Compare commits

...

36 Commits

Author SHA1 Message Date
rkuo-danswer
9456fef307 Merge pull request #3161 from danswer-ai/hotfix/v0.13-indexing-redux
enhanced logging for indexing and increased indexing timeouts
2024-11-18 19:16:39 -08:00
Richard Kuo (Danswer)
cc3c0800f0 no idea how those files got into the merge 2024-11-18 18:38:29 -08:00
Richard Kuo (Danswer)
e860f15b64 hotfix merge 2024-11-18 18:14:21 -08:00
rkuo-danswer
574ef470a4 Merge pull request #3149 from danswer-ai/hotfix/v0.13-overlapping-connectors
merge overlapping connector hotfix
2024-11-16 22:34:02 -08:00
Richard Kuo
9e391495c2 fix unused stuff for hotfix 2024-11-16 21:11:39 -08:00
Richard Kuo
e26d5430fa merge overlapping connector hotfix 2024-11-16 20:59:00 -08:00
rkuo-danswer
cce0ec2f22 Merge pull request #3141 from danswer-ai/hotfix/v0.13-indexing-concurrency
Merge hotfix/v0.13-indexing-concurrency into release/v0.13
2024-11-15 12:51:41 -08:00
rkuo-danswer
a4f09a62a5 Merge pull request #3142 from danswer-ai/hotfix/v0.13-session-text
Merge hotfix/v0.13-session-text into release/v0.13
2024-11-15 12:51:23 -08:00
rkuo-danswer
fd2428d97f Merge pull request #3131 from danswer-ai/bugfix/session_text
use text()
2024-11-15 20:23:18 +00:00
rkuo-danswer
cfc46812c8 scale indexing sql pool based on concurrency (#3130) 2024-11-15 20:21:43 +00:00
pablodanswer
942e47db29 improved mobile scroll (#3110) 2024-11-12 01:57:49 +00:00
pablodanswer
f4a020b599 moderate component fixes (#3095)
* moderate component fixes

* nit

* nit

* update colors

* k
2024-11-12 00:47:35 +00:00
pablodanswer
5166649eae Cleaner EE fallback for no op (#3106)
* treat async values differently

* cleaner approach

* spacing

* typing
2024-11-11 17:42:14 +00:00
Chris Weaver
ba805f766f New assistants api (#3097) 2024-11-11 07:55:23 -08:00
rkuo-danswer
9d57f34c34 re-enable helm (#3053)
* re-enable helm

* allow manual triggering

* change vespa host

* change vespa chart location

* update Chart.lock

* update ct.yaml with new vespa chart repo

* bump vespa to 0.2.5

* update Chart.lock

* update to vespa 0.2.6

* bump vespa to 0.2.7

* bump to 0.2.8

* bump version

* try appending the ordinal

* try new configmap

* bump vespa

* bump vespa

* add debug to see if we can figure out what ct install thinks is failing

* add debug flag to helm

* try disabling nginx because of KinD

* use helm-extra-set-args

* try command line

* try pointing test connection to the correct service name

* bump vespa to 0.2.12

* update chart.lock

* bump vespa to 0.2.13

* bump vespa to 0.2.14

* bump vespa

* bump vespa

* re-enable chart testing only on changes

* name the check more specifically than "lint-test"

* add some debugging

* try setting remote

* might have to specify chart dirs directly

* add comments

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-10 01:28:39 +00:00
pablodanswer
cc2f584321 Silence auth logs (#3098)
* silence auth logs

* remove unnecessary line

* k
2024-11-09 21:41:11 +00:00
pablodanswer
a1b95df3b8 Robustify cloud deployment + include initial KEDA configuration (#3094)
* robustify cloud deployment + include initial KEDA configuration

* ensure .github changes are passed

* raise exits
2024-11-09 21:26:51 +00:00
pablodanswer
9272d6ebfe Remove ee (#3093)
* move api key to non-ee

* finalize previous migration

* move token rate limit to non-ee

* general cleanup

* update

* update

* finalize

* finalize

* ensure callable

* k
2024-11-09 20:51:36 +00:00
Yuhong Sun
4fb65dcf73 Reenable OpenAI Tokenizer (#3062)
* k

* clean up test embeddings

* nit

* minor update to ensure consistency

* minor organizational update

* minor updates

---------

Co-authored-by: pablodanswer <pablo@danswer.ai>
2024-11-08 22:54:15 +00:00
rkuo-danswer
2bbc5d5d07 fix saving docker logs (#3090) 2024-11-08 19:54:48 +00:00
rkuo-danswer
950b1c38f2 Merge pull request #3080 from danswer-ai/robust_assistant_description
Account for malformatted starter messages
2024-11-08 11:28:19 -08:00
Yuhong Sun
99fbfba32f File Connector Metadata (#3089) 2024-11-08 10:49:59 -08:00
pablodanswer
0a59efe64a account for malformatted starter messages 2024-11-08 10:21:04 -08:00
pablodanswer
cf5d394d39 adjust default postgres schema for slack listener (#3088) 2024-11-08 18:00:44 +00:00
pablodanswer
f6d8f5ca89 Migrate tenant upgrades to data plane (#3051)
* add provisioning on data plane

* functional but scrappy

* minor cleanup

* minor clean up

* k

* simplify

* update provisioning

* improve import logic

* ensure proper conditional

* minor pydantic update

* minor config update

* nit
2024-11-08 17:13:29 +00:00
hagen-danswer
1fb4cdfcc3 Merge pull request #3073 from skylares/fireflies-dev
Fireflies connector
2024-11-08 06:50:22 -08:00
hagen-danswer
ac51469bcb Merge branch 'main' into fireflies-dev 2024-11-07 18:56:37 -08:00
Skylar Kesselring
c25f164e28 Remove linux 2024-11-07 21:51:58 -05:00
Skylar Kesselring
813720905b Fix failure cases 2024-11-07 21:37:41 -05:00
rkuo-danswer
0c45488ac6 wait for db before allowing worker to proceed (reduces error spam on … (#3079)
* wait for db before allowing worker to proceed (reduces error spam on container startup)

* fix session usage

* rework readiness probe logic to be less confusing and word ongoing probes better

* add vespa probe too

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-11-08 01:25:09 +00:00
Skylar Kesselring
95d9b33c1a Clean up connector 2024-11-07 19:51:40 -05:00
Yuhong Sun
55919f596c PG Dev Max Connections (#3082) 2024-11-07 11:51:23 -08:00
pablodanswer
1d0fb6d012 Evaluate None to default (#3069)
* add sentinel value

* update typing

* clearer

* update comments

* ensure proper attribution
2024-11-07 18:41:42 +00:00
pablodanswer
2b1dbde829 minor improvements (#3081) 2024-11-07 18:35:49 +00:00
Skylar Kesselring
ee4b334a0a Fix errors and cleanup 2024-11-06 14:01:51 -05:00
Skylar Kesselring
7ff18e0a93 Create connector 2024-11-05 19:28:57 -05:00
135 changed files with 4242 additions and 1138 deletions

View File

@@ -3,61 +3,61 @@ name: Build and Push Backend Image on Tag
on:
push:
tags:
- '*'
- "*"
env:
REGISTRY_IMAGE: danswer/danswer-backend
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-backend-cloud' || 'danswer/danswer-backend' }}
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# TODO: investigate a matrix build like the web container
# TODO: investigate a matrix build like the web container
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Install build-essential
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Install build-essential
run: |
sudo apt-get update
sudo apt-get install -y build-essential
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'
trivyignores: ./backend/.trivyignore
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
trivyignores: ./backend/.trivyignore

View File

@@ -4,12 +4,12 @@ name: Build and Push Cloud Web Image on Tag
on:
push:
tags:
- '*'
- "*"
env:
REGISTRY_IMAGE: danswer/danswer-cloud-web-server
REGISTRY_IMAGE: danswer/danswer-web-server-cloud
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build:
runs-on:
@@ -28,11 +28,11 @@ jobs:
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
@@ -41,16 +41,16 @@ jobs:
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push by digest
id: build
uses: docker/build-push-action@v5
@@ -65,17 +65,17 @@ jobs:
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
# needed due to weird interactions with the builds for different platforms
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
@@ -95,42 +95,42 @@ jobs:
path: /tmp/digests
pattern: digests-*
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'
severity: "CRITICAL,HIGH"

View File

@@ -3,53 +3,53 @@ name: Build and Push Model Server Image on Tag
on:
push:
tags:
- '*'
- "*"
env:
REGISTRY_IMAGE: danswer/danswer-model-server
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-model-server-cloud' || 'danswer/danswer-model-server' }}
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Model Server Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Model Server Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: "CRITICAL,HIGH"

View File

@@ -210,17 +210,18 @@ jobs:
echo "All integration tests passed successfully."
fi
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
# save before stopping the containers so the logs can be captured
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Upload logs
if: success() || failure()

View File

@@ -1,24 +1,20 @@
# This workflow is intentionally disabled while we're still working on it
# It's close to ready, but a race condition needs to be fixed with
# API server and Vespa startup, and it needs to have a way to build/test against
# local containers
name: Helm - Lint and Test Charts
on:
merge_group:
pull_request:
branches: [ main ]
workflow_dispatch: # Allows manual triggering
jobs:
lint-test:
helm-chart-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
# fetch-depth 0 is required for helm/chart-testing-action
steps:
- name: Checkout code
uses: actions/checkout@v3
uses: actions/checkout@v4
with:
fetch-depth: 0
@@ -28,7 +24,7 @@ jobs:
version: v3.14.4
- name: Set up Python
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: '3.11'
cache: 'pip'
@@ -45,24 +41,31 @@ jobs:
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
id: list-changed
run: |
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
echo "default_branch: ${{ github.event.repository.default_branch }}"
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
echo "list-changed output: $changed"
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
# if: steps.list-changed.outputs.changed == 'true'
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
if: steps.list-changed.outputs.changed == 'true'
run: ct lint --config ct.yaml --all
# the following would lint only changed charts, but linting isn't expensive
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
- name: Create kind cluster
# if: steps.list-changed.outputs.changed == 'true'
if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
# if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --config ct.yaml
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --helm-extra-set-args="--set=nginx.enabled=false" --debug --config ct.yaml
# the following would install only changed charts, but we only have one chart so
# don't worry about that for now
# run: ct install --target-branch ${{ github.event.repository.default_branch }}

View File

@@ -288,6 +288,15 @@ def upgrade() -> None:
def downgrade() -> None:
# NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints
# below
op.execute("DELETE FROM chat_feedback")
op.execute("DELETE FROM chat_message__search_doc")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM document_retrieval_feedback")
op.execute("DELETE FROM chat_message")
op.execute("DELETE FROM chat_session")
op.drop_constraint(
"chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey"
)

View File

@@ -26,7 +26,8 @@ def upgrade() -> None:
FROM jsonb_array_elements(starter_messages) elem
)
WHERE starter_messages IS NOT NULL
"""
AND jsonb_typeof(starter_messages) = 'array'
"""
)
)
@@ -41,6 +42,7 @@ def downgrade() -> None:
FROM jsonb_array_elements(starter_messages) elem
)
WHERE starter_messages IS NOT NULL
"""
AND jsonb_typeof(starter_messages) = 'array'
"""
)
)

View File

@@ -23,6 +23,56 @@ def upgrade() -> None:
def downgrade() -> None:
# Delete chat messages and feedback first since they reference chat sessions
# Get chat messages from sessions with null persona_id
chat_messages_query = """
SELECT id
FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
# Delete dependent records first
op.execute(
f"""
DELETE FROM document_retrieval_feedback
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
op.execute(
f"""
DELETE FROM chat_message__search_doc
WHERE chat_message_id IN (
{chat_messages_query}
)
"""
)
# Delete chat messages
op.execute(
"""
DELETE FROM chat_message
WHERE chat_session_id IN (
SELECT id
FROM chat_session
WHERE persona_id IS NULL
)
"""
)
# Now we can safely delete the chat sessions
op.execute(
"""
DELETE FROM chat_session
WHERE persona_id IS NULL
"""
)
op.alter_column(
"chat_session",
"persona_id",

View File

@@ -8,7 +8,7 @@ from passlib.hash import sha256_crypt
from pydantic import BaseModel
from danswer.auth.schemas import UserRole
from ee.danswer.configs.app_configs import API_KEY_HASH_ROUNDS
from danswer.configs.app_configs import API_KEY_HASH_ROUNDS
_API_KEY_HEADER_NAME = "Authorization"

View File

@@ -48,11 +48,11 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.orm import attributes
from sqlalchemy.orm import Session
from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
@@ -75,6 +75,7 @@ from danswer.configs.constants import AuthType
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.api_key import fetch_user_for_api_key
from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
@@ -83,24 +84,27 @@ from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.db.models import UserTenantMapping
from danswer.db.users import get_user_by_email
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True
@@ -190,20 +194,6 @@ def verify_email_domain(email: str) -> None:
)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
def send_user_verification_email(
user_email: str,
token: str,
@@ -238,19 +228,13 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> User:
try:
tenant_id = (
get_tenant_id_for_email(user_create.email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user_create.email,
)
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -271,7 +255,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
@@ -292,7 +276,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
else:
raise exceptions.UserAlreadyExists()
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user
async def oauth_callback(
@@ -308,19 +294,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
# Get tenant_id from mapping table
try:
tenant_id = (
get_tenant_id_for_email(account_email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=account_email,
)
if not tenant_id:
raise HTTPException(status_code=401, detail="User not found")
# Proceed with the tenant context
token = None
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -371,9 +356,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Explicitly set the Postgres schema for this session to ensure
# OAuth account creation happens in the correct tenant schema
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
else:
@@ -453,7 +438,13 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
email = credentials.username
# Get tenant_id from mapping table
tenant_id = get_tenant_id_for_email(email)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=email,
)
if not tenant_id:
# User not found in mapping
self.password_helper.hash(credentials.password)
@@ -477,8 +468,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
has_web_login = attributes.get_attribute(user, "has_web_login")
if not has_web_login:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
raise BasicAuthenticationError(
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
@@ -511,7 +501,14 @@ cookie_transport = CookieTransport(
# This strategy is used to add tenant_id to the JWT token
class TenantAwareJWTStrategy(JWTStrategy):
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
tenant_id = get_tenant_id_for_email(user.email)
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user.email,
)
data = {
"sub": str(user.id),
"aud": self.token_audience,
@@ -628,14 +625,12 @@ async def double_check_user(
return None
if user is None:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
raise BasicAuthenticationError(
detail="Access denied. User is not authenticated.",
)
if user_needs_to_be_verified() and not user.is_verified:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
raise BasicAuthenticationError(
detail="Access denied. User is not verified.",
)
@@ -644,8 +639,7 @@ async def double_check_user(
and user.oidc_expiry < datetime.now(timezone.utc)
and not include_expired
):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
raise BasicAuthenticationError(
detail="Access denied. User's OIDC token has expired.",
)
@@ -671,15 +665,13 @@ async def current_curator_or_admin_user(
return None
if not user or not hasattr(user, "role"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
raise BasicAuthenticationError(
detail="Access denied. User is not authenticated or lacks role information.",
)
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
if user.role not in allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
raise BasicAuthenticationError(
detail="Access denied. User is not a curator or admin.",
)
@@ -691,8 +683,7 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
return None
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
raise BasicAuthenticationError(
detail="Access denied. User must be an admin to perform this action.",
)
@@ -885,3 +876,22 @@ def get_oauth_router(
return redirect_response
return router
def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
hashed_api_key = get_hashed_api_key_from_request(request)
if not hashed_api_key:
raise HTTPException(status_code=401, detail="Missing API key")
if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")
return user

View File

@@ -3,6 +3,7 @@ import multiprocessing
import time
from typing import Any
import requests
import sentry_sdk
from celery import Task
from celery.app import trace
@@ -11,11 +12,15 @@ from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from celery.worker import strategy # type: ignore
from sentry_sdk.integrations.celery import CeleryIntegration
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
@@ -26,7 +31,6 @@ from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import SENTRY_DSN
@@ -139,45 +143,136 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
"""Waits for redis to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
r = get_redis_client(tenant_id=None)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Redis: Readiness check starting.")
logger.info("Redis: Readiness probe starting.")
while True:
try:
if r.ping():
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
logger.info(
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
break
logger.info(
f"Redis: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
logger.info("Redis: Readiness check succeeded. Continuing...")
if not ready:
msg = (
f"Redis: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Redis: Readiness probe succeeded. Continuing...")
return
def wait_for_db(sender: Any, **kwargs: Any) -> None:
"""Waits for the db to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Database: Readiness probe starting.")
while True:
try:
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(text("SELECT NOW()")).scalar()
if result:
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Database: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Database: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Database: Readiness probe succeeded. Continuing...")
return
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Vespa: Readiness probe starting.")
while True:
try:
response = requests.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
response.raise_for_status()
response_dict = response.json()
if response_dict["status"]["code"] == "up":
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Vespa: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Vespa: Readiness probe succeeded. Continuing...")
return
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("Running as a secondary celery worker.")
# Exit early if multi-tenant since primary worker check not needed
if MULTI_TENANT:
return
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60

View File

@@ -12,6 +12,7 @@ from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
@@ -119,10 +120,10 @@ class DynamicTenantScheduler(PersistentScheduler):
else:
logger.info("Schedule is up to date, no changes needed")
except (AttributeError, KeyError) as e:
logger.exception(f"Failed to process task configuration: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error updating tenant tasks: {str(e)}")
except (AttributeError, KeyError):
logger.exception("Failed to process task configuration")
except Exception:
logger.exception("Unexpected error updating tenant tasks")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
@@ -143,6 +144,11 @@ def on_beat_init(sender: Any, **kwargs: Any) -> None:
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
SqlEngine.init_engine(pool_size=2, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)

View File

@@ -13,6 +13,7 @@ import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -60,7 +61,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=4, max_overflow=12)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)

View File

@@ -13,6 +13,7 @@ import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -58,9 +59,15 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)

View File

@@ -13,6 +13,7 @@ import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -59,8 +60,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)

View File

@@ -14,10 +14,14 @@ from celery.signals import worker_shutdown
import danswer.background.celery.apps.app_base as app_base
from danswer.background.celery.apps.app_base import task_logger
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.celery.tasks.vespa.tasks import get_unfenced_index_attempt_ids
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.engine import SqlEngine
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
from danswer.redis.redis_connector_index import RedisConnectorIndex
@@ -75,13 +79,16 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
app_base.wait_for_redis(sender, **kwargs)
logger.info("Running as the primary celery worker.")
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
logger.info("Running as the primary celery worker.")
# This is singleton work that should be done on startup exactly once
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_redis_client(tenant_id=None)
@@ -131,6 +138,23 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
RedisConnectorStop.reset_all(r)
# mark orphaned index attempts as failed
with get_session_with_default_tenant() as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
failure_reason = (
f"Orphaned index attempt found on startup: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
logger.warning(failure_reason)
mark_attempt_failed(attempt.id, db_session, failure_reason)
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:

View File

@@ -1,12 +1,12 @@
from datetime import datetime
from datetime import timezone
import redis
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
@@ -87,7 +87,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
cc_pair_id: int,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.

View File

@@ -3,13 +3,14 @@ from datetime import timezone
from http import HTTPStatus
from time import sleep
import redis
import sentry_sdk
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.background.celery.apps.app_base import task_logger
@@ -44,7 +45,7 @@ from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_index import RedisConnectorIndexingFenceData
from danswer.redis.redis_connector_index import RedisConnectorIndexPayload
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
@@ -61,14 +62,18 @@ class RunIndexingCallback(RunIndexingCallbackInterface):
self,
stop_key: str,
generator_progress_key: str,
redis_lock: redis.lock.Lock,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.redis_lock: redis.lock.Lock = redis_lock
self.redis_lock: RedisLock = redis_lock
self.stop_key: str = stop_key
self.generator_progress_key: str = generator_progress_key
self.redis_client = redis_client
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
def should_stop(self) -> bool:
if self.redis_client.exists(self.stop_key):
@@ -76,7 +81,19 @@ class RunIndexingCallback(RunIndexingCallbackInterface):
return False
def progress(self, amount: int) -> None:
self.redis_lock.reacquire()
try:
self.redis_lock.reacquire()
self.last_lock_reacquire = datetime.now(timezone.utc)
except LockError:
logger.exception(
f"RunIndexingCallback - lock.reacquire exceptioned. "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
raise
self.redis_client.incrby(self.generator_progress_key, amount)
@@ -325,7 +342,7 @@ def try_creating_indexing_task(
redis_connector_index.generator_clear()
# set a basic fence to start
payload = RedisConnectorIndexingFenceData(
payload = RedisConnectorIndexPayload(
index_attempt_id=None,
started=None,
submitted=datetime.now(timezone.utc),
@@ -368,7 +385,7 @@ def try_creating_indexing_task(
redis_connector_index.set_fence(payload)
except Exception:
redis_connector_index.set_fence(payload)
redis_connector_index.set_fence(None)
task_logger.exception(
f"Unexpected exception: "
f"tenant={tenant_id} "

View File

@@ -13,6 +13,7 @@ from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult
from celery.states import READY_STATES
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from tenacity import RetryError
@@ -162,7 +163,7 @@ def try_generate_stale_document_sync_tasks(
celery_app: Celery,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
# the fence is up, do nothing
@@ -180,7 +181,12 @@ def try_generate_stale_document_sync_tasks(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
)
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
task_logger.info(
"RedisConnector.generate_tasks starting by cc_pair. "
"Documents spanning multiple cc_pairs will only be synced once."
)
docs_to_skip: set[str] = set()
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
@@ -188,22 +194,21 @@ def try_generate_stale_document_sync_tasks(
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
rc = RedisConnectorCredentialPair(tenant_id, cc_pair.id)
tasks_generated = rc.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
rc.set_skip_docs(docs_to_skip)
result = rc.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
if tasks_generated is None:
if result is None:
continue
if tasks_generated == 0:
if result[1] == 0:
continue
task_logger.info(
f"RedisConnector.generate_tasks finished for single cc_pair. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
f"cc_pair={cc_pair.id} tasks_generated={result[0]} tasks_possible={result[1]}"
)
total_tasks_generated += tasks_generated
total_tasks_generated += result[0]
task_logger.info(
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
@@ -218,7 +223,7 @@ def try_generate_document_set_sync_tasks(
document_set_id: int,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@@ -246,12 +251,11 @@ def try_generate_document_set_sync_tasks(
)
# Add all documents that need to be updated into the queue
tasks_generated = rds.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
result = rds.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
if result is None:
return None
tasks_generated = result[0]
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
@@ -260,7 +264,7 @@ def try_generate_document_set_sync_tasks(
task_logger.info(
f"RedisDocumentSet.generate_tasks finished. "
f"document_set_id={document_set.id} tasks_generated={tasks_generated}"
f"document_set={document_set.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
@@ -273,7 +277,7 @@ def try_generate_user_group_sync_tasks(
usergroup_id: int,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
lock_beat: RedisLock,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@@ -302,12 +306,11 @@ def try_generate_user_group_sync_tasks(
task_logger.info(
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
)
tasks_generated = rug.generate_tasks(
celery_app, db_session, r, lock_beat, tenant_id
)
if tasks_generated is None:
result = rug.generate_tasks(celery_app, db_session, r, lock_beat, tenant_id)
if result is None:
return None
tasks_generated = result[0]
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
@@ -316,7 +319,7 @@ def try_generate_user_group_sync_tasks(
task_logger.info(
f"RedisUserGroup.generate_tasks finished. "
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
f"usergroup={usergroup.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
@@ -580,8 +583,8 @@ def monitor_ccpair_indexing_taskset(
progress = redis_connector_index.get_progress()
if progress is not None:
task_logger.info(
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"Connector indexing progress: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
@@ -602,8 +605,8 @@ def monitor_ccpair_indexing_taskset(
# if it isn't, then the worker crashed
task_logger.info(
f"Connector indexing aborted: "
f"cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
@@ -621,8 +624,8 @@ def monitor_ccpair_indexing_taskset(
status_enum = HTTPStatus(status_int)
task_logger.info(
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"Connector indexing finished: cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
@@ -630,6 +633,37 @@ def monitor_ccpair_indexing_taskset(
redis_connector_index.reset()
def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[int]:
"""Gets a list of unfenced index attempts. Should not be possible, so we'd typically
want to clean them up.
Unfenced = attempt not in terminal state and fence does not exist.
"""
unfenced_attempts: list[int] = []
# do some cleanup before clearing fences
# check the db for any outstanding index attempts
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for attempt in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
fence_key = RedisConnectorIndex.fence_key_with_ids(
attempt.connector_credential_pair_id, attempt.search_settings_id
)
if r.exists(fence_key):
continue
unfenced_attempts.append(attempt.id)
return unfenced_attempts
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
@@ -643,7 +677,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
"""
r = get_redis_client(tenant_id=tenant_id)
lock_beat: redis.lock.Lock = r.lock(
lock_beat: RedisLock = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -677,31 +711,24 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
f"pruning={n_pruning}"
)
# do some cleanup before clearing fences
# check the db for any outstanding index attempts
# Fail any index attempts in the DB that don't have fences
with get_session_with_tenant(tenant_id) as db_session:
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
attempt = get_index_attempt(db_session, attempt_id)
if not attempt:
continue
for a in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
fence_key = RedisConnectorIndex.fence_key_with_ids(
a.connector_credential_pair_id, a.search_settings_id
failure_reason = (
f"Unfenced index attempt found in DB: "
f"index_attempt={attempt.id} "
f"cc_pair={attempt.connector_credential_pair_id} "
f"search_settings={attempt.search_settings_id}"
)
task_logger.warning(failure_reason)
mark_attempt_failed(
attempt.id, db_session, failure_reason=failure_reason
)
if not r.exists(fence_key):
failure_reason = (
f"Unknown index attempt. Might be left over from a process restart: "
f"index_attempt={a.id} "
f"cc_pair={a.connector_credential_pair_id} "
f"search_settings={a.search_settings_id}"
)
task_logger.warning(failure_reason)
mark_attempt_failed(a.id, db_session, failure_reason=failure_reason)
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):

View File

@@ -433,11 +433,13 @@ def run_indexing_entrypoint(
with get_session_with_tenant(tenant_id) as db_session:
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
logger.info(
f"Indexing starting for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing starting{tenant_str}: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
@@ -445,10 +447,8 @@ def run_indexing_entrypoint(
_run_indexing(db_session, attempt, tenant_id, callback)
logger.info(
f"Indexing finished for tenant {tenant_id}: "
if tenant_id is not None
else ""
+ f"connector='{attempt.connector_credential_pair.connector.name}' "
f"Indexing finished{tenant_str}: "
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)

View File

@@ -0,0 +1,4 @@
def name_sync_external_doc_permissions_task(
cc_pair_id: int, tenant_id: str | None = None
) -> str:
return f"sync_external_doc_permissions_task__{cc_pair_id}"

View File

@@ -19,16 +19,10 @@ from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.chat.models import StreamStopInfo
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.chat import attach_files_to_chat_message
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@@ -41,7 +35,6 @@ from danswer.db.chat import reserve_message_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
@@ -61,14 +54,13 @@ from danswer.llm.answering.models import PromptConfig
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import litellm_exception_to_error_msg
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import QueryFlow
from danswer.search.enums import SearchType
from danswer.search.models import InferenceSection
from danswer.search.models import RetrievalDetails
from danswer.search.retrieval.search_runner import inference_sections_from_ids
from danswer.search.utils import chunks_or_sections_to_search_docs
from danswer.search.utils import dedupe_documents
@@ -77,14 +69,14 @@ from danswer.search.utils import relevant_sections_to_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.force import ForceUseTool
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.models import ToolResponse
from danswer.tools.tool import Tool
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.tool_constructor import construct_tools
from danswer.tools.tool_constructor import CustomToolConfig
from danswer.tools.tool_constructor import ImageGenerationToolConfig
from danswer.tools.tool_constructor import InternetSearchToolConfig
from danswer.tools.tool_constructor import SearchToolConfig
from danswer.tools.tool_implementations.custom.custom_tool import (
CUSTOM_TOOL_RESPONSE_ID,
)
@@ -95,9 +87,6 @@ from danswer.tools.tool_implementations.images.image_generation_tool import (
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationResponse,
)
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
INTERNET_SEARCH_RESPONSE_ID,
)
@@ -122,9 +111,6 @@ from danswer.tools.tool_implementations.search.search_tool import (
SECTION_RELEVANCE_LIST_ID,
)
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.headers import header_dict_to_header_list
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
@@ -295,7 +281,6 @@ def stream_chat_message_objects(
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
@@ -307,6 +292,9 @@ def stream_chat_message_objects(
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
use_existing_user_message = new_msg_req.use_existing_user_message
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
# Currently surrounding context is not supported for chat
# Chat is already token heavy and harder for the model to process plus it would roll history over much faster
new_msg_req.chunks_above = 0
@@ -428,12 +416,20 @@ def stream_chat_message_objects(
final_msg, history_msgs = create_chat_chain(
chat_session_id=chat_session_id, db_session=db_session
)
if final_msg.message_type != MessageType.USER:
raise RuntimeError(
"The last message was not a user message. Cannot call "
"`stream_chat_message_objects` with `is_regenerate=True` "
"when the last message is not a user message."
)
if existing_assistant_message_id is None:
if final_msg.message_type != MessageType.USER:
raise RuntimeError(
"The last message was not a user message. Cannot call "
"`stream_chat_message_objects` with `is_regenerate=True` "
"when the last message is not a user message."
)
else:
if final_msg.id != existing_assistant_message_id:
raise RuntimeError(
"The last message was not the existing assistant message. "
f"Final message id: {final_msg.id}, "
f"existing assistant message id: {existing_assistant_message_id}"
)
# Disable Query Rephrasing for the first message
# This leads to a better first response since the LLM rephrasing the question
@@ -504,13 +500,19 @@ def stream_chat_message_objects(
),
max_window_percentage=max_document_percentage,
)
reserved_message_id = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
# we don't need to reserve a message id if we're using an existing assistant message
reserved_message_id = (
final_msg.id
if existing_assistant_message_id is not None
else reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
)
)
yield MessageResponseIDInfo(
user_message_id=user_message.id if user_message else None,
@@ -525,7 +527,13 @@ def stream_chat_message_objects(
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=final_msg,
# if we're using an existing assistant message, then this will just be an
# update operation, in which case the parent should be the parent of
# the latest. If we're creating a new assistant message, then the parent
# should be the latest message (latest user message)
parent_message=(
final_msg if existing_assistant_message_id is None else parent_message
),
prompt_id=prompt_id,
overridden_model=overridden_model,
# message=,
@@ -537,6 +545,7 @@ def stream_chat_message_objects(
# reference_docs=,
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
)
if not final_msg.prompt:
@@ -560,142 +569,39 @@ def stream_chat_message_objects(
structured_response_format=new_msg_req.structured_response_format,
)
# find out what tools to use
search_tool: SearchTool | None = None
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
for db_tool_model in persona.tools:
# handle in-code tools specially
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
answer_style_config=answer_style_config,
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
img_generation_llm_config: LLMConfig | None = None
if (
llm
and llm.config.api_key
and llm.config.model_provider == "openai"
):
img_generation_llm_config = LLMConfig(
model_provider=llm.config.model_provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=llm.config.api_key,
api_base=llm.config.api_base,
api_version=llm.config.api_version,
)
elif (
llm.config.model_provider == "azure"
and AZURE_DALLE_API_KEY is not None
):
img_generation_llm_config = LLMConfig(
model_provider="azure",
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
temperature=GEN_AI_TEMPERATURE,
api_key=AZURE_DALLE_API_KEY,
api_base=AZURE_DALLE_API_BASE,
api_version=AZURE_DALLE_API_VERSION,
)
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError(
"Image generation tool requires an OpenAI API key"
)
img_generation_llm_config = LLMConfig(
model_provider=openai_provider.provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
api_version=openai_provider.api_version,
)
tool_dict[db_tool_model.id] = [
ImageGenerationTool(
api_key=cast(str, img_generation_llm_config.api_key),
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=litellm_additional_headers,
model=img_generation_llm_config.model_name,
)
]
elif tool_cls.__name__ == InternetSearchTool.__name__:
bing_api_key = BING_API_KEY
if not bing_api_key:
raise ValueError(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(
api_key=bing_api_key,
answer_style_config=answer_style_config,
prompt_config=prompt_config,
)
]
continue
# handle all custom tools
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
custom_headers=(db_tool_model.custom_headers or [])
+ (
header_dict_to_header_list(
custom_tool_additional_headers or {}
)
),
),
)
tool_dict = construct_tools(
persona=persona,
prompt_config=prompt_config,
db_session=db_session,
user=user,
llm=llm,
fast_llm=fast_llm,
search_tool_config=SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=retrieval_options or RetrievalDetails(),
selected_sections=selected_sections,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
latest_query_files=latest_query_files,
),
internet_search_tool_config=InternetSearchToolConfig(
answer_style_config=answer_style_config,
),
image_generation_tool_config=ImageGenerationToolConfig(
additional_headers=litellm_additional_headers,
),
custom_tool_config=CustomToolConfig(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
additional_headers=custom_tool_additional_headers,
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
tools, llm_tokenizer
)
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
llm_provider, llm_model_name
)
# LLM prompt building, response capturing, etc.
answer = Answer(
is_connected=is_connected,
@@ -871,7 +777,6 @@ def stream_chat_message_objects(
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
@@ -879,9 +784,11 @@ def stream_chat_message_objects(
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=message_specific_citations.citation_map
if message_specific_citations
else None,
citations=(
message_specific_citations.citation_map
if message_specific_citations
else None
),
error=None,
tool_call=(
ToolCall(
@@ -915,7 +822,6 @@ def stream_chat_message_objects(
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
@@ -925,7 +831,6 @@ def stream_chat_message(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
is_connected=is_connected,

View File

@@ -493,3 +493,13 @@ JWT_ALGORITHM = "HS256"
# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
#####
# API Key Configs
#####
# refers to the rounds described here: https://passlib.readthedocs.io/en/stable/lib/passlib.hash.sha256_crypt.html
_API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
API_KEY_HASH_ROUNDS = (
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
)

View File

@@ -74,7 +74,7 @@ CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
@@ -126,6 +126,7 @@ class DocumentSource(str, Enum):
XENFORO = "xenforo"
NOT_APPLICABLE = "not_applicable"
FRESHDESK = "freshdesk"
FIREFLIES = "fireflies"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]

View File

@@ -16,6 +16,7 @@ from danswer.connectors.discourse.connector import DiscourseConnector
from danswer.connectors.document360.connector import Document360Connector
from danswer.connectors.dropbox.connector import DropboxConnector
from danswer.connectors.file.connector import LocalFileConnector
from danswer.connectors.fireflies.connector import FirefliesConnector
from danswer.connectors.freshdesk.connector import FreshdeskConnector
from danswer.connectors.github.connector import GithubConnector
from danswer.connectors.gitlab.connector import GitlabConnector
@@ -101,6 +102,7 @@ def identify_connector_class(
DocumentSource.OCI_STORAGE: BlobStorageConnector,
DocumentSource.XENFORO: XenforoConnector,
DocumentSource.FRESHDESK: FreshdeskConnector,
DocumentSource.FIREFLIES: FirefliesConnector,
}
connector_by_source = connector_map.get(source, {})

View File

@@ -123,9 +123,13 @@ def _process_file(
"filename",
"file_display_name",
"title",
"connector_type",
]
}
source_type_str = all_metadata.get("connector_type")
source_type = DocumentSource(source_type_str) if source_type_str else None
p_owner_names = all_metadata.get("primary_owners")
s_owner_names = all_metadata.get("secondary_owners")
p_owners = (
@@ -145,7 +149,7 @@ def _process_file(
sections=[
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
],
source=DocumentSource.FILE,
source=source_type or DocumentSource.FILE,
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,

View File

@@ -0,0 +1,182 @@
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import List
import requests
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
logger = setup_logger()
_FIREFLIES_ID_PREFIX = "FIREFLIES_"
_FIREFLIES_API_URL = "https://api.fireflies.ai/graphql"
_FIREFLIES_TRANSCRIPT_QUERY_SIZE = 50 # Max page size is 50
_FIREFLIES_API_QUERY = """
query Transcripts($fromDate: DateTime, $toDate: DateTime, $limit: Int!, $skip: Int!) {
transcripts(fromDate: $fromDate, toDate: $toDate, limit: $limit, skip: $skip) {
id
title
host_email
participants
date
transcript_url
sentences {
text
speaker_name
}
}
}
"""
def _create_doc_from_transcript(transcript: dict) -> Document | None:
meeting_text = ""
sentences = transcript.get("sentences", [])
if sentences:
for sentence in sentences:
meeting_text += sentence.get("speaker_name") or "Unknown Speaker"
meeting_text += ": " + sentence.get("text", "") + "\n\n"
else:
return None
meeting_link = transcript["transcript_url"]
fireflies_id = _FIREFLIES_ID_PREFIX + transcript["id"]
meeting_title = transcript["title"] or "No Title"
meeting_date_unix = transcript["date"]
meeting_date = datetime.fromtimestamp(meeting_date_unix / 1000, tz=timezone.utc)
meeting_host_email = transcript["host_email"]
host_email_user_info = [BasicExpertInfo(email=meeting_host_email)]
meeting_participants_email_list = []
for participant in transcript.get("participants", []):
if participant != meeting_host_email and participant:
meeting_participants_email_list.append(BasicExpertInfo(email=participant))
return Document(
id=fireflies_id,
sections=[
Section(
link=meeting_link,
text=meeting_text,
)
],
source=DocumentSource.FIREFLIES,
semantic_identifier=meeting_title,
metadata={},
doc_updated_at=meeting_date,
primary_owners=host_email_user_info,
secondary_owners=meeting_participants_email_list,
)
class FirefliesConnector(PollConnector, LoadConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
def load_credentials(self, credentials: dict[str, str]) -> None:
api_key = credentials.get("fireflies_api_key")
if not isinstance(api_key, str):
raise ConnectorMissingCredentialError(
"The Fireflies API key must be a string"
)
self.api_key = api_key
return None
def _fetch_transcripts(
self, start_datetime: str | None = None, end_datetime: str | None = None
) -> Iterator[List[dict]]:
if self.api_key is None:
raise ConnectorMissingCredentialError("Missing API key")
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.api_key,
}
skip = 0
variables: dict[str, int | str] = {
"limit": _FIREFLIES_TRANSCRIPT_QUERY_SIZE,
}
if start_datetime:
variables["fromDate"] = start_datetime
if end_datetime:
variables["toDate"] = end_datetime
while True:
variables["skip"] = skip
response = requests.post(
_FIREFLIES_API_URL,
headers=headers,
json={"query": _FIREFLIES_API_QUERY, "variables": variables},
)
response.raise_for_status()
if response.status_code == 204:
break
recieved_transcripts = response.json()
parsed_transcripts = recieved_transcripts.get("data", {}).get(
"transcripts", []
)
yield parsed_transcripts
if len(parsed_transcripts) < _FIREFLIES_TRANSCRIPT_QUERY_SIZE:
break
skip += _FIREFLIES_TRANSCRIPT_QUERY_SIZE
def _process_transcripts(
self, start: str | None = None, end: str | None = None
) -> GenerateDocumentsOutput:
doc_batch: List[Document] = []
for transcript_batch in self._fetch_transcripts(start, end):
for transcript in transcript_batch:
if doc := _create_doc_from_transcript(transcript):
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._process_transcripts()
def poll_source(
self, start_unixtime: SecondsSinceUnixEpoch, end_unixtime: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.fromtimestamp(
start_unixtime, tz=timezone.utc
).strftime("%Y-%m-%dT%H:%M:%S.000Z")
end_datetime = datetime.fromtimestamp(end_unixtime, tz=timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%S.000Z"
)
yield from self._process_transcripts(start_datetime, end_datetime)

View File

@@ -77,6 +77,7 @@ from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SLACK_CHANNEL_ID
logger = setup_logger()
@@ -189,59 +190,67 @@ class SlackbotHandler:
continue
logger.debug(f"Acquired lock for tenant {tenant_id}")
with get_session_with_tenant(tenant_id) as db_session:
try:
logger.debug(
f"Setting tenant ID context variable for tenant {tenant_id}"
)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public")
slack_bot_tokens = fetch_tokens()
logger.debug(f"Fetched Slack bot tokens for tenant {tenant_id}")
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
logger.debug(
f"Reset tenant ID context variable for tenant {tenant_id}"
)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
try:
with get_session_with_tenant(tenant_id) as db_session:
try:
logger.debug(
f"Setting tenant ID context variable for tenant {tenant_id}"
)
slack_bot_tokens = fetch_tokens()
logger.debug(f"Fetched Slack bot tokens for tenant {tenant_id}")
logger.debug(
f"Reset tenant ID context variable for tenant {tenant_id}"
)
if not slack_bot_tokens:
logger.debug(f"No Slack bot token found for tenant {tenant_id}")
if not slack_bot_tokens:
logger.debug(
f"No Slack bot token found for tenant {tenant_id}"
)
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
continue
if (
tenant_id not in self.slack_bot_tokens
or slack_bot_tokens != self.slack_bot_tokens[tenant_id]
):
if tenant_id in self.slack_bot_tokens:
logger.info(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
search_settings = get_current_search_settings(
db_session
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
self.start_socket_client(tenant_id, slack_bot_tokens)
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
continue
if (
tenant_id not in self.slack_bot_tokens
or slack_bot_tokens != self.slack_bot_tokens[tenant_id]
):
if tenant_id in self.slack_bot_tokens:
logger.info(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
self.start_socket_client(tenant_id, slack_bot_tokens)
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
except Exception as e:
logger.exception(f"Error handling tenant {tenant_id}: {e}")
except Exception as e:
logger.exception(f"Error handling tenant {tenant_id}: {e}")
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def send_heartbeats(self) -> None:
current_time = int(time.time())

View File

@@ -5,16 +5,16 @@ from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from danswer.auth.api_key import ApiKeyDescriptor
from danswer.auth.api_key import build_displayable_api_key
from danswer.auth.api_key import generate_api_key
from danswer.auth.api_key import hash_api_key
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.models import ApiKey
from danswer.db.models import User
from ee.danswer.auth.api_key import ApiKeyDescriptor
from ee.danswer.auth.api_key import build_displayable_api_key
from ee.danswer.auth.api_key import generate_api_key
from ee.danswer.auth.api_key import hash_api_key
from ee.danswer.server.api_key.models import APIKeyArgs
from danswer.server.api_key.models import APIKeyArgs
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

View File

@@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserRole
from danswer.db.api_key import get_api_key_email_pattern
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.models import AccessToken
@@ -22,7 +23,6 @@ from danswer.db.models import User
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from ee.danswer.db.api_key import get_api_key_email_pattern
def get_default_admin_user_emails() -> list[str]:

View File

@@ -25,8 +25,8 @@ from danswer.db.models import UserGroup__ConnectorCredentialPair
from danswer.db.models import UserRole
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.sync_params import check_if_valid_sync_source
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
@@ -351,7 +351,11 @@ def add_credential_to_connector(
raise HTTPException(status_code=404, detail="Connector does not exist")
if access_type == AccessType.SYNC:
if not check_if_valid_sync_source(connector.source):
if not fetch_ee_implementation_or_noop(
"danswer.external_permissions.sync_params",
"check_if_valid_sync_source",
noop_return_value=True,
)(connector.source):
raise HTTPException(
status_code=400,
detail=f"Connector of type {connector.source} does not support SYNC access type",
@@ -438,7 +442,10 @@ def remove_credential_from_connector(
)
if association is not None:
delete_user__ext_group_for_cc_pair__no_commit(
fetch_ee_implementation_or_noop(
"danswer.db.external_perm",
"delete_user__ext_group_for_cc_pair__no_commit",
)(
db_session=db_session,
cc_pair_id=association.id,
)

View File

@@ -169,6 +169,7 @@ def get_document_connector_counts(
def get_document_counts_for_cc_pairs(
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
) -> Sequence[tuple[int, int, int]]:
"""Returns a sequence of tuples of (connector_id, credential_id, document count)"""
stmt = (
select(
DocumentByConnectorCredentialPair.connector_id,
@@ -323,23 +324,23 @@ def upsert_documents(
def upsert_document_by_connector_credential_pair(
db_session: Session, document_metadata_batch: list[DocumentMetadata]
db_session: Session, connector_id: int, credential_id: int, document_ids: list[str]
) -> None:
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
if not document_metadata_batch:
logger.info("`document_metadata_batch` is empty. Skipping.")
if not document_ids:
logger.info("`document_ids` is empty. Skipping.")
return
insert_stmt = insert(DocumentByConnectorCredentialPair).values(
[
model_to_dict(
DocumentByConnectorCredentialPair(
id=document_metadata.document_id,
connector_id=document_metadata.connector_id,
credential_id=document_metadata.credential_id,
id=doc_id,
connector_id=connector_id,
credential_id=credential_id,
)
)
for document_metadata in document_metadata_batch
for doc_id in document_ids
]
)
# for now, there are no columns to update. If more metadata is added, then this
@@ -400,17 +401,6 @@ def mark_document_as_synced(document_id: str, db_session: Session) -> None:
db_session.commit()
def upsert_documents_complete(
db_session: Session,
document_metadata_batch: list[DocumentMetadata],
) -> None:
upsert_documents(db_session, document_metadata_batch)
upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
logger.info(
f"Upserted {len(document_metadata_batch)} document store entries into DB"
)
def delete_document_by_connector_credential_pair__no_commit(
db_session: Session,
document_id: str,
@@ -520,7 +510,7 @@ def prepare_to_modify_documents(
db_session.commit() # ensure that we're not in a transaction
lock_acquired = False
for _ in range(_NUM_LOCK_ATTEMPTS):
for i in range(_NUM_LOCK_ATTEMPTS):
try:
with db_session.begin() as transaction:
lock_acquired = acquire_document_locks(
@@ -531,7 +521,7 @@ def prepare_to_modify_documents(
break
except OperationalError as e:
logger.warning(
f"Failed to acquire locks for documents, retrying. Error: {e}"
f"Failed to acquire locks for documents on attempt {i}, retrying. Error: {e}"
)
time.sleep(retry_delay)

View File

@@ -312,7 +312,9 @@ async def get_async_session_with_tenant(
await session.execute(text(f'SET search_path = "{tenant_id}"'))
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
await session.execute(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
except Exception:
logger.exception("Error setting search_path.")
@@ -323,16 +325,28 @@ async def get_async_session_with_tenant(
yield session
@contextmanager
def get_session_with_default_tenant() -> Generator[Session, None, None]:
"""
Get a database session using the current tenant ID from the context variable.
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
with get_session_with_tenant(tenant_id) as session:
yield session
@contextmanager
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""
Generate a database session bound to a connection with the appropriate tenant schema set.
This preserves the tenant ID across the session and reverts to the previous tenant ID
after the session is closed.
If tenant ID is set, we save the previous tenant ID from the context var to set
after the session is closed. The value `None` evaluates to the default schema.
Generate a database session for a specific tenant.
This function:
1. Sets the database schema to the specified tenant's schema.
2. Preserves the tenant ID across the session.
3. Reverts to the previous tenant ID after the session is closed.
4. Uses the default schema if no tenant ID is provided.
"""
engine = get_sqlalchemy_engine()
@@ -340,9 +354,9 @@ def get_session_with_tenant(
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
if tenant_id is None:
tenant_id = previous_tenant_id
else:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
event.listen(engine, "checkout", set_search_path_on_checkout)
@@ -361,7 +375,9 @@ def get_session_with_tenant(
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
cursor.execute(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
finally:
cursor.close()

View File

@@ -12,7 +12,7 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import IndexAttempt
@@ -152,7 +152,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
if db_session is None:
with get_session_with_tenant() as db_session:
with get_session_with_default_tenant() as db_session:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_current_search_settings(db_session)

View File

@@ -0,0 +1,111 @@
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.constants import TokenRateLimitScope
from danswer.db.models import TokenRateLimit
from danswer.db.models import TokenRateLimit__UserGroup
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
def fetch_all_user_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
ordered: bool = True,
) -> Sequence[TokenRateLimit]:
query = select(TokenRateLimit).where(
TokenRateLimit.scope == TokenRateLimitScope.USER
)
if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))
if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())
return db_session.scalars(query).all()
def fetch_all_global_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
ordered: bool = True,
) -> Sequence[TokenRateLimit]:
query = select(TokenRateLimit).where(
TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
)
if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))
if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())
token_rate_limits = db_session.scalars(query).all()
return token_rate_limits
def insert_user_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.USER,
)
db_session.add(token_limit)
db_session.commit()
return token_limit
def insert_global_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.GLOBAL,
)
db_session.add(token_limit)
db_session.commit()
return token_limit
def update_token_rate_limit(
db_session: Session,
token_rate_limit_id: int,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
if token_limit is None:
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
token_limit.enabled = token_rate_limit_settings.enabled
token_limit.token_budget = token_rate_limit_settings.token_budget
token_limit.period_hours = token_rate_limit_settings.period_hours
db_session.commit()
return token_limit
def delete_token_rate_limit(
db_session: Session,
token_rate_limit_id: int,
) -> None:
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
if token_limit is None:
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
db_session.query(TokenRateLimit__UserGroup).filter(
TokenRateLimit__UserGroup.rate_limit_id == token_rate_limit_id
).delete()
db_session.delete(token_limit)
db_session.commit()

View File

@@ -24,6 +24,13 @@ def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
return tool
def get_tool_by_name(tool_name: str, db_session: Session) -> Tool:
tool = db_session.scalar(select(Tool).where(Tool.name == tool_name))
if not tool:
raise ValueError("Tool by specified name does not exist")
return tool
def create_tool(
name: str,
description: str | None,
@@ -37,7 +44,7 @@ def create_tool(
description=description,
in_code_tool_id=None,
openapi_schema=openapi_schema,
custom_headers=[header.dict() for header in custom_headers]
custom_headers=[header.model_dump() for header in custom_headers]
if custom_headers
else [],
user_id=user_id,

View File

@@ -20,7 +20,8 @@ from danswer.db.document import get_documents_by_ids
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document import update_docs_last_modified__no_commit
from danswer.db.document import update_docs_updated_at__no_commit
from danswer.db.document import upsert_documents_complete
from danswer.db.document import upsert_document_by_connector_credential_pair
from danswer.db.document import upsert_documents
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.index_attempt import create_index_attempt_error
from danswer.db.models import Document as DBDocument
@@ -56,13 +57,13 @@ class IndexingPipelineProtocol(Protocol):
...
def upsert_documents_in_db(
def _upsert_documents_in_db(
documents: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
) -> None:
# Metadata here refers to basic document info, not metadata about the actual content
doc_m_batch: list[DocumentMetadata] = []
document_metadata_list: list[DocumentMetadata] = []
for doc in documents:
first_link = next(
(section.link for section in doc.sections if section.link), ""
@@ -77,12 +78,9 @@ def upsert_documents_in_db(
secondary_owners=get_experts_stores_representations(doc.secondary_owners),
from_ingestion_api=doc.from_ingestion_api,
)
doc_m_batch.append(db_doc_metadata)
document_metadata_list.append(db_doc_metadata)
upsert_documents_complete(
db_session=db_session,
document_metadata_batch=doc_m_batch,
)
upsert_documents(db_session, document_metadata_list)
# Insert document content metadata
for doc in documents:
@@ -95,21 +93,25 @@ def upsert_documents_in_db(
document_id=doc.id,
db_session=db_session,
)
else:
create_or_add_document_tag(
tag_key=k,
tag_value=v,
source=doc.source,
document_id=doc.id,
db_session=db_session,
)
continue
create_or_add_document_tag(
tag_key=k,
tag_value=v,
source=doc.source,
document_id=doc.id,
db_session=db_session,
)
def get_doc_ids_to_update(
documents: list[Document], db_docs: list[DBDocument]
) -> list[Document]:
"""Figures out which documents actually need to be updated. If a document is already present
and the `updated_at` hasn't changed, we shouldn't need to do anything with it."""
and the `updated_at` hasn't changed, we shouldn't need to do anything with it.
NB: Still need to associate the document in the DB if multiple connectors are
indexing the same doc."""
id_update_time_map = {
doc.id: doc.doc_updated_at for doc in db_docs if doc.doc_updated_at
}
@@ -195,9 +197,9 @@ def index_doc_batch_prepare(
db_session: Session,
ignore_time_skip: bool = False,
) -> DocumentBatchPrepareContext | None:
"""This sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
This preceeds indexing it into the actual document index."""
documents = []
documents: list[Document] = []
for document in document_batch:
empty_contents = not any(section.text.strip() for section in document.sections)
if (
@@ -212,43 +214,58 @@ def index_doc_batch_prepare(
logger.warning(
f"Skipping document with ID {document.id} as it has neither title nor content."
)
elif (
document.title is not None and not document.title.strip() and empty_contents
):
continue
if document.title is not None and not document.title.strip() and empty_contents:
# The title is explicitly empty ("" and not None) and the document is empty
# so when building the chunk text representation, it will be empty and unuseable
logger.warning(
f"Skipping document with ID {document.id} as the chunks will be empty."
)
else:
documents.append(document)
continue
document_ids = [document.id for document in documents]
documents.append(document)
# Create a trimmed list of docs that don't have a newer updated at
# Shortcuts the time-consuming flow on connector index retries
document_ids: list[str] = [document.id for document in documents]
db_docs: list[DBDocument] = get_documents_by_ids(
db_session=db_session,
document_ids=document_ids,
)
# Skip indexing docs that don't have a newer updated at
# Shortcuts the time-consuming flow on connector index retries
updatable_docs = (
get_doc_ids_to_update(documents=documents, db_docs=db_docs)
if not ignore_time_skip
else documents
)
# No docs to update either because the batch is empty or every doc was already indexed
# for all updatable docs, upsert into the DB
# Does not include doc_updated_at which is also used to indicate a successful update
if updatable_docs:
_upsert_documents_in_db(
documents=updatable_docs,
index_attempt_metadata=index_attempt_metadata,
db_session=db_session,
)
logger.info(
f"Upserted {len(updatable_docs)} changed docs out of "
f"{len(documents)} total docs into the DB"
)
# for all docs, upsert the document to cc pair relationship
upsert_document_by_connector_credential_pair(
db_session,
index_attempt_metadata.connector_id,
index_attempt_metadata.credential_id,
document_ids,
)
# No docs to process because the batch is empty or every doc was already indexed
if not updatable_docs:
return None
# Create records in the source of truth about these documents,
# does not include doc_updated_at which is also used to indicate a successful update
upsert_documents_in_db(
documents=documents,
index_attempt_metadata=index_attempt_metadata,
db_session=db_session,
)
id_to_db_doc_map = {doc.id: doc for doc in db_docs}
return DocumentBatchPrepareContext(
updatable_docs=updatable_docs, id_to_db_doc_map=id_to_db_doc_map
@@ -269,7 +286,10 @@ def index_doc_batch(
) -> tuple[int, int]:
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
memory requirements"""
memory requirements
Returns a tuple where the first element is the number of new docs and the
second element is the number of chunks."""
no_access = DocumentAccess.build(
user_emails=[],
@@ -312,9 +332,9 @@ def index_doc_batch(
# we're concerned about race conditions where multiple simultaneous indexings might result
# in one set of metadata overwriting another one in vespa.
# we still write data here for immediate and most likely correct sync, but
# we still write data here for the immediate and most likely correct sync, but
# to resolve this, an update of the last modified field at the end of this loop
# always triggers a final metadata sync
# always triggers a final metadata sync via the celery queue
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
@@ -351,7 +371,8 @@ def index_doc_batch(
ids_to_new_updated_at = {}
for doc in successful_docs:
last_modified_ids.append(doc.id)
# doc_updated_at is the connector source's idea of when the doc was last modified
# doc_updated_at is the source's idea (on the other end of the connector)
# of when the doc was last modified
if doc.doc_updated_at is None:
continue
ids_to_new_updated_at[doc.id] = doc.doc_updated_at
@@ -366,10 +387,13 @@ def index_doc_batch(
db_session.commit()
return len([r for r in insertion_records if r.already_existed is False]), len(
access_aware_chunks
result = (
len([r for r in insertion_records if r.already_existed is False]),
len(access_aware_chunks),
)
return result
def build_indexing_pipeline(
*,

View File

@@ -25,6 +25,7 @@ from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserUpdate
from danswer.auth.users import auth_backend
from danswer.auth.users import BasicAuthenticationError
from danswer.auth.users import fastapi_users
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.configs.app_configs import APP_HOST
@@ -73,6 +74,9 @@ from danswer.server.manage.search_settings import router as search_settings_rout
from danswer.server.manage.slack_bot import router as slack_bot_management_router
from danswer.server.manage.users import router as user_router
from danswer.server.middleware.latency_logging import add_latency_logging_middleware
from danswer.server.openai_assistants_api.full_openai_assistants_api import (
get_full_openai_assistants_api_router,
)
from danswer.server.query_and_chat.chat_backend import router as chat_router
from danswer.server.query_and_chat.query_backend import (
admin_router as admin_query_router,
@@ -194,7 +198,12 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
def log_http_error(_: Request, exc: Exception) -> JSONResponse:
status_code = getattr(exc, "status_code", 500)
if status_code >= 400:
if isinstance(exc, BasicAuthenticationError):
# For BasicAuthenticationError, just log a brief message without stack trace (almost always spam)
logger.error(f"Authentication failed: {str(exc)}")
elif status_code >= 400:
error_msg = f"{str(exc)}\n"
error_msg += "".join(traceback.format_tb(exc.__traceback__))
logger.error(error_msg)
@@ -220,7 +229,6 @@ def get_application() -> FastAPI:
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
# Add the custom exception handler
application.add_exception_handler(status.HTTP_400_BAD_REQUEST, log_http_error)
application.add_exception_handler(status.HTTP_401_UNAUTHORIZED, log_http_error)
application.add_exception_handler(status.HTTP_403_FORBIDDEN, log_http_error)
@@ -265,6 +273,9 @@ def get_application() -> FastAPI:
application, token_rate_limit_settings_router
)
include_router_with_global_prefix_prepended(application, indexing_router)
include_router_with_global_prefix_prepended(
application, get_full_openai_assistants_api_router()
)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step
@@ -277,12 +288,14 @@ def get_application() -> FastAPI:
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_register_router(UserRead, UserCreate),
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_reset_password_router(),

View File

@@ -35,23 +35,31 @@ class BaseTokenizer(ABC):
class TiktokenTokenizer(BaseTokenizer):
_instances: dict[str, "TiktokenTokenizer"] = {}
def __new__(cls, encoding_name: str = "cl100k_base") -> "TiktokenTokenizer":
if encoding_name not in cls._instances:
cls._instances[encoding_name] = super(TiktokenTokenizer, cls).__new__(cls)
return cls._instances[encoding_name]
def __new__(cls, model_name: str) -> "TiktokenTokenizer":
if model_name not in cls._instances:
cls._instances[model_name] = super(TiktokenTokenizer, cls).__new__(cls)
return cls._instances[model_name]
def __init__(self, encoding_name: str = "cl100k_base"):
def __init__(self, model_name: str):
if not hasattr(self, "encoder"):
import tiktoken
self.encoder = tiktoken.get_encoding(encoding_name)
self.encoder = tiktoken.encoding_for_model(model_name)
def encode(self, string: str) -> list[int]:
# this returns no special tokens
# this ignores special tokens that the model is trained on, see encode_ordinary for details
return self.encoder.encode_ordinary(string)
def tokenize(self, string: str) -> list[str]:
return [self.encoder.decode([token]) for token in self.encode(string)]
encoded = self.encode(string)
decoded = [self.encoder.decode([token]) for token in encoded]
if len(decoded) != len(encoded):
logger.warning(
f"OpenAI tokenized length {len(decoded)} does not match encoded length {len(encoded)} for string: {string}"
)
return decoded
def decode(self, tokens: list[int]) -> str:
return self.encoder.decode(tokens)
@@ -74,22 +82,35 @@ class HuggingFaceTokenizer(BaseTokenizer):
return self.encoder.decode(tokens)
_TOKENIZER_CACHE: dict[str, BaseTokenizer] = {}
_TOKENIZER_CACHE: dict[tuple[EmbeddingProvider | None, str | None], BaseTokenizer] = {}
def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
def _check_tokenizer_cache(
model_provider: EmbeddingProvider | None, model_name: str | None
) -> BaseTokenizer:
global _TOKENIZER_CACHE
if tokenizer_name not in _TOKENIZER_CACHE:
if tokenizer_name == "openai":
_TOKENIZER_CACHE[tokenizer_name] = TiktokenTokenizer("cl100k_base")
return _TOKENIZER_CACHE[tokenizer_name]
id_tuple = (model_provider, model_name)
if id_tuple not in _TOKENIZER_CACHE:
if model_provider in [EmbeddingProvider.OPENAI, EmbeddingProvider.AZURE]:
if model_name is None:
raise ValueError(
"model_name is required for OPENAI and AZURE embeddings"
)
_TOKENIZER_CACHE[id_tuple] = TiktokenTokenizer(model_name)
return _TOKENIZER_CACHE[id_tuple]
try:
logger.debug(f"Initializing HuggingFaceTokenizer for: {tokenizer_name}")
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(tokenizer_name)
if model_name is None:
model_name = DOCUMENT_ENCODER_MODEL
logger.debug(f"Initializing HuggingFaceTokenizer for: {model_name}")
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(model_name)
except Exception as primary_error:
logger.error(
f"Error initializing HuggingFaceTokenizer for {tokenizer_name}: {primary_error}"
f"Error initializing HuggingFaceTokenizer for {model_name}: {primary_error}"
)
logger.warning(
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
@@ -98,7 +119,7 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
try:
# Cache this tokenizer name to the default so we don't have to try to load it again
# and fail again
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(
DOCUMENT_ENCODER_MODEL
)
except Exception as fallback_error:
@@ -106,10 +127,10 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
)
raise ValueError(
f"Failed to initialize tokenizer for {tokenizer_name} and fallback model"
f"Failed to initialize tokenizer for {model_name} and fallback model"
) from fallback_error
return _TOKENIZER_CACHE[tokenizer_name]
return _TOKENIZER_CACHE[id_tuple]
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
@@ -118,11 +139,16 @@ _DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
def get_tokenizer(
model_name: str | None, provider_type: EmbeddingProvider | str | None
) -> BaseTokenizer:
# Currently all of the viable models use the same sentencepiece tokenizer
# OpenAI uses a different one but currently it's not supported due to quality issues
# the inconsistent chunking makes using the sentencepiece tokenizer default better for now
# LLM tokenizers are specified by strings
global _DEFAULT_TOKENIZER
if provider_type is not None:
if isinstance(provider_type, str):
try:
provider_type = EmbeddingProvider(provider_type)
except ValueError:
logger.debug(
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
)
return _DEFAULT_TOKENIZER
return _check_tokenizer_cache(provider_type, model_name)
return _DEFAULT_TOKENIZER

View File

@@ -65,7 +65,7 @@ from danswer.tools.tool_implementations.search.search_tool import (
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
@@ -125,11 +125,11 @@ def stream_answer_objects(
)
temporary_persona: Persona | None = None
if query_req.persona_config is not None:
new_persona = create_temporary_persona(
db_session=db_session, persona_config=query_req.persona_config, user=user
)
temporary_persona = new_persona
temporary_persona = fetch_ee_implementation_or_noop(
"danswer.server.query_and_chat.utils", "create_temporary_persona", None
)(db_session=db_session, persona_config=query_req.persona_config, user=user)
persona = temporary_persona if temporary_persona else chat_session.persona

View File

@@ -1,9 +1,10 @@
import time
from typing import cast
from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -13,6 +14,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair_f
from danswer.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
)
from danswer.db.models import Document
from danswer.redis.redis_object_helper import RedisObjectHelper
@@ -30,6 +32,9 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
def __init__(self, tenant_id: str | None, id: int) -> None:
super().__init__(tenant_id, str(id))
# documents that should be skipped
self.skip_docs: set[str] = set()
@classmethod
def get_fence_key(cls) -> str:
return RedisConnectorCredentialPair.FENCE_PREFIX
@@ -45,14 +50,19 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
# example: connector_taskset
return f"{self.TASKSET_PREFIX}"
def set_skip_docs(self, skip_docs: set[str]) -> None:
# documents that should be skipped. Note that this classes updates
# the list on the fly
self.skip_docs = skip_docs
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
lock: RedisLock,
tenant_id: str | None,
) -> int | None:
) -> tuple[int, int] | None:
last_lock_time = time.monotonic()
async_results = []
@@ -63,7 +73,10 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
stmt = construct_document_select_for_connector_credential_pair_by_needs_sync(
cc_pair.connector_id, cc_pair.credential_id
)
num_docs = 0
for doc in db_session.scalars(stmt).yield_per(1):
doc = cast(Document, doc)
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
@@ -71,6 +84,12 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
lock.reacquire()
last_lock_time = current_time
num_docs += 1
# check if we should skip the document (typically because it's already syncing)
if doc.id in self.skip_docs:
continue
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
@@ -93,5 +112,6 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
)
async_results.append(result)
self.skip_docs.add(doc.id)
return len(async_results)
return len(async_results), num_docs

View File

@@ -6,6 +6,7 @@ from uuid import uuid4
import redis
from celery import Celery
from pydantic import BaseModel
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -83,7 +84,7 @@ class RedisConnectorDelete:
self,
celery_app: Celery,
db_session: Session,
lock: redis.lock.Lock,
lock: RedisLock,
) -> int | None:
"""Returns None if the cc_pair doesn't exist.
Otherwise, returns an int with the number of generated tasks."""

View File

@@ -6,7 +6,7 @@ import redis
from pydantic import BaseModel
class RedisConnectorIndexingFenceData(BaseModel):
class RedisConnectorIndexPayload(BaseModel):
index_attempt_id: int | None
started: datetime | None
submitted: datetime
@@ -71,22 +71,20 @@ class RedisConnectorIndex:
return False
@property
def payload(self) -> RedisConnectorIndexingFenceData | None:
def payload(self) -> RedisConnectorIndexPayload | None:
# read related data and evaluate/print task progress
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
if fence_bytes is None:
return None
fence_str = fence_bytes.decode("utf-8")
payload = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_str)
)
payload = RedisConnectorIndexPayload.model_validate_json(cast(str, fence_str))
return payload
def set_fence(
self,
payload: RedisConnectorIndexingFenceData | None,
payload: RedisConnectorIndexPayload | None,
) -> None:
if not payload:
self.redis.delete(self.fence_key)

View File

@@ -4,6 +4,7 @@ from uuid import uuid4
import redis
from celery import Celery
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -105,7 +106,7 @@ class RedisConnectorPrune:
documents_to_prune: set[str],
celery_app: Celery,
db_session: Session,
lock: redis.lock.Lock | None,
lock: RedisLock | None,
) -> int | None:
last_lock_time = time.monotonic()

View File

@@ -5,6 +5,7 @@ from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -50,9 +51,9 @@ class RedisDocumentSet(RedisObjectHelper):
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
lock: RedisLock,
tenant_id: str | None,
) -> int | None:
) -> tuple[int, int] | None:
last_lock_time = time.monotonic()
async_results = []
@@ -84,7 +85,7 @@ class RedisDocumentSet(RedisObjectHelper):
async_results.append(result)
return len(async_results)
return len(async_results), len(async_results)
def reset(self) -> None:
self.redis.delete(self.taskset_key)

View File

@@ -1,9 +1,9 @@
from abc import ABC
from abc import abstractmethod
import redis
from celery import Celery
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.redis.redis_pool import get_redis_client
@@ -85,7 +85,13 @@ class RedisObjectHelper(ABC):
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
lock: RedisLock,
tenant_id: str | None,
) -> int | None:
pass
) -> tuple[int, int] | None:
"""First element should be the number of actual tasks generated, second should
be the number of docs that were candidates to be synced for the cc pair.
The need for this is when we are syncing stale docs referenced by multiple
connectors. In a single pass across multiple cc pairs, we only want a task
for be created for a particular document id the first time we see it.
The rest can be skipped."""

View File

@@ -5,6 +5,7 @@ from uuid import uuid4
import redis
from celery import Celery
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -51,15 +52,15 @@ class RedisUserGroup(RedisObjectHelper):
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
lock: RedisLock,
tenant_id: str | None,
) -> int | None:
) -> tuple[int, int] | None:
last_lock_time = time.monotonic()
async_results = []
if not global_version.is_ee_version():
return 0
return 0, 0
try:
construct_document_select_by_usergroup = fetch_versioned_implementation(
@@ -67,7 +68,7 @@ class RedisUserGroup(RedisObjectHelper):
"construct_document_select_by_usergroup",
)
except ModuleNotFoundError:
return 0
return 0, 0
stmt = construct_document_select_by_usergroup(int(self._id))
for doc in db_session.scalars(stmt).yield_per(1):
@@ -97,7 +98,7 @@ class RedisUserGroup(RedisObjectHelper):
async_results.append(result)
return len(async_results)
return len(async_results), len(async_results)
def reset(self) -> None:
self.redis.delete(self.taskset_key)

View File

@@ -3,15 +3,15 @@ from fastapi import Depends
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.db.api_key import ApiKeyDescriptor
from danswer.db.api_key import fetch_api_keys
from danswer.db.api_key import insert_api_key
from danswer.db.api_key import regenerate_api_key
from danswer.db.api_key import remove_api_key
from danswer.db.api_key import update_api_key
from danswer.db.engine import get_session
from danswer.db.models import User
from ee.danswer.db.api_key import ApiKeyDescriptor
from ee.danswer.db.api_key import fetch_api_keys
from ee.danswer.db.api_key import insert_api_key
from ee.danswer.db.api_key import regenerate_api_key
from ee.danswer.db.api_key import remove_api_key
from ee.danswer.db.api_key import update_api_key
from ee.danswer.server.api_key.models import APIKeyArgs
from danswer.server.api_key.models import APIKeyArgs
router = APIRouter(prefix="/admin/api-key")

View File

@@ -10,8 +10,7 @@ from danswer.auth.users import current_user
from danswer.auth.users import current_user_with_expired_token
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.server.danswer_api.ingestion import api_key_dep
from ee.danswer.auth.users import current_cloud_superuser
from ee.danswer.server.tenants.access import control_plane_dep
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
PUBLIC_ENDPOINT_SPECS = [
@@ -81,6 +80,14 @@ def check_router_auth(
(1) have auth enabled OR
(2) are explicitly marked as a public endpoint
"""
control_plane_dep = fetch_ee_implementation_or_noop(
"danswer.server.tenants.access", "control_plane_dep"
)
current_cloud_superuser = fetch_ee_implementation_or_noop(
"danswer.auth.users", "current_cloud_superuser"
)
for route in application.routes:
# explicitly marked as public
if is_route_in_spec_list(route, public_endpoint_specs):

View File

@@ -3,6 +3,7 @@ from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import api_key_dep
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import Document
from danswer.connectors.models import IndexAttemptMetadata
@@ -22,7 +23,6 @@ from danswer.server.danswer_api.models import DocMinimalInfo
from danswer.server.danswer_api.models import IngestionDocument
from danswer.server.danswer_api.models import IngestionResult
from danswer.utils.logger import setup_logger
from ee.danswer.auth.users import api_key_dep
logger = setup_logger()

View File

@@ -16,6 +16,9 @@ from danswer.background.celery.tasks.pruning.tasks import (
try_creating_prune_generator_task,
)
from danswer.background.celery.versioned_apps.primary import app as primary_app
from danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from danswer.db.connector_credential_pair import add_credential_to_connector
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import remove_credential_from_connector
@@ -47,11 +50,7 @@ from danswer.server.documents.models import ConnectorCredentialPairMetadata
from danswer.server.documents.models import PaginatedIndexAttempts
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from ee.danswer.db.user_group import validate_user_creation_permissions
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
router = APIRouter(prefix="/manage")
@@ -332,9 +331,6 @@ def sync_cc_pair(
db_session: Session = Depends(get_session),
) -> StatusResponse[list[int]]:
# avoiding circular refs
from ee.danswer.background.celery.apps.primary import (
sync_external_doc_permissions_task,
)
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
@@ -360,12 +356,19 @@ def sync_cc_pair(
)
logger.info(f"Syncing the {cc_pair.connector.name} connector.")
sync_external_doc_permissions_task.apply_async(
kwargs=dict(
cc_pair_id=cc_pair_id, tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get()
),
sync_external_doc_permissions_task = fetch_ee_implementation_or_noop(
"danswer.background.celery.apps.primary",
"sync_external_doc_permissions_task",
None,
)
if sync_external_doc_permissions_task:
sync_external_doc_permissions_task.apply_async(
kwargs=dict(
cc_pair_id=cc_pair_id, tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get()
),
)
return StatusResponse(
success=True,
message="Successfully created the sync task.",
@@ -380,7 +383,9 @@ def associate_credential_to_connector(
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[int]:
validate_user_creation_permissions(
fetch_ee_implementation_or_noop(
"danswer.db.user_group", "validate_user_creation_permissions", None
)(
db_session=db_session,
user=user,
target_group_ids=metadata.groups,

View File

@@ -108,7 +108,7 @@ from danswer.server.documents.models import ObjectCreationIdResponse
from danswer.server.documents.models import RunConnectorRequest
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.db.user_group import validate_user_creation_permissions
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
@@ -658,7 +658,10 @@ def create_connector_from_model(
) -> ObjectCreationIdResponse:
try:
_validate_connector_allowed(connector_data.source)
validate_user_creation_permissions(
fetch_ee_implementation_or_noop(
"danswer.db.user_group", "validate_user_creation_permissions", None
)(
db_session=db_session,
user=user,
target_group_ids=connector_data.groups,
@@ -732,7 +735,9 @@ def update_connector_from_model(
) -> ConnectorSnapshot | StatusResponse[int]:
try:
_validate_connector_allowed(connector_data.source)
validate_user_creation_permissions(
fetch_ee_implementation_or_noop(
"danswer.db.user_group", "validate_user_creation_permissions", None
)(
db_session=db_session,
user=user,
target_group_ids=connector_data.groups,

View File

@@ -28,7 +28,7 @@ from danswer.server.documents.models import CredentialSwapRequest
from danswer.server.documents.models import ObjectCreationIdResponse
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.db.user_group import validate_user_creation_permissions
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
@@ -121,7 +121,9 @@ def create_credential_from_model(
db_session: Session = Depends(get_session),
) -> ObjectCreationIdResponse:
if not _ignore_credential_permissions(credential_info.source):
validate_user_creation_permissions(
fetch_ee_implementation_or_noop(
"danswer.db.user_group", "validate_user_creation_permissions", None
)(
db_session=db_session,
user=user,
target_group_ids=credential_info.groups,

View File

@@ -18,7 +18,7 @@ from danswer.server.features.document_set.models import CheckDocSetPublicRespons
from danswer.server.features.document_set.models import DocumentSet
from danswer.server.features.document_set.models import DocumentSetCreationRequest
from danswer.server.features.document_set.models import DocumentSetUpdateRequest
from ee.danswer.db.user_group import validate_user_creation_permissions
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
router = APIRouter(prefix="/manage")
@@ -30,7 +30,9 @@ def create_document_set(
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> int:
validate_user_creation_permissions(
fetch_ee_implementation_or_noop(
"danswer.db.user_group", "validate_user_creation_permissions", None
)(
db_session=db_session,
user=user,
target_group_ids=document_set_creation_request.groups,
@@ -53,7 +55,9 @@ def patch_document_set(
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> None:
validate_user_creation_permissions(
fetch_ee_implementation_or_noop(
"danswer.db.user_group", "validate_user_creation_permissions", None
)(
db_session=db_session,
user=user,
target_group_ids=document_set_update_request.groups,

View File

@@ -11,7 +11,6 @@ from fastapi import Body
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import status
from psycopg2.errors import UniqueViolation
from pydantic import BaseModel
from sqlalchemy import Column
@@ -27,10 +26,10 @@ from danswer.auth.noauth_user import fetch_no_auth_user
from danswer.auth.noauth_user import set_no_auth_user_preferences
from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserStatus
from danswer.auth.users import BasicAuthenticationError
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import get_tenant_id_for_email
from danswer.auth.users import optional_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
@@ -38,6 +37,7 @@ from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import SUPER_USERS
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.constants import AuthType
from danswer.db.api_key import is_api_key_email_address
from danswer.db.auth import get_total_users_count
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from danswer.db.engine import get_session
@@ -61,12 +61,7 @@ from danswer.server.models import InvitedUserSnapshot
from danswer.server.models import MinimalUserSnapshot
from danswer.server.utils import send_user_email_invite
from danswer.utils.logger import setup_logger
from ee.danswer.db.api_key import is_api_key_email_address
from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_commit
from ee.danswer.db.user_group import remove_curator_status__no_commit
from ee.danswer.server.tenants.billing import register_tenant_users
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import remove_users_from_tenant
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -105,7 +100,10 @@ def set_user_role(
)
if user_to_update.role == UserRole.CURATOR:
remove_curator_status__no_commit(db_session, user_to_update)
fetch_ee_implementation_or_noop(
"danswer.db.user_group",
"remove_curator_status__no_commit",
)(db_session, user_to_update)
user_to_update.role = user_role_update_request.new_role.value
@@ -205,7 +203,9 @@ def bulk_invite_users(
if MULTI_TENANT:
try:
add_users_to_tenant(normalized_emails, tenant_id)
fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning", "add_users_to_tenant", None
)(normalized_emails, tenant_id)
except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
@@ -226,9 +226,9 @@ def bulk_invite_users(
return number_of_invited_users
try:
logger.info("Registering tenant users")
register_tenant_users(
CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session)
)
fetch_ee_implementation_or_noop(
"danswer.server.tenants.billing", "register_tenant_users", None
)(CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session))
if ENABLE_EMAIL_INVITES:
try:
for email in all_emails:
@@ -243,7 +243,9 @@ def bulk_invite_users(
"Reverting changes: removing users from tenant and resetting invited users"
)
write_invited_users(initial_invited_users) # Reset to original state
remove_users_from_tenant(normalized_emails, tenant_id)
fetch_ee_implementation_or_noop(
"danswer.server.tenants.user_mapping", "remove_users_from_tenant", None
)(normalized_emails, tenant_id)
raise e
@@ -257,14 +259,16 @@ def remove_invited_user(
remaining_users = [user for user in user_emails if user != user_email.user_email]
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
remove_users_from_tenant([user_email.user_email], tenant_id)
fetch_ee_implementation_or_noop(
"danswer.server.tenants.user_mapping", "remove_users_from_tenant", None
)([user_email.user_email], tenant_id)
number_of_invited_users = write_invited_users(remaining_users)
try:
if MULTI_TENANT:
register_tenant_users(
CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session)
)
fetch_ee_implementation_or_noop(
"danswer.server.tenants.billing", "register_tenant_users", None
)(CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session))
except Exception:
logger.error(
"Request to update number of seats taken in control plane failed. "
@@ -331,7 +335,10 @@ async def delete_user(
for oauth_account in user_to_delete.oauth_accounts:
db_session.delete(oauth_account)
delete_user__ext_group_for_user__no_commit(
fetch_ee_implementation_or_noop(
"danswer.db.external_perm",
"delete_user__ext_group_for_user__no_commit",
)(
db_session=db_session,
user_id=user_to_delete.id,
)
@@ -485,20 +492,19 @@ def verify_user_logged_in(
store = get_kv_store()
return fetch_no_auth_user(store)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="User Not Authenticated"
)
raise BasicAuthenticationError(detail="User Not Authenticated")
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
raise BasicAuthenticationError(
detail="Access denied. User's OIDC token has expired.",
)
token_created_at = (
None if MULTI_TENANT else get_current_token_creation(user, db_session)
)
organization_name = get_tenant_id_for_email(user.email)
organization_name = fetch_ee_implementation_or_noop(
"danswer.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(user.email)
user_info = UserInfo.from_model(
user,

View File

@@ -0,0 +1,273 @@
from typing import Any
from typing import Optional
from uuid import uuid4
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.db.persona import get_persona_by_id
from danswer.db.persona import get_personas
from danswer.db.persona import mark_persona_as_deleted
from danswer.db.persona import upsert_persona
from danswer.db.persona import upsert_prompt
from danswer.db.tools import get_tool_by_name
from danswer.search.enums import RecencyBiasSetting
from danswer.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/assistants")
# Base models
class AssistantObject(BaseModel):
id: int
object: str = "assistant"
created_at: int
name: Optional[str] = None
description: Optional[str] = None
model: str
instructions: Optional[str] = None
tools: list[dict[str, Any]]
file_ids: list[str]
metadata: Optional[dict[str, Any]] = None
class CreateAssistantRequest(BaseModel):
model: str
name: Optional[str] = None
description: Optional[str] = None
instructions: Optional[str] = None
tools: Optional[list[dict[str, Any]]] = None
file_ids: Optional[list[str]] = None
metadata: Optional[dict[str, Any]] = None
class ModifyAssistantRequest(BaseModel):
model: Optional[str] = None
name: Optional[str] = None
description: Optional[str] = None
instructions: Optional[str] = None
tools: Optional[list[dict[str, Any]]] = None
file_ids: Optional[list[str]] = None
metadata: Optional[dict[str, Any]] = None
class DeleteAssistantResponse(BaseModel):
id: int
object: str = "assistant.deleted"
deleted: bool
class ListAssistantsResponse(BaseModel):
object: str = "list"
data: list[AssistantObject]
first_id: Optional[int] = None
last_id: Optional[int] = None
has_more: bool
def persona_to_assistant(persona: Persona) -> AssistantObject:
return AssistantObject(
id=persona.id,
created_at=0,
name=persona.name,
description=persona.description,
model=persona.llm_model_version_override or "gpt-3.5-turbo",
instructions=persona.prompts[0].system_prompt if persona.prompts else None,
tools=[
{
"type": tool.display_name,
"function": {
"name": tool.name,
"description": tool.description,
"schema": tool.openapi_schema,
},
}
for tool in persona.tools
],
file_ids=[], # Assuming no file support for now
metadata={}, # Assuming no metadata for now
)
# API endpoints
@router.post("")
def create_assistant(
request: CreateAssistantRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AssistantObject:
prompt = None
if request.instructions:
prompt = upsert_prompt(
user=user,
name=f"Prompt for {request.name or 'New Assistant'}",
description="Auto-generated prompt",
system_prompt=request.instructions,
task_prompt="",
include_citations=True,
datetime_aware=True,
personas=[],
db_session=db_session,
)
tool_ids = []
for tool in request.tools or []:
tool_type = tool.get("type")
if not tool_type:
continue
try:
tool_db = get_tool_by_name(tool_type, db_session)
tool_ids.append(tool_db.id)
except ValueError:
# Skip tools that don't exist in the database
logger.error(f"Tool {tool_type} not found in database")
raise HTTPException(
status_code=404, detail=f"Tool {tool_type} not found in database"
)
persona = upsert_persona(
user=user,
name=request.name or f"Assistant-{uuid4()}",
description=request.description or "",
num_chunks=25,
llm_relevance_filter=True,
llm_filter_extraction=True,
recency_bias=RecencyBiasSetting.AUTO,
llm_model_provider_override=None,
llm_model_version_override=request.model,
starter_messages=None,
is_public=False,
db_session=db_session,
prompt_ids=[prompt.id] if prompt else [0],
document_set_ids=[],
tool_ids=tool_ids,
icon_color=None,
icon_shape=None,
is_visible=True,
)
if prompt:
prompt.personas = [persona]
db_session.commit()
return persona_to_assistant(persona)
""
@router.get("/{assistant_id}")
def retrieve_assistant(
assistant_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AssistantObject:
try:
persona = get_persona_by_id(
persona_id=assistant_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
except ValueError:
persona = None
if not persona:
raise HTTPException(status_code=404, detail="Assistant not found")
return persona_to_assistant(persona)
@router.post("/{assistant_id}")
def modify_assistant(
assistant_id: int,
request: ModifyAssistantRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> AssistantObject:
persona = get_persona_by_id(
persona_id=assistant_id,
user=user,
db_session=db_session,
is_for_edit=True,
)
if not persona:
raise HTTPException(status_code=404, detail="Assistant not found")
update_data = request.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(persona, key, value)
if "instructions" in update_data and persona.prompts:
persona.prompts[0].system_prompt = update_data["instructions"]
db_session.commit()
return persona_to_assistant(persona)
@router.delete("/{assistant_id}")
def delete_assistant(
assistant_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> DeleteAssistantResponse:
try:
mark_persona_as_deleted(
persona_id=int(assistant_id),
user=user,
db_session=db_session,
)
return DeleteAssistantResponse(id=assistant_id, deleted=True)
except ValueError:
raise HTTPException(status_code=404, detail="Assistant not found")
@router.get("")
def list_assistants(
limit: int = Query(20, le=100),
order: str = Query("desc", regex="^(asc|desc)$"),
after: Optional[int] = None,
before: Optional[int] = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ListAssistantsResponse:
personas = list(
get_personas(
user=user,
db_session=db_session,
get_editable=False,
joinedload_all=True,
)
)
# Apply filtering based on after and before
if after:
personas = [p for p in personas if p.id > int(after)]
if before:
personas = [p for p in personas if p.id < int(before)]
# Apply ordering
personas.sort(key=lambda p: p.id, reverse=(order == "desc"))
# Apply limit
personas = personas[:limit]
assistants = [persona_to_assistant(p) for p in personas]
return ListAssistantsResponse(
data=assistants,
first_id=assistants[0].id if assistants else None,
last_id=assistants[-1].id if assistants else None,
has_more=len(personas) == limit,
)

View File

@@ -0,0 +1,19 @@
from fastapi import APIRouter
from danswer.server.openai_assistants_api.asssistants_api import (
router as assistants_router,
)
from danswer.server.openai_assistants_api.messages_api import router as messages_router
from danswer.server.openai_assistants_api.runs_api import router as runs_router
from danswer.server.openai_assistants_api.threads_api import router as threads_router
def get_full_openai_assistants_api_router() -> APIRouter:
router = APIRouter(prefix="/openai-assistants")
router.include_router(assistants_router)
router.include_router(runs_router)
router.include_router(threads_router)
router.include_router(messages_router)
return router

View File

@@ -0,0 +1,235 @@
import uuid
from datetime import datetime
from typing import Any
from typing import Literal
from typing import Optional
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.configs.constants import MessageType
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_message
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_or_create_root_message
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.llm.utils import check_number_of_tokens
router = APIRouter(prefix="")
Role = Literal["user", "assistant"]
class MessageContent(BaseModel):
type: Literal["text"]
text: str
class Message(BaseModel):
id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4()}")
object: Literal["thread.message"] = "thread.message"
created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
thread_id: str
role: Role
content: list[MessageContent]
file_ids: list[str] = []
assistant_id: Optional[str] = None
run_id: Optional[str] = None
metadata: Optional[dict[str, Any]] = None # Change this line to use dict[str, Any]
class CreateMessageRequest(BaseModel):
role: Role
content: str
file_ids: list[str] = []
metadata: Optional[dict] = None
class ListMessagesResponse(BaseModel):
object: Literal["list"] = "list"
data: list[Message]
first_id: str
last_id: str
has_more: bool
@router.post("/threads/{thread_id}/messages")
def create_message(
thread_id: str,
message: CreateMessageRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Message:
user_id = user.id if user else None
try:
chat_session = get_chat_session_by_id(
chat_session_id=uuid.UUID(thread_id),
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Chat session not found")
chat_messages = get_chat_messages_by_session(
chat_session_id=chat_session.id,
user_id=user.id if user else None,
db_session=db_session,
)
latest_message = (
chat_messages[-1]
if chat_messages
else get_or_create_root_message(chat_session.id, db_session)
)
new_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=latest_message,
message=message.content,
prompt_id=chat_session.persona.prompts[0].id,
token_count=check_number_of_tokens(message.content),
message_type=(
MessageType.USER if message.role == "user" else MessageType.ASSISTANT
),
db_session=db_session,
)
return Message(
id=str(new_message.id),
thread_id=thread_id,
role="user",
content=[MessageContent(type="text", text=message.content)],
file_ids=message.file_ids,
metadata=message.metadata,
)
@router.get("/threads/{thread_id}/messages")
def list_messages(
thread_id: str,
limit: int = 20,
order: Literal["asc", "desc"] = "desc",
after: Optional[str] = None,
before: Optional[str] = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ListMessagesResponse:
user_id = user.id if user else None
try:
chat_session = get_chat_session_by_id(
chat_session_id=uuid.UUID(thread_id),
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Chat session not found")
messages = get_chat_messages_by_session(
chat_session_id=chat_session.id,
user_id=user_id,
db_session=db_session,
)
# Apply filtering based on after and before
if after:
messages = [m for m in messages if str(m.id) >= after]
if before:
messages = [m for m in messages if str(m.id) <= before]
# Apply ordering
messages = sorted(messages, key=lambda m: m.id, reverse=(order == "desc"))
# Apply limit
messages = messages[:limit]
data = [
Message(
id=str(m.id),
thread_id=thread_id,
role="user" if m.message_type == "user" else "assistant",
content=[MessageContent(type="text", text=m.message)],
created_at=int(m.time_sent.timestamp()),
)
for m in messages
]
return ListMessagesResponse(
data=data,
first_id=str(data[0].id) if data else "",
last_id=str(data[-1].id) if data else "",
has_more=len(messages) == limit,
)
@router.get("/threads/{thread_id}/messages/{message_id}")
def retrieve_message(
thread_id: str,
message_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Message:
user_id = user.id if user else None
try:
chat_message = get_chat_message(
chat_message_id=message_id,
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Message not found")
return Message(
id=str(chat_message.id),
thread_id=thread_id,
role="user" if chat_message.message_type == "user" else "assistant",
content=[MessageContent(type="text", text=chat_message.message)],
created_at=int(chat_message.time_sent.timestamp()),
)
class ModifyMessageRequest(BaseModel):
metadata: dict
@router.post("/threads/{thread_id}/messages/{message_id}")
def modify_message(
thread_id: str,
message_id: int,
request: ModifyMessageRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Message:
user_id = user.id if user else None
try:
chat_message = get_chat_message(
chat_message_id=message_id,
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Message not found")
# Update metadata
# TODO: Uncomment this once we have metadata in the chat message
# chat_message.metadata = request.metadata
# db_session.commit()
return Message(
id=str(chat_message.id),
thread_id=thread_id,
role="user" if chat_message.message_type == "user" else "assistant",
content=[MessageContent(type="text", text=chat_message.message)],
created_at=int(chat_message.time_sent.timestamp()),
metadata=request.metadata,
)

View File

@@ -0,0 +1,344 @@
from typing import Literal
from typing import Optional
from uuid import UUID
from fastapi import APIRouter
from fastapi import BackgroundTasks
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.chat.process_message import stream_chat_message_objects
from danswer.configs.constants import MessageType
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_message
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_or_create_root_message
from danswer.db.engine import get_session
from danswer.db.models import ChatMessage
from danswer.db.models import User
from danswer.search.models import RetrievalDetails
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter()
class RunRequest(BaseModel):
assistant_id: int
model: Optional[str] = None
instructions: Optional[str] = None
additional_instructions: Optional[str] = None
tools: Optional[list[dict]] = None
metadata: Optional[dict] = None
RunStatus = Literal[
"queued",
"in_progress",
"requires_action",
"cancelling",
"cancelled",
"failed",
"completed",
"expired",
]
class RunResponse(BaseModel):
id: str
object: Literal["thread.run"]
created_at: int
assistant_id: int
thread_id: UUID
status: RunStatus
started_at: Optional[int] = None
expires_at: Optional[int] = None
cancelled_at: Optional[int] = None
failed_at: Optional[int] = None
completed_at: Optional[int] = None
last_error: Optional[dict] = None
model: str
instructions: str
tools: list[dict]
file_ids: list[str]
metadata: Optional[dict] = None
def process_run_in_background(
message_id: int,
parent_message_id: int,
chat_session_id: UUID,
assistant_id: int,
instructions: str,
tools: list[dict],
user: User | None,
db_session: Session,
) -> None:
# Get the latest message in the chat session
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id,
user_id=user.id if user else None,
db_session=db_session,
)
search_tool_retrieval_details = RetrievalDetails()
for tool in tools:
if tool["type"] == SearchTool.__name__ and (
retrieval_details := tool.get("retrieval_details")
):
search_tool_retrieval_details = RetrievalDetails.model_validate(
retrieval_details
)
break
new_msg_req = CreateChatMessageRequest(
chat_session_id=chat_session_id,
parent_message_id=int(parent_message_id) if parent_message_id else None,
message=instructions,
file_descriptors=[],
prompt_id=chat_session.persona.prompts[0].id,
search_doc_ids=None,
retrieval_options=search_tool_retrieval_details, # Adjust as needed
query_override=None,
regenerate=None,
llm_override=None,
prompt_override=None,
alternate_assistant_id=assistant_id,
use_existing_user_message=True,
existing_assistant_message_id=message_id,
)
run_message = get_chat_message(message_id, user.id if user else None, db_session)
try:
for packet in stream_chat_message_objects(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
):
if isinstance(packet, ChatMessageDetail):
# Update the run status and message content
run_message = get_chat_message(
message_id, user.id if user else None, db_session
)
if run_message:
# this handles cancelling
if run_message.error:
return
run_message.message = packet.message
run_message.message_type = MessageType.ASSISTANT
db_session.commit()
except Exception as e:
logger.exception("Error processing run in background")
run_message.error = str(e)
db_session.commit()
return
db_session.refresh(run_message)
if run_message.token_count == 0:
run_message.error = "No tokens generated"
db_session.commit()
@router.post("/threads/{thread_id}/runs")
def create_run(
thread_id: UUID,
run_request: RunRequest,
background_tasks: BackgroundTasks,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> RunResponse:
try:
chat_session = get_chat_session_by_id(
chat_session_id=thread_id,
user_id=user.id if user else None,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Thread not found")
chat_messages = get_chat_messages_by_session(
chat_session_id=chat_session.id,
user_id=user.id if user else None,
db_session=db_session,
)
latest_message = (
chat_messages[-1]
if chat_messages
else get_or_create_root_message(chat_session.id, db_session)
)
# Create a new "run" (chat message) in the session
new_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=latest_message,
message="",
prompt_id=chat_session.persona.prompts[0].id,
token_count=0,
message_type=MessageType.ASSISTANT,
db_session=db_session,
commit=False,
)
db_session.flush()
latest_message.latest_child_message = new_message.id
db_session.commit()
# Schedule the background task
background_tasks.add_task(
process_run_in_background,
new_message.id,
latest_message.id,
chat_session.id,
run_request.assistant_id,
run_request.instructions or "",
run_request.tools or [],
user,
db_session,
)
return RunResponse(
id=str(new_message.id),
object="thread.run",
created_at=int(new_message.time_sent.timestamp()),
assistant_id=run_request.assistant_id,
thread_id=chat_session.id,
status="queued",
model=run_request.model or "default_model",
instructions=run_request.instructions or "",
tools=run_request.tools or [],
file_ids=[],
metadata=run_request.metadata,
)
@router.get("/threads/{thread_id}/runs/{run_id}")
def retrieve_run(
thread_id: UUID,
run_id: str,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> RunResponse:
# Retrieve the chat message (which represents a "run" in DAnswer)
chat_message = get_chat_message(
chat_message_id=int(run_id), # Convert string run_id to int
user_id=user.id if user else None,
db_session=db_session,
)
if not chat_message:
raise HTTPException(status_code=404, detail="Run not found")
chat_session = chat_message.chat_session
# Map DAnswer status to OpenAI status
run_status: RunStatus = "queued"
if chat_message.message:
run_status = "in_progress"
if chat_message.token_count != 0:
run_status = "completed"
if chat_message.error:
run_status = "cancelled"
return RunResponse(
id=run_id,
object="thread.run",
created_at=int(chat_message.time_sent.timestamp()),
assistant_id=chat_session.persona_id or 0,
thread_id=chat_session.id,
status=run_status,
started_at=int(chat_message.time_sent.timestamp()),
completed_at=(
int(chat_message.time_sent.timestamp()) if chat_message.message else None
),
model=chat_session.current_alternate_model or "default_model",
instructions="", # DAnswer doesn't store per-message instructions
tools=[], # DAnswer doesn't have a direct equivalent for tools
file_ids=(
[file["id"] for file in chat_message.files] if chat_message.files else []
),
metadata=None, # DAnswer doesn't store metadata for individual messages
)
@router.post("/threads/{thread_id}/runs/{run_id}/cancel")
def cancel_run(
thread_id: UUID,
run_id: str,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> RunResponse:
# In DAnswer, we don't have a direct equivalent to cancelling a run
# We'll simulate it by marking the message as "cancelled"
chat_message = (
db_session.query(ChatMessage).filter(ChatMessage.id == run_id).first()
)
if not chat_message:
raise HTTPException(status_code=404, detail="Run not found")
chat_message.error = "Cancelled"
db_session.commit()
return retrieve_run(thread_id, run_id, user, db_session)
@router.get("/threads/{thread_id}/runs")
def list_runs(
thread_id: UUID,
limit: int = 20,
order: Literal["asc", "desc"] = "desc",
after: Optional[str] = None,
before: Optional[str] = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[RunResponse]:
# In DAnswer, we'll treat each message in a chat session as a "run"
chat_messages = get_chat_messages_by_session(
chat_session_id=thread_id,
user_id=user.id if user else None,
db_session=db_session,
)
# Apply pagination
if after:
chat_messages = [msg for msg in chat_messages if str(msg.id) > after]
if before:
chat_messages = [msg for msg in chat_messages if str(msg.id) < before]
# Apply ordering
chat_messages = sorted(
chat_messages, key=lambda msg: msg.time_sent, reverse=(order == "desc")
)
# Apply limit
chat_messages = chat_messages[:limit]
return [
retrieve_run(thread_id, str(msg.id), user, db_session) for msg in chat_messages
]
@router.get("/threads/{thread_id}/runs/{run_id}/steps")
def list_run_steps(
run_id: str,
limit: int = 20,
order: Literal["asc", "desc"] = "desc",
after: Optional[str] = None,
before: Optional[str] = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[dict]: # You may want to create a specific model for run steps
# DAnswer doesn't have an equivalent to run steps
# We'll return an empty list to maintain API compatibility
return []
# Additional helper functions can be added here if needed

View File

@@ -0,0 +1,156 @@
from typing import Optional
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.db.chat import create_chat_session
from danswer.db.chat import delete_chat_session
from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_chat_sessions_by_user
from danswer.db.chat import update_chat_session
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.server.query_and_chat.models import ChatSessionDetails
from danswer.server.query_and_chat.models import ChatSessionsResponse
router = APIRouter(prefix="/threads")
# Models
class Thread(BaseModel):
id: UUID
object: str = "thread"
created_at: int
metadata: Optional[dict[str, str]] = None
class CreateThreadRequest(BaseModel):
messages: Optional[list[dict]] = None
metadata: Optional[dict[str, str]] = None
class ModifyThreadRequest(BaseModel):
metadata: Optional[dict[str, str]] = None
# API Endpoints
@router.post("")
def create_thread(
request: CreateThreadRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Thread:
user_id = user.id if user else None
new_chat_session = create_chat_session(
db_session=db_session,
description="", # Leave the naming till later to prevent delay
user_id=user_id,
persona_id=0,
)
return Thread(
id=new_chat_session.id,
created_at=int(new_chat_session.time_created.timestamp()),
metadata=request.metadata,
)
@router.get("/{thread_id}")
def retrieve_thread(
thread_id: UUID,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Thread:
user_id = user.id if user else None
try:
chat_session = get_chat_session_by_id(
chat_session_id=thread_id,
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Thread not found")
return Thread(
id=chat_session.id,
created_at=int(chat_session.time_created.timestamp()),
metadata=None, # Assuming we don't store metadata in our current implementation
)
@router.post("/{thread_id}")
def modify_thread(
thread_id: UUID,
request: ModifyThreadRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Thread:
user_id = user.id if user else None
try:
chat_session = update_chat_session(
db_session=db_session,
user_id=user_id,
chat_session_id=thread_id,
description=None, # Not updating description
sharing_status=None, # Not updating sharing status
)
except ValueError:
raise HTTPException(status_code=404, detail="Thread not found")
return Thread(
id=chat_session.id,
created_at=int(chat_session.time_created.timestamp()),
metadata=request.metadata,
)
@router.delete("/{thread_id}")
def delete_thread(
thread_id: UUID,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> dict:
user_id = user.id if user else None
try:
delete_chat_session(
user_id=user_id,
chat_session_id=thread_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Thread not found")
return {"id": str(thread_id), "object": "thread.deleted", "deleted": True}
@router.get("")
def list_threads(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChatSessionsResponse:
user_id = user.id if user else None
chat_sessions = get_chat_sessions_by_user(
user_id=user_id,
deleted=False,
db_session=db_session,
)
return ChatSessionsResponse(
sessions=[
ChatSessionDetails(
id=chat.id,
name=chat.description,
persona_id=chat.persona_id,
time_created=chat.time_created.isoformat(),
shared_status=chat.shared_status,
folder_id=chat.folder_id,
current_alternate_model=chat.current_alternate_model,
)
for chat in chat_sessions
]
)

View File

@@ -347,7 +347,6 @@ def handle_new_chat_message(
for packet in stream_chat_message(
new_msg_req=chat_message_req,
user=user,
use_existing_user_message=chat_message_req.use_existing_user_message,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
@@ -359,7 +358,7 @@ def handle_new_chat_message(
yield json.dumps(packet) if isinstance(packet, dict) else packet
except Exception as e:
logger.exception(f"Error in chat message streaming: {e}")
logger.exception("Error in chat message streaming")
yield json.dumps({"error": str(e)})
finally:

View File

@@ -108,6 +108,9 @@ class CreateChatMessageRequest(ChunkContext):
# used for seeded chats to kick off the generation of an AI answer
use_existing_user_message: bool = False
# used for "OpenAI Assistants API"
existing_assistant_message_id: int | None = None
# forces the LLM to return a structured response, see
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None

View File

@@ -279,7 +279,7 @@ def get_answer_with_quote(
):
yield json.dumps(packet) if isinstance(packet, dict) else packet
except Exception as e:
logger.exception(f"Error in search answer streaming: {e}")
logger.exception("Error in search answer streaming")
yield json.dumps({"error": str(e)})
return StreamingResponse(stream_generator(), media_type="application/json")

View File

@@ -18,9 +18,9 @@ from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
from danswer.db.models import TokenRateLimit
from danswer.db.models import User
from danswer.db.token_limit import fetch_all_global_token_rate_limits
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

View File

@@ -5,13 +5,13 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.token_limit import delete_token_rate_limit
from danswer.db.token_limit import fetch_all_global_token_rate_limits
from danswer.db.token_limit import insert_global_token_rate_limit
from danswer.db.token_limit import update_token_rate_limit
from danswer.server.query_and_chat.token_limit import any_rate_limit_exists
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
from danswer.server.token_rate_limits.models import TokenRateLimitDisplay
from ee.danswer.db.token_limit import delete_token_rate_limit
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
from ee.danswer.db.token_limit import insert_global_token_rate_limit
from ee.danswer.db.token_limit import update_token_rate_limit
router = APIRouter(prefix="/admin/token-rate-limits")

View File

@@ -0,0 +1,255 @@
from typing import cast
from uuid import UUID
from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
from danswer.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import CitationConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import LLMConfig
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.search.enums import LLMEvaluationType
from danswer.search.models import InferenceSection
from danswer.search.models import RetrievalDetails
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.models import DynamicSchemaInfo
from danswer.tools.tool import Tool
from danswer.tools.tool_implementations.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from danswer.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from danswer.tools.tool_implementations.search.search_tool import SearchTool
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.headers import header_dict_to_header_list
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
"""Helper function to get image generation LLM config based on available providers"""
if llm and llm.config.api_key and llm.config.model_provider == "openai":
return LLMConfig(
model_provider=llm.config.model_provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=llm.config.api_key,
api_base=llm.config.api_base,
api_version=llm.config.api_version,
)
if llm.config.model_provider == "azure" and AZURE_DALLE_API_KEY is not None:
return LLMConfig(
model_provider="azure",
model_name=f"azure/{AZURE_DALLE_DEPLOYMENT_NAME}",
temperature=GEN_AI_TEMPERATURE,
api_key=AZURE_DALLE_API_KEY,
api_base=AZURE_DALLE_API_BASE,
api_version=AZURE_DALLE_API_VERSION,
)
# Fallback to checking for OpenAI provider in database
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError("Image generation tool requires an OpenAI API key")
return LLMConfig(
model_provider=openai_provider.provider,
model_name="dall-e-3",
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
api_version=openai_provider.api_version,
)
class SearchToolConfig(BaseModel):
answer_style_config: AnswerStyleConfig = Field(
default_factory=lambda: AnswerStyleConfig(citation_config=CitationConfig())
)
document_pruning_config: DocumentPruningConfig = Field(
default_factory=DocumentPruningConfig
)
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
selected_sections: list[InferenceSection] | None = None
chunks_above: int = 0
chunks_below: int = 0
full_doc: bool = False
latest_query_files: list[InMemoryChatFile] | None = None
class InternetSearchToolConfig(BaseModel):
answer_style_config: AnswerStyleConfig = Field(
default_factory=lambda: AnswerStyleConfig(
citation_config=CitationConfig(all_docs_useful=True)
)
)
class ImageGenerationToolConfig(BaseModel):
additional_headers: dict[str, str] | None = None
class CustomToolConfig(BaseModel):
chat_session_id: UUID | None = None
message_id: int | None = None
additional_headers: dict[str, str] | None = None
def construct_tools(
persona: Persona,
prompt_config: PromptConfig,
db_session: Session,
user: User | None,
llm: LLM,
fast_llm: LLM,
search_tool_config: SearchToolConfig | None = None,
internet_search_tool_config: InternetSearchToolConfig | None = None,
image_generation_tool_config: ImageGenerationToolConfig | None = None,
custom_tool_config: CustomToolConfig | None = None,
) -> dict[int, list[Tool]]:
"""Constructs tools based on persona configuration and available APIs"""
tool_dict: dict[int, list[Tool]] = {}
for db_tool_model in persona.tools:
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
# Handle Search Tool
if tool_cls.__name__ == SearchTool.__name__:
if not search_tool_config:
search_tool_config = SearchToolConfig()
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=search_tool_config.retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=search_tool_config.document_pruning_config,
answer_style_config=search_tool_config.answer_style_config,
selected_sections=search_tool_config.selected_sections,
chunks_above=search_tool_config.chunks_above,
chunks_below=search_tool_config.chunks_below,
full_doc=search_tool_config.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
)
tool_dict[db_tool_model.id] = [search_tool]
# Handle Image Generation Tool
elif tool_cls.__name__ == ImageGenerationTool.__name__:
if not image_generation_tool_config:
image_generation_tool_config = ImageGenerationToolConfig()
img_generation_llm_config = _get_image_generation_config(
llm, db_session
)
tool_dict[db_tool_model.id] = [
ImageGenerationTool(
api_key=cast(str, img_generation_llm_config.api_key),
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=image_generation_tool_config.additional_headers,
model=img_generation_llm_config.model_name,
)
]
# Handle Internet Search Tool
elif tool_cls.__name__ == InternetSearchTool.__name__:
if not internet_search_tool_config:
internet_search_tool_config = InternetSearchToolConfig()
if not BING_API_KEY:
raise ValueError(
"Internet search tool requires a Bing API key, please contact your Danswer admin to get it added!"
)
tool_dict[db_tool_model.id] = [
InternetSearchTool(
api_key=BING_API_KEY,
answer_style_config=internet_search_tool_config.answer_style_config,
prompt_config=prompt_config,
)
]
# Handle custom tools
elif db_tool_model.openapi_schema:
if not custom_tool_config:
custom_tool_config = CustomToolConfig()
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema_and_headers(
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=custom_tool_config.chat_session_id,
message_id=custom_tool_config.message_id,
),
custom_headers=(db_tool_model.custom_headers or [])
+ (
header_dict_to_header_list(
custom_tool_config.additional_headers or {}
)
),
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# factor in tool definition size when pruning
if search_tool_config:
search_tool_config.document_pruning_config.tool_num_tokens = (
compute_all_tool_tokens(
tools,
get_tokenizer(
model_name=llm.config.model_name,
provider_type=llm.config.model_provider,
),
)
)
search_tool_config.document_pruning_config.using_tool_message = (
explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
)
return tool_dict

View File

@@ -15,7 +15,7 @@ from langchain_core.messages import SystemMessage
from pydantic import BaseModel
from danswer.configs.constants import FileOrigin
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_session_with_default_tenant
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import InMemoryChatFile
@@ -187,7 +187,7 @@ class CustomTool(BaseTool):
def _save_and_get_file_references(
self, file_content: bytes | str, content_type: str
) -> List[str]:
with get_session_with_tenant() as db_session:
with get_session_with_default_tenant() as db_session:
file_store = get_default_file_store(db_session)
file_id = str(uuid.uuid4())
@@ -299,7 +299,7 @@ class CustomTool(BaseTool):
# Load files from storage
files = []
with get_session_with_tenant() as db_session:
with get_session_with_default_tenant() as db_session:
file_store = get_default_file_store(db_session)
for file_id in response.tool_result.file_ids:

View File

@@ -1,5 +1,6 @@
import functools
import importlib
import inspect
from typing import Any
from typing import TypeVar
@@ -119,3 +120,41 @@ def noop_fallback(*args: Any, **kwargs: Any) -> None:
Returns:
None
"""
def fetch_ee_implementation_or_noop(
module: str, attribute: str, noop_return_value: Any = None
) -> Any:
"""
Fetches an EE implementation if EE is enabled, otherwise returns a no-op function.
Raises an exception if EE is enabled but the fetch fails.
Args:
module (str): The name of the module from which to fetch the attribute.
attribute (str): The name of the attribute to fetch from the module.
Returns:
Any: The fetched EE implementation if successful and EE is enabled, otherwise a no-op function.
Raises:
Exception: If EE is enabled but the fetch fails.
"""
if not global_version.is_ee_version():
if inspect.iscoroutinefunction(noop_return_value):
async def async_noop(*args: Any, **kwargs: Any) -> Any:
return await noop_return_value(*args, **kwargs)
return async_noop
else:
def sync_noop(*args: Any, **kwargs: Any) -> Any:
return noop_return_value
return sync_noop
try:
return fetch_versioned_implementation(module, attribute)
except Exception as e:
logger.error(f"Failed to fetch implementation for {module}.{attribute}: {e}")
raise

View File

@@ -4,16 +4,15 @@ from fastapi import Request
from fastapi import status
from sqlalchemy.orm import Session
from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.users import current_admin_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import SUPER_CLOUD_API_KEY
from danswer.configs.app_configs import SUPER_USERS
from danswer.configs.constants import AuthType
from danswer.db.engine import get_session
from danswer.db.api_key import fetch_user_for_api_key
from danswer.db.models import User
from danswer.utils.logger import setup_logger
from ee.danswer.auth.api_key import get_hashed_api_key_from_request
from ee.danswer.db.api_key import fetch_user_for_api_key
from ee.danswer.db.saml import get_saml_account
from ee.danswer.server.seeding import get_seed_config
from ee.danswer.utils.secrets import extract_hashed_cookie
@@ -48,25 +47,6 @@ async def optional_user_(
return user
def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
hashed_api_key = get_hashed_api_key_from_request(request)
if not hashed_api_key:
raise HTTPException(status_code=401, detail="Missing API key")
if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")
return user
def get_default_admin_user_emails_() -> list[str]:
seed_config = get_seed_config()
if seed_config and seed_config.admin_user_emails:

View File

@@ -1,4 +1,7 @@
from danswer.background.celery.apps.primary import celery_app
from danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.db.chat import delete_chat_sessions_older_than
@@ -14,9 +17,6 @@ from ee.danswer.background.celery_utils import (
should_perform_external_group_permissions_check,
)
from ee.danswer.background.task_name_builders import name_chat_ttl_task
from ee.danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from ee.danswer.background.task_name_builders import (
name_sync_external_group_permissions_task,
)

View File

@@ -3,15 +3,15 @@ from datetime import timezone
from sqlalchemy.orm import Session
from danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from danswer.db.enums import AccessType
from danswer.db.models import ConnectorCredentialPair
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.utils.logger import setup_logger
from ee.danswer.background.task_name_builders import name_chat_ttl_task
from ee.danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from ee.danswer.background.task_name_builders import (
name_sync_external_group_permissions_task,
)

View File

@@ -2,12 +2,6 @@ def name_chat_ttl_task(retention_limit_days: int, tenant_id: str | None = None)
return f"chat_ttl_{retention_limit_days}_days"
def name_sync_external_doc_permissions_task(
cc_pair_id: int, tenant_id: str | None = None
) -> str:
return f"sync_external_doc_permissions_task__{cc_pair_id}"
def name_sync_external_group_permissions_task(
cc_pair_id: int, tenant_id: str | None = None
) -> str:

View File

@@ -7,16 +7,6 @@ OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/danswer/configs/saml_config"
#####
# API Key Configs
#####
# refers to the rounds described here: https://passlib.readthedocs.io/en/stable/lib/passlib.hash.sha256_crypt.html
_API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
API_KEY_HASH_ROUNDS = (
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
)
#####
# Auto Permission Sync
#####

View File

@@ -65,64 +65,6 @@ def _add_user_filters(
return stmt.where(where_clause)
def fetch_all_user_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
ordered: bool = True,
) -> Sequence[TokenRateLimit]:
query = select(TokenRateLimit).where(
TokenRateLimit.scope == TokenRateLimitScope.USER
)
if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))
if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())
return db_session.scalars(query).all()
def fetch_all_global_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
ordered: bool = True,
) -> Sequence[TokenRateLimit]:
query = select(TokenRateLimit).where(
TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
)
if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))
if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())
token_rate_limits = db_session.scalars(query).all()
return token_rate_limits
def fetch_user_group_token_rate_limits(
db_session: Session,
group_id: int,
user: User | None = None,
enabled_only: bool = False,
ordered: bool = True,
get_editable: bool = True,
) -> Sequence[TokenRateLimit]:
stmt = select(TokenRateLimit)
stmt = stmt.where(User__UserGroup.user_group_id == group_id)
stmt = _add_user_filters(stmt, user, get_editable)
if enabled_only:
stmt = stmt.where(TokenRateLimit.enabled.is_(True))
if ordered:
stmt = stmt.order_by(TokenRateLimit.created_at.desc())
return db_session.scalars(stmt).all()
def fetch_all_user_group_token_rate_limits_by_group(
db_session: Session,
) -> Sequence[Row[tuple[TokenRateLimit, str]]]:
@@ -138,38 +80,6 @@ def fetch_all_user_group_token_rate_limits_by_group(
return db_session.execute(query).all()
def insert_user_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.USER,
)
db_session.add(token_limit)
db_session.commit()
return token_limit
def insert_global_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.GLOBAL,
)
db_session.add(token_limit)
db_session.commit()
return token_limit
def insert_user_group_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
@@ -193,34 +103,22 @@ def insert_user_group_token_rate_limit(
return token_limit
def update_token_rate_limit(
def fetch_user_group_token_rate_limits(
db_session: Session,
token_rate_limit_id: int,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
if token_limit is None:
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
group_id: int,
user: User | None = None,
enabled_only: bool = False,
ordered: bool = True,
get_editable: bool = True,
) -> Sequence[TokenRateLimit]:
stmt = select(TokenRateLimit)
stmt = stmt.where(User__UserGroup.user_group_id == group_id)
stmt = _add_user_filters(stmt, user, get_editable)
token_limit.enabled = token_rate_limit_settings.enabled
token_limit.token_budget = token_rate_limit_settings.token_budget
token_limit.period_hours = token_rate_limit_settings.period_hours
db_session.commit()
if enabled_only:
stmt = stmt.where(TokenRateLimit.enabled.is_(True))
return token_limit
if ordered:
stmt = stmt.order_by(TokenRateLimit.created_at.desc())
def delete_token_rate_limit(
db_session: Session,
token_rate_limit_id: int,
) -> None:
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
if token_limit is None:
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
db_session.query(TokenRateLimit__UserGroup).filter(
TokenRateLimit__UserGroup.rate_limit_id == token_rate_limit_id
).delete()
db_session.delete(token_limit)
db_session.commit()
return db_session.scalars(stmt).all()

View File

@@ -12,11 +12,11 @@ from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import AuthType
from danswer.main import get_application as get_application_base
from danswer.main import include_router_with_global_prefix_prepended
from danswer.server.api_key.api import router as api_key_router
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from ee.danswer.configs.app_configs import OPENID_CONFIG_URL
from ee.danswer.server.analytics.api import router as analytics_router
from ee.danswer.server.api_key.api import router as api_key_router
from ee.danswer.server.auth_check import check_ee_router_auth
from ee.danswer.server.enterprise_settings.api import (
admin_router as enterprise_settings_admin_router,

View File

@@ -8,9 +8,9 @@ from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from danswer.auth.api_key import extract_tenant_from_api_key_header
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.db.engine import is_valid_schema_name
from ee.danswer.auth.api_key import extract_tenant_from_api_key_header
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

View File

@@ -12,6 +12,7 @@ from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.api_key import is_api_key_email_address
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
@@ -20,12 +21,11 @@ from danswer.db.models import TokenRateLimit__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.db.models import UserGroup
from danswer.db.token_limit import fetch_all_user_token_rate_limits
from danswer.server.query_and_chat.token_limit import _get_cutoff_time
from danswer.server.query_and_chat.token_limit import _is_rate_limited
from danswer.server.query_and_chat.token_limit import _user_is_rate_limited_by_global
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from ee.danswer.db.api_key import is_api_key_email_address
from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits
def _check_token_rate_limits(user: User | None, tenant_id: str | None) -> None:

View File

@@ -7,7 +7,6 @@ from fastapi import Response
from danswer.auth.users import auth_backend
from danswer.auth.users import current_admin_user
from danswer.auth.users import get_jwt_strategy
from danswer.auth.users import get_tenant_id_for_email
from danswer.auth.users import User
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.db.engine import get_session_with_tenant
@@ -15,7 +14,6 @@ from danswer.db.notification import create_notification
from danswer.db.users import get_user_by_email
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.setup import setup_danswer
from danswer.utils.logger import setup_logger
from ee.danswer.auth.users import current_cloud_superuser
from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY
@@ -23,15 +21,9 @@ from ee.danswer.server.tenants.access import control_plane_dep
from ee.danswer.server.tenants.billing import fetch_billing_information
from ee.danswer.server.tenants.billing import fetch_tenant_stripe_information
from ee.danswer.server.tenants.models import BillingInformation
from ee.danswer.server.tenants.models import CreateTenantRequest
from ee.danswer.server.tenants.models import ImpersonateRequest
from ee.danswer.server.tenants.models import ProductGatingRequest
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import configure_default_api_keys
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
from shared_configs.configs import MULTI_TENANT
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
stripe.api_key = STRIPE_SECRET_KEY
@@ -40,52 +32,6 @@ logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/create")
def create_tenant(
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
) -> dict[str, str]:
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
tenant_id = create_tenant_request.tenant_id
email = create_tenant_request.initial_admin_email
token = None
if user_owns_a_tenant(email):
raise HTTPException(
status_code=409, detail="User already belongs to an organization"
)
try:
if not ensure_schema_exists(tenant_id):
logger.info(f"Created schema for tenant {tenant_id}")
else:
logger.info(f"Schema already exists for tenant {tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
run_alembic_migrations(tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
setup_danswer(db_session, tenant_id)
configure_default_api_keys(db_session)
add_users_to_tenant([email], tenant_id)
return {
"status": "success",
"message": f"Tenant {tenant_id} created successfully",
}
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Failed to create tenant: {str(e)}"
)
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)

View File

@@ -33,3 +33,8 @@ class CheckoutSessionCreationResponse(BaseModel):
class ImpersonateRequest(BaseModel):
email: str
class TenantCreationPayload(BaseModel):
tenant_id: str
email: str

View File

@@ -1,145 +1,210 @@
import os
from types import SimpleNamespace
import asyncio
import logging
import uuid
from sqlalchemy import text
import aiohttp # Async HTTP client
from fastapi import HTTPException
from sqlalchemy.orm import Session
from sqlalchemy.schema import CreateSchema
from alembic import command
from alembic.config import Config
from danswer.db.engine import build_connection_string
from danswer.auth.users import exceptions
from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_cloud_embedding_provider
from danswer.db.llm import upsert_llm_provider
from danswer.db.models import UserTenantMapping
from danswer.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
from danswer.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
from danswer.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
from danswer.llm.llm_provider_options import OPENAI_PROVIDER_NAME
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.utils.logger import setup_logger
from danswer.setup import setup_danswer
from ee.danswer.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.danswer.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.danswer.server.tenants.access import generate_data_plane_token
from ee.danswer.server.tenants.models import TenantCreationPayload
from ee.danswer.server.tenants.schema_management import create_schema_if_not_exists
from ee.danswer.server.tenants.schema_management import drop_schema
from ee.danswer.server.tenants.schema_management import run_alembic_migrations
from ee.danswer.server.tenants.user_mapping import add_users_to_tenant
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
from ee.danswer.server.tenants.user_mapping import user_owns_a_tenant
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.enums import EmbeddingProvider
logger = setup_logger()
logger = logging.getLogger(__name__)
def run_alembic_migrations(schema_name: str) -> None:
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
async def get_or_create_tenant_id(email: str) -> str:
"""Get existing tenant ID for an email or create a new tenant if none exists."""
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))
alembic_ini_path = os.path.join(root_dir, "alembic.ini")
tenant_id = get_tenant_id_for_email(email)
except exceptions.UserNotExists:
# If tenant does not exist and in Multi tenant mode, provision a new tenant
try:
tenant_id = await create_tenant(email)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
# Configure Alembic
alembic_cfg = Config(alembic_ini_path)
alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string())
alembic_cfg.set_main_option(
"script_location", os.path.join(root_dir, "alembic")
if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)
# Ensure that logging isn't broken
alembic_cfg.attributes["configure_logger"] = False
return tenant_id
# Mimic command-line options by adding 'cmd_opts' to the config
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
# Run migrations programmatically
command.upgrade(alembic_cfg, "head")
async def create_tenant(email: str) -> str:
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
try:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane
await notify_control_plane(tenant_id, email)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
await rollback_tenant_provisioning(tenant_id)
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
return tenant_id
# Run migrations programmatically
logger.info(
f"Alembic migrations completed successfully for schema: {schema_name}"
async def provision_tenant(tenant_id: str, email: str) -> None:
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
if user_owns_a_tenant(email):
raise HTTPException(
status_code=409, detail="User already belongs to an organization"
)
logger.info(f"Provisioning tenant: {tenant_id}")
token = None
try:
if not create_schema_if_not_exists(tenant_id):
logger.info(f"Created schema for tenant {tenant_id}")
else:
logger.info(f"Schema already exists for tenant {tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Await the Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
setup_danswer(db_session, tenant_id)
configure_default_api_keys(db_session)
add_users_to_tenant([email], tenant_id)
except Exception as e:
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
raise
logger.exception(f"Failed to create tenant {tenant_id}")
raise HTTPException(
status_code=500, detail=f"Failed to create tenant: {str(e)}"
)
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def notify_control_plane(tenant_id: str, email: str) -> None:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
payload = TenantCreationPayload(tenant_id=tenant_id, email=email)
async with aiohttp.ClientSession() as session:
async with session.post(
f"{CONTROL_PLANE_API_BASE_URL}/tenants/create",
headers=headers,
json=payload.model_dump(),
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Control plane tenant creation failed: {error_text}")
raise Exception(
f"Failed to create tenant on control plane: {error_text}"
)
async def rollback_tenant_provisioning(tenant_id: str) -> None:
# Logic to rollback tenant provisioning on data plane
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
try:
# Drop the tenant's schema to rollback provisioning
drop_schema(tenant_id)
# Remove tenant mapping
with Session(get_sqlalchemy_engine()) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
except Exception as e:
logger.error(f"Failed to rollback tenant provisioning: {e}")
def configure_default_api_keys(db_session: Session) -> None:
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider="OpenAI",
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4o",
)
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
provider="Anthropic",
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-5-sonnet-20240620",
)
upsert_llm_provider(open_provider, db_session)
upsert_llm_provider(anthropic_provider, db_session)
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
provider_type=EmbeddingProvider.COHERE,
api_key=COHERE_DEFAULT_API_KEY,
)
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
def ensure_schema_exists(tenant_id: str) -> bool:
with Session(get_sqlalchemy_engine()) as db_session:
with db_session.begin():
result = db_session.execute(
text(
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"
),
{"schema_name": tenant_id},
)
schema_exists = result.scalar() is not None
if not schema_exists:
stmt = CreateSchema(tenant_id)
db_session.execute(stmt)
return True
return False
# For now, we're implementing a primitive mapping between users and tenants.
# This function is only used to determine a user's relationship to a tenant upon creation (implying ownership).
def user_owns_a_tenant(email: str) -> bool:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
result = (
db_session.query(UserTenantMapping)
.filter(UserTenantMapping.email == email)
.first()
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4",
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
)
return result is not None
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
full_provider = upsert_llm_provider(open_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.exception(f"Failed to add users to tenant {tenant_id}: {str(e)}")
db_session.commit()
logger.error(f"Failed to configure OpenAI provider: {e}")
else:
logger.error(
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
)
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
if ANTHROPIC_DEFAULT_API_KEY:
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-5-sonnet-20241022",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
)
try:
mappings_to_delete = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email.in_(emails),
UserTenantMapping.tenant_id == tenant_id,
)
.all()
)
for mapping in mappings_to_delete:
db_session.delete(mapping)
db_session.commit()
full_provider = upsert_llm_provider(anthropic_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.exception(
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
)
db_session.rollback()
logger.error(f"Failed to configure Anthropic provider: {e}")
else:
logger.error(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
)
if COHERE_DEFAULT_API_KEY:
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
provider_type=EmbeddingProvider.COHERE,
api_key=COHERE_DEFAULT_API_KEY,
)
try:
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
except Exception as e:
logger.error(f"Failed to configure Cohere embedding provider: {e}")
else:
logger.error(
"COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration"
)

View File

@@ -0,0 +1,76 @@
import logging
import os
from types import SimpleNamespace
from sqlalchemy import text
from sqlalchemy.orm import Session
from sqlalchemy.schema import CreateSchema
from alembic import command
from alembic.config import Config
from danswer.db.engine import build_connection_string
from danswer.db.engine import get_sqlalchemy_engine
logger = logging.getLogger(__name__)
def run_alembic_migrations(schema_name: str) -> None:
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))
alembic_ini_path = os.path.join(root_dir, "alembic.ini")
# Configure Alembic
alembic_cfg = Config(alembic_ini_path)
alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string())
alembic_cfg.set_main_option(
"script_location", os.path.join(root_dir, "alembic")
)
# Ensure that logging isn't broken
alembic_cfg.attributes["configure_logger"] = False
# Mimic command-line options by adding 'cmd_opts' to the config
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
# Run migrations programmatically
command.upgrade(alembic_cfg, "head")
# Run migrations programmatically
logger.info(
f"Alembic migrations completed successfully for schema: {schema_name}"
)
except Exception as e:
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
raise
def create_schema_if_not_exists(tenant_id: str) -> bool:
with Session(get_sqlalchemy_engine()) as db_session:
with db_session.begin():
result = db_session.execute(
text(
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"
),
{"schema_name": tenant_id},
)
schema_exists = result.scalar() is not None
if not schema_exists:
stmt = CreateSchema(tenant_id)
db_session.execute(stmt)
return True
return False
def drop_schema(tenant_id: str) -> None:
if not tenant_id.isidentifier():
raise ValueError("Invalid tenant_id.")
with get_sqlalchemy_engine().connect() as connection:
connection.execute(
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
{"schema_name": tenant_id},
)

View File

@@ -0,0 +1,70 @@
import logging
from fastapi_users import exceptions
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import UserTenantMapping
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = logging.getLogger(__name__)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
def user_owns_a_tenant(email: str) -> bool:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
result = (
db_session.query(UserTenantMapping)
.filter(UserTenantMapping.email == email)
.first()
)
return result is not None
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
except Exception:
logger.exception(f"Failed to add users to tenant {tenant_id}")
db_session.commit()
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
mappings_to_delete = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email.in_(emails),
UserTenantMapping.tenant_id == tenant_id,
)
.all()
)
for mapping in mappings_to_delete:
db_session.delete(mapping)
db_session.commit()
except Exception as e:
logger.exception(
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
)
db_session.rollback()

View File

@@ -8,14 +8,14 @@ from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.token_limit import fetch_all_user_token_rate_limits
from danswer.db.token_limit import insert_user_token_rate_limit
from danswer.server.query_and_chat.token_limit import any_rate_limit_exists
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
from danswer.server.token_rate_limits.models import TokenRateLimitDisplay
from ee.danswer.db.token_limit import fetch_all_user_group_token_rate_limits_by_group
from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits
from ee.danswer.db.token_limit import fetch_user_group_token_rate_limits
from ee.danswer.db.token_limit import insert_user_group_token_rate_limit
from ee.danswer.db.token_limit import insert_user_token_rate_limit
router = APIRouter(prefix="/admin/token-rate-limits")

View File

@@ -15,9 +15,9 @@ docker rm danswer_postgres danswer_vespa danswer_redis
# Start the PostgreSQL container with optional volume
echo "Starting PostgreSQL container..."
if [[ -n "$POSTGRES_VOLUME" ]]; then
docker run -p 5432:5432 --name danswer_postgres -e POSTGRES_PASSWORD=password -d -v $POSTGRES_VOLUME:/var/lib/postgresql/data postgres
docker run -p 5432:5432 --name danswer_postgres -e POSTGRES_PASSWORD=password -d -v $POSTGRES_VOLUME:/var/lib/postgresql/data postgres -c max_connections=250
else
docker run -p 5432:5432 --name danswer_postgres -e POSTGRES_PASSWORD=password -d postgres
docker run -p 5432:5432 --name danswer_postgres -e POSTGRES_PASSWORD=password -d postgres -c max_connections=250
fi
# Start the Vespa container with optional volume

View File

@@ -1,4 +1,5 @@
import os
from typing import Any
from typing import List
from urllib.parse import urlparse
@@ -133,6 +134,11 @@ MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
POSTGRES_DEFAULT_SCHEMA = os.environ.get("POSTGRES_DEFAULT_SCHEMA") or "public"
async def async_return_default_schema(*args: Any, **kwargs: Any) -> str:
return POSTGRES_DEFAULT_SCHEMA
# Prefix used for all tenant ids
TENANT_ID_PREFIX = "tenant_"

View File

@@ -3,7 +3,7 @@ from uuid import uuid4
import requests
from danswer.db.models import UserRole
from ee.danswer.server.api_key.models import APIKeyArgs
from danswer.server.api_key.models import APIKeyArgs
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestAPIKey

View File

@@ -13,6 +13,14 @@ from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
DOMAIN = "test.com"
DEFAULT_PASSWORD = "test"
def build_email(name: str) -> str:
return f"{name}@test.com"
class UserManager:
@staticmethod
def create(
@@ -23,9 +31,9 @@ class UserManager:
name = f"test{str(uuid4())}"
if email is None:
email = f"{name}@test.com"
email = build_email(name)
password = "test"
password = DEFAULT_PASSWORD
body = {
"email": email,

View File

@@ -0,0 +1,55 @@
from typing import Optional
from uuid import UUID
import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
@pytest.fixture
def admin_user() -> DATestUser | None:
try:
return UserManager.create("admin_user")
except Exception:
pass
try:
return UserManager.login_as_user(
DATestUser(
id="",
email=build_email("admin_user"),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
)
)
except Exception:
pass
return None
@pytest.fixture
def llm_provider(admin_user: DATestUser | None) -> DATestLLMProvider:
return LLMProviderManager.create(user_performing_action=admin_user)
@pytest.fixture
def thread_id(admin_user: Optional[DATestUser]) -> UUID:
# Create a thread to use in the tests
response = requests.post(
f"{BASE_URL}/threads", # Updated endpoint path
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
return UUID(response.json()["id"])

View File

@@ -0,0 +1,151 @@
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
ASSISTANTS_URL = f"{API_SERVER_URL}/openai-assistants/assistants"
def test_create_assistant(admin_user: DATestUser | None) -> None:
response = requests.post(
ASSISTANTS_URL,
json={
"model": "gpt-3.5-turbo",
"name": "Test Assistant",
"description": "A test assistant",
"instructions": "You are a helpful assistant.",
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["name"] == "Test Assistant"
assert data["description"] == "A test assistant"
assert data["model"] == "gpt-3.5-turbo"
assert data["instructions"] == "You are a helpful assistant."
def test_retrieve_assistant(admin_user: DATestUser | None) -> None:
# First, create an assistant
create_response = requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": "Retrieve Test"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
assistant_id = create_response.json()["id"]
# Now, retrieve the assistant
response = requests.get(
f"{ASSISTANTS_URL}/{assistant_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["id"] == assistant_id
assert data["name"] == "Retrieve Test"
def test_modify_assistant(admin_user: DATestUser | None) -> None:
# First, create an assistant
create_response = requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": "Modify Test"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
assistant_id = create_response.json()["id"]
# Now, modify the assistant
response = requests.post(
f"{ASSISTANTS_URL}/{assistant_id}",
json={"name": "Modified Assistant", "instructions": "New instructions"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["id"] == assistant_id
assert data["name"] == "Modified Assistant"
assert data["instructions"] == "New instructions"
def test_delete_assistant(admin_user: DATestUser | None) -> None:
# First, create an assistant
create_response = requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": "Delete Test"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
assistant_id = create_response.json()["id"]
# Now, delete the assistant
response = requests.delete(
f"{ASSISTANTS_URL}/{assistant_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["id"] == assistant_id
assert data["deleted"] is True
def test_list_assistants(admin_user: DATestUser | None) -> None:
# Create multiple assistants
for i in range(3):
requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": f"List Test {i}"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# Now, list the assistants
response = requests.get(
ASSISTANTS_URL,
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert data["object"] == "list"
assert len(data["data"]) >= 3 # At least the 3 we just created
assert all(assistant["object"] == "assistant" for assistant in data["data"])
def test_list_assistants_pagination(admin_user: DATestUser | None) -> None:
# Create 5 assistants
for i in range(5):
requests.post(
ASSISTANTS_URL,
json={"model": "gpt-3.5-turbo", "name": f"Pagination Test {i}"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# List assistants with limit
response = requests.get(
f"{ASSISTANTS_URL}?limit=2",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 2
assert data["has_more"] is True
# Get next page
before = data["last_id"]
response = requests.get(
f"{ASSISTANTS_URL}?limit=2&before={before}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
data = response.json()
assert len(data["data"]) == 2
def test_assistant_not_found(admin_user: DATestUser | None) -> None:
non_existent_id = -99
response = requests.get(
f"{ASSISTANTS_URL}/{non_existent_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 404

View File

@@ -0,0 +1,133 @@
import uuid
from typing import Optional
import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
BASE_URL = f"{API_SERVER_URL}/openai-assistants/threads"
@pytest.fixture
def thread_id(admin_user: Optional[DATestUser]) -> str:
response = requests.post(
BASE_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
return response.json()["id"]
def test_create_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
response = requests.post(
f"{BASE_URL}/{thread_id}/messages", # URL structure matches API
json={
"role": "user",
"content": "Hello, world!",
"file_ids": [],
"metadata": {"key": "value"},
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert "id" in response_json
assert response_json["thread_id"] == thread_id
assert response_json["role"] == "user"
assert response_json["content"] == [{"type": "text", "text": "Hello, world!"}]
assert response_json["metadata"] == {"key": "value"}
def test_list_messages(admin_user: Optional[DATestUser], thread_id: str) -> None:
# Create a message first
requests.post(
f"{BASE_URL}/{thread_id}/messages",
json={"role": "user", "content": "Test message"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# Now, list the messages
response = requests.get(
f"{BASE_URL}/{thread_id}/messages",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert response_json["object"] == "list"
assert isinstance(response_json["data"], list)
assert len(response_json["data"]) > 0
assert "first_id" in response_json
assert "last_id" in response_json
assert "has_more" in response_json
def test_retrieve_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
# Create a message first
create_response = requests.post(
f"{BASE_URL}/{thread_id}/messages",
json={"role": "user", "content": "Test message"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
message_id = create_response.json()["id"]
# Now, retrieve the message
response = requests.get(
f"{BASE_URL}/{thread_id}/messages/{message_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert response_json["id"] == message_id
assert response_json["thread_id"] == thread_id
assert response_json["role"] == "user"
assert response_json["content"] == [{"type": "text", "text": "Test message"}]
def test_modify_message(admin_user: Optional[DATestUser], thread_id: str) -> None:
# Create a message first
create_response = requests.post(
f"{BASE_URL}/{thread_id}/messages",
json={"role": "user", "content": "Test message"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
message_id = create_response.json()["id"]
# Now, modify the message
response = requests.post(
f"{BASE_URL}/{thread_id}/messages/{message_id}",
json={"metadata": {"new_key": "new_value"}},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert response_json["id"] == message_id
assert response_json["thread_id"] == thread_id
assert response_json["metadata"] == {"new_key": "new_value"}
def test_error_handling(admin_user: Optional[DATestUser]) -> None:
non_existent_thread_id = str(uuid.uuid4())
non_existent_message_id = -99
# Test with non-existent thread
response = requests.post(
f"{BASE_URL}/{non_existent_thread_id}/messages",
json={"role": "user", "content": "Test message"},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 404
# Test with non-existent message
response = requests.get(
f"{BASE_URL}/{non_existent_thread_id}/messages/{non_existent_message_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 404

View File

@@ -0,0 +1,137 @@
from uuid import UUID
import pytest
import requests
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
BASE_URL = f"{API_SERVER_URL}/openai-assistants"
@pytest.fixture
def run_id(admin_user: DATestUser | None, thread_id: UUID) -> str:
"""Create a run and return its ID."""
response = requests.post(
f"{BASE_URL}/threads/{thread_id}/runs",
json={
"assistant_id": 0,
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
return response.json()["id"]
def test_create_run(
admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider
) -> None:
response = requests.post(
f"{BASE_URL}/threads/{thread_id}/runs",
json={
"assistant_id": 0,
"model": "gpt-3.5-turbo",
"instructions": "Test instructions",
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert "id" in response_json
assert response_json["object"] == "thread.run"
assert "created_at" in response_json
assert response_json["assistant_id"] == 0
assert UUID(response_json["thread_id"]) == thread_id
assert response_json["status"] == "queued"
assert response_json["model"] == "gpt-3.5-turbo"
assert response_json["instructions"] == "Test instructions"
def test_retrieve_run(
admin_user: DATestUser | None,
thread_id: UUID,
run_id: str,
llm_provider: DATestLLMProvider,
) -> None:
retrieve_response = requests.get(
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert retrieve_response.status_code == 200
response_json = retrieve_response.json()
assert response_json["id"] == run_id
assert response_json["object"] == "thread.run"
assert "created_at" in response_json
assert UUID(response_json["thread_id"]) == thread_id
def test_cancel_run(
admin_user: DATestUser | None,
thread_id: UUID,
run_id: str,
llm_provider: DATestLLMProvider,
) -> None:
cancel_response = requests.post(
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/cancel",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert cancel_response.status_code == 200
response_json = cancel_response.json()
assert response_json["id"] == run_id
assert response_json["status"] == "cancelled"
def test_list_runs(
admin_user: DATestUser | None, thread_id: UUID, llm_provider: DATestLLMProvider
) -> None:
# Create a few runs
for _ in range(3):
requests.post(
f"{BASE_URL}/threads/{thread_id}/runs",
json={
"assistant_id": 0,
},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# Now, list the runs
list_response = requests.get(
f"{BASE_URL}/threads/{thread_id}/runs",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert list_response.status_code == 200
response_json = list_response.json()
assert isinstance(response_json, list)
assert len(response_json) >= 3
for run in response_json:
assert "id" in run
assert run["object"] == "thread.run"
assert "created_at" in run
assert UUID(run["thread_id"]) == thread_id
assert "status" in run
assert "model" in run
def test_list_run_steps(
admin_user: DATestUser | None,
thread_id: UUID,
run_id: str,
llm_provider: DATestLLMProvider,
) -> None:
steps_response = requests.get(
f"{BASE_URL}/threads/{thread_id}/runs/{run_id}/steps",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert steps_response.status_code == 200
response_json = steps_response.json()
assert isinstance(response_json, list)
# Since DAnswer doesn't have an equivalent to run steps, we expect an empty list
assert len(response_json) == 0

View File

@@ -0,0 +1,132 @@
from uuid import UUID
import requests
from danswer.db.models import ChatSessionSharedStatus
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
THREADS_URL = f"{API_SERVER_URL}/openai-assistants/threads"
def test_create_thread(admin_user: DATestUser | None) -> None:
response = requests.post(
THREADS_URL,
json={"messages": None, "metadata": {"key": "value"}},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert response.status_code == 200
response_json = response.json()
assert "id" in response_json
assert response_json["object"] == "thread"
assert "created_at" in response_json
assert response_json["metadata"] == {"key": "value"}
def test_retrieve_thread(admin_user: DATestUser | None) -> None:
# First, create a thread
create_response = requests.post(
THREADS_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
thread_id = create_response.json()["id"]
# Now, retrieve the thread
retrieve_response = requests.get(
f"{THREADS_URL}/{thread_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert retrieve_response.status_code == 200
response_json = retrieve_response.json()
assert response_json["id"] == thread_id
assert response_json["object"] == "thread"
assert "created_at" in response_json
def test_modify_thread(admin_user: DATestUser | None) -> None:
# First, create a thread
create_response = requests.post(
THREADS_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
thread_id = create_response.json()["id"]
# Now, modify the thread
modify_response = requests.post(
f"{THREADS_URL}/{thread_id}",
json={"metadata": {"new_key": "new_value"}},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert modify_response.status_code == 200
response_json = modify_response.json()
assert response_json["id"] == thread_id
assert response_json["metadata"] == {"new_key": "new_value"}
def test_delete_thread(admin_user: DATestUser | None) -> None:
# First, create a thread
create_response = requests.post(
THREADS_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert create_response.status_code == 200
thread_id = create_response.json()["id"]
# Now, delete the thread
delete_response = requests.delete(
f"{THREADS_URL}/{thread_id}",
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert delete_response.status_code == 200
response_json = delete_response.json()
assert response_json["id"] == thread_id
assert response_json["object"] == "thread.deleted"
assert response_json["deleted"] is True
def test_list_threads(admin_user: DATestUser | None) -> None:
# Create a few threads
for _ in range(3):
requests.post(
THREADS_URL,
json={},
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
# Now, list the threads
list_response = requests.get(
THREADS_URL,
headers=admin_user.headers if admin_user else GENERAL_HEADERS,
)
assert list_response.status_code == 200
response_json = list_response.json()
assert "sessions" in response_json
assert len(response_json["sessions"]) >= 3
for session in response_json["sessions"]:
assert "id" in session
assert "name" in session
assert "persona_id" in session
assert "time_created" in session
assert "shared_status" in session
assert "folder_id" in session
assert "current_alternate_model" in session
# Validate UUID
UUID(session["id"])
# Validate shared_status
assert session["shared_status"] in [
status.value for status in ChatSessionSharedStatus
]

View File

@@ -29,6 +29,78 @@ from tests.integration.common_utils.test_models import DATestUserGroup
from tests.integration.common_utils.vespa import vespa_fixture
# def test_connector_creation(reset: None) -> None:
# # Creating an admin user (first user created is automatically an admin)
# admin_user: DATestUser = UserManager.create(name="admin_user")
# # create connectors
# cc_pair_1 = CCPairManager.create_from_scratch(
# source=DocumentSource.INGESTION_API,
# user_performing_action=admin_user,
# )
# cc_pair_info = CCPairManager.get_single(
# cc_pair_1.id, user_performing_action=admin_user
# )
# assert cc_pair_info
# assert cc_pair_info.creator
# assert str(cc_pair_info.creator) == admin_user.id
# assert cc_pair_info.creator_email == admin_user.email
# TODO(rkuo): will enable this once i have credentials on github
# def test_overlapping_connector_creation(reset: None) -> None:
# # Creating an admin user (first user created is automatically an admin)
# admin_user: DATestUser = UserManager.create(name="admin_user")
# config = {
# "wiki_base": os.environ["CONFLUENCE_TEST_SPACE_URL"],
# "space": os.environ["CONFLUENCE_TEST_SPACE"],
# "is_cloud": True,
# "page_id": "",
# }
# credential = {
# "confluence_username": os.environ["CONFLUENCE_USER_NAME"],
# "confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
# }
# # store the time before we create the connector so that we know after
# # when the indexing should have started
# now = datetime.now(timezone.utc)
# # create connector
# cc_pair_1 = CCPairManager.create_from_scratch(
# source=DocumentSource.CONFLUENCE,
# connector_specific_config=config,
# credential_json=credential,
# user_performing_action=admin_user,
# )
# CCPairManager.wait_for_indexing(
# cc_pair_1, now, timeout=60, user_performing_action=admin_user
# )
# cc_pair_2 = CCPairManager.create_from_scratch(
# source=DocumentSource.CONFLUENCE,
# connector_specific_config=config,
# credential_json=credential,
# user_performing_action=admin_user,
# )
# CCPairManager.wait_for_indexing(
# cc_pair_2, now, timeout=60, user_performing_action=admin_user
# )
# info_1 = CCPairManager.get_single(cc_pair_1.id)
# assert info_1
# info_2 = CCPairManager.get_single(cc_pair_2.id)
# assert info_2
# assert info_1.num_docs_indexed == info_2.num_docs_indexed
def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

10
ct.yaml
View File

@@ -1,12 +1,18 @@
# See https://github.com/helm/chart-testing#configuration
# still have to specify this on the command line for list-changed
chart-dirs:
- deployment/helm/charts
# must be kept in sync with Chart.yaml
chart-repos:
- vespa=https://unoplat.github.io/vespa-helm-charts
- vespa=https://danswer-ai.github.io/vespa-helm-charts
- postgresql=https://charts.bitnami.com/bitnami
helm-extra-args: --timeout 600s
helm-extra-args: --debug --timeout 600s
# nginx appears to not work on kind, likely due to lack of loadbalancer support
# helm-extra-set-args also only works on the command line, not in this yaml
# helm-extra-set-args: --set=nginx.enabled=false
validate-maintainers: false

View File

@@ -0,0 +1,13 @@
apiVersion: keda.sh/v1alpha1
kind: TriggerAuthentication
metadata:
name: celery-worker-auth
namespace: danswer
spec:
secretTargetRef:
- parameter: host
name: keda-redis-secret
key: host
- parameter: password
name: keda-redis-secret
key: password

View File

@@ -0,0 +1,46 @@
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:
name: celery-worker-indexing-scaledobject
namespace: danswer
labels:
app: celery-worker-indexing
spec:
scaleTargetRef:
name: celery-worker-indexing
minReplicaCount: 1
maxReplicaCount: 10
triggers:
- type: redis
metadata:
sslEnabled: "true"
host: "{host}"
port: "6379"
enableTLS: "true"
listName: connector_indexing
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
sslEnabled: "true"
host: "{host}"
port: "6379"
enableTLS: "true"
listName: connector_indexing:2
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
sslEnabled: "true"
host: "{host}"
port: "6379"
enableTLS: "true"
listName: connector_indexing:3
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth

View File

@@ -0,0 +1,63 @@
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:
name: celery-worker-light-scaledobject
namespace: danswer
labels:
app: celery-worker-light
spec:
scaleTargetRef:
name: celery-worker-light
minReplicaCount: 1
maxReplicaCount: 20
triggers:
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: vespa_metadata_sync
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: vespa_metadata_sync:2
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: vespa_metadata_sync:3
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: connector_deletion
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: connector_deletion:2
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth

View File

@@ -0,0 +1,76 @@
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:
name: celery-worker-primary-scaledobject
namespace: danswer
labels:
app: celery-worker-primary
spec:
scaleTargetRef:
name: celery-worker-primary
pollingInterval: 15 # Check every 15 seconds
cooldownPeriod: 30 # Wait 30 seconds before scaling down
minReplicaCount: 1
maxReplicaCount: 1
triggers:
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: celery
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: celery:1
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: celery:2
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: celery:3
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: periodic_tasks
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: periodic_tasks:2
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth

Some files were not shown because too many files have changed in this diff Show More