mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-01 13:45:44 +00:00
Compare commits
8 Commits
nit
...
new_seq_to
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
59e9a33b30 | ||
|
|
6e60437c56 | ||
|
|
9cde51f1a2 | ||
|
|
8b8952f117 | ||
|
|
dc01eea610 | ||
|
|
c89d8318c0 | ||
|
|
3f2d6557dc | ||
|
|
b3818877af |
@@ -3,61 +3,61 @@ name: Build and Push Backend Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-backend-cloud' || 'danswer/danswer-backend' }}
|
||||
REGISTRY_IMAGE: danswer/danswer-backend
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# TODO: investigate a matrix build like the web container
|
||||
# TODO: investigate a matrix build like the web container
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Install build-essential
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
- name: Install build-essential
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
trivyignores: ./backend/.trivyignore
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
trivyignores: ./backend/.trivyignore
|
||||
|
||||
@@ -4,12 +4,12 @@ name: Build and Push Cloud Web Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: danswer/danswer-web-server-cloud
|
||||
REGISTRY_IMAGE: danswer/danswer-cloud-web-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on:
|
||||
@@ -28,11 +28,11 @@ jobs:
|
||||
- name: Prepare
|
||||
run: |
|
||||
platform=${{ matrix.platform }}
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
@@ -41,16 +41,16 @@ jobs:
|
||||
tags: |
|
||||
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
@@ -65,17 +65,17 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
|
||||
- name: Export digest
|
||||
run: |
|
||||
mkdir -p /tmp/digests
|
||||
digest="${{ steps.build.outputs.digest }}"
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
touch "/tmp/digests/${digest#sha256:}"
|
||||
|
||||
- name: Upload digest
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
@@ -95,42 +95,42 @@ jobs:
|
||||
path: /tmp/digests
|
||||
pattern: digests-*
|
||||
merge-multiple: true
|
||||
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
|
||||
- name: Docker meta
|
||||
id: meta
|
||||
uses: docker/metadata-action@v5
|
||||
with:
|
||||
images: ${{ env.REGISTRY_IMAGE }}
|
||||
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
|
||||
- name: Create manifest list and push
|
||||
working-directory: /tmp/digests
|
||||
run: |
|
||||
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
|
||||
|
||||
- name: Inspect image
|
||||
run: |
|
||||
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
@@ -3,53 +3,53 @@ name: Build and Push Model Server Image on Tag
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
- '*'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-model-server-cloud' || 'danswer/danswer-model-server' }}
|
||||
REGISTRY_IMAGE: danswer/danswer-model-server
|
||||
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
|
||||
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Model Server Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
- name: Model Server Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
|
||||
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
|
||||
build-args: |
|
||||
DANSWER_VERSION=${{ github.ref_name }}
|
||||
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
|
||||
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
|
||||
severity: "CRITICAL,HIGH"
|
||||
# trivy has their own rate limiting issues causing this action to flake
|
||||
# we worked around it by hardcoding to different db repos in env
|
||||
# can re-enable when they figure it out
|
||||
# https://github.com/aquasecurity/trivy/discussions/7538
|
||||
# https://github.com/aquasecurity/trivy-action/issues/389
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@master
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
|
||||
severity: 'CRITICAL,HIGH'
|
||||
|
||||
11
.github/workflows/pr-Integration-tests.yml
vendored
11
.github/workflows/pr-Integration-tests.yml
vendored
@@ -210,18 +210,17 @@ jobs:
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
# save before stopping the containers so the logs can be captured
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
|
||||
@@ -1,20 +1,24 @@
|
||||
# This workflow is intentionally disabled while we're still working on it
|
||||
# It's close to ready, but a race condition needs to be fixed with
|
||||
# API server and Vespa startup, and it needs to have a way to build/test against
|
||||
# local containers
|
||||
|
||||
name: Helm - Lint and Test Charts
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
workflow_dispatch: # Allows manual triggering
|
||||
|
||||
|
||||
jobs:
|
||||
helm-chart-check:
|
||||
lint-test:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
|
||||
|
||||
# fetch-depth 0 is required for helm/chart-testing-action
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
uses: actions/checkout@v3
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
@@ -24,7 +28,7 @@ jobs:
|
||||
version: v3.14.4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.11'
|
||||
cache: 'pip'
|
||||
@@ -41,31 +45,24 @@ jobs:
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
id: list-changed
|
||||
run: |
|
||||
echo "default_branch: ${{ github.event.repository.default_branch }}"
|
||||
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
|
||||
echo "list-changed output: $changed"
|
||||
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
|
||||
if [[ -n "$changed" ]]; then
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --config ct.yaml --all
|
||||
# the following would lint only changed charts, but linting isn't expensive
|
||||
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.10.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --helm-extra-set-args="--set=nginx.enabled=false" --debug --config ct.yaml
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
# if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --config ct.yaml
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
@@ -26,8 +26,7 @@ def upgrade() -> None:
|
||||
FROM jsonb_array_elements(starter_messages) elem
|
||||
)
|
||||
WHERE starter_messages IS NOT NULL
|
||||
AND jsonb_typeof(starter_messages) = 'array'
|
||||
"""
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
@@ -42,7 +41,6 @@ def downgrade() -> None:
|
||||
FROM jsonb_array_elements(starter_messages) elem
|
||||
)
|
||||
WHERE starter_messages IS NOT NULL
|
||||
AND jsonb_typeof(starter_messages) = 'array'
|
||||
"""
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
@@ -48,11 +48,11 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import attributes
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.api_key import get_hashed_api_key_from_request
|
||||
from danswer.auth.invited_users import get_invited_users
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRole
|
||||
@@ -75,7 +75,6 @@ from danswer.configs.constants import AuthType
|
||||
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from danswer.db.api_key import fetch_user_for_api_key
|
||||
from danswer.db.auth import get_access_token_db
|
||||
from danswer.db.auth import get_default_admin_user_emails
|
||||
from danswer.db.auth import get_user_count
|
||||
@@ -84,27 +83,24 @@ from danswer.db.auth import SQLAlchemyUserAdminDB
|
||||
from danswer.db.engine import get_async_session_with_tenant
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import UserTenantMapping
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import async_return_default_schema
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
@@ -194,6 +190,20 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_tenant_id_for_email(email: str) -> str:
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
# Implement logic to get tenant_id from the mapping table
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
|
||||
)
|
||||
tenant_id = result.scalar_one_or_none()
|
||||
if tenant_id is None:
|
||||
raise exceptions.UserNotExists()
|
||||
return tenant_id
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
@@ -228,13 +238,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> User:
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user_create.email,
|
||||
)
|
||||
try:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(user_create.email)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User does not belong to an organization"
|
||||
)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
@@ -255,7 +271,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
|
||||
user = None
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
@@ -276,9 +292,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
else:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
return user
|
||||
|
||||
async def oauth_callback(
|
||||
@@ -294,18 +308,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
associate_by_email: bool = False,
|
||||
is_verified_by_default: bool = False,
|
||||
) -> models.UOAP:
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=account_email,
|
||||
)
|
||||
# Get tenant_id from mapping table
|
||||
try:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(account_email)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=401, detail="User not found")
|
||||
|
||||
# Proceed with the tenant context
|
||||
token = None
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
@@ -356,9 +371,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
# Explicitly set the Postgres schema for this session to ensure
|
||||
# OAuth account creation happens in the correct tenant schema
|
||||
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
@@ -438,13 +453,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
email = credentials.username
|
||||
|
||||
# Get tenant_id from mapping table
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=email,
|
||||
)
|
||||
tenant_id = get_tenant_id_for_email(email)
|
||||
if not tenant_id:
|
||||
# User not found in mapping
|
||||
self.password_helper.hash(credentials.password)
|
||||
@@ -468,7 +477,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
has_web_login = attributes.get_attribute(user, "has_web_login")
|
||||
|
||||
if not has_web_login:
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
|
||||
@@ -501,14 +511,7 @@ cookie_transport = CookieTransport(
|
||||
# This strategy is used to add tenant_id to the JWT token
|
||||
class TenantAwareJWTStrategy(JWTStrategy):
|
||||
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning",
|
||||
"get_or_create_tenant_id",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user.email,
|
||||
)
|
||||
|
||||
tenant_id = get_tenant_id_for_email(user.email)
|
||||
data = {
|
||||
"sub": str(user.id),
|
||||
"aud": self.token_audience,
|
||||
@@ -625,12 +628,14 @@ async def double_check_user(
|
||||
return None
|
||||
|
||||
if user is None:
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
@@ -639,7 +644,8 @@ async def double_check_user(
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
@@ -665,13 +671,15 @@ async def current_curator_or_admin_user(
|
||||
return None
|
||||
|
||||
if not user or not hasattr(user, "role"):
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not authenticated or lacks role information.",
|
||||
)
|
||||
|
||||
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
|
||||
if user.role not in allowed_roles:
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not a curator or admin.",
|
||||
)
|
||||
|
||||
@@ -683,7 +691,8 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
|
||||
return None
|
||||
|
||||
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User must be an admin to perform this action.",
|
||||
)
|
||||
|
||||
@@ -876,22 +885,3 @@ def get_oauth_router(
|
||||
return redirect_response
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def api_key_dep(
|
||||
request: Request, db_session: Session = Depends(get_session)
|
||||
) -> User | None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return None
|
||||
|
||||
hashed_api_key = get_hashed_api_key_from_request(request)
|
||||
if not hashed_api_key:
|
||||
raise HTTPException(status_code=401, detail="Missing API key")
|
||||
|
||||
if hashed_api_key:
|
||||
user = fetch_user_for_api_key(hashed_api_key, db_session)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
return user
|
||||
|
||||
@@ -3,7 +3,6 @@ import multiprocessing
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
import sentry_sdk
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
@@ -12,15 +11,11 @@ from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from celery.worker import strategy # type: ignore
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
|
||||
from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from danswer.background.celery.celery_utils import celery_is_worker_primary
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from danswer.redis.redis_connector_delete import RedisConnectorDelete
|
||||
@@ -31,6 +26,7 @@ from danswer.redis.redis_usergroup import RedisUserGroup
|
||||
from danswer.utils.logger import ColoredFormatter
|
||||
from danswer.utils.logger import PlainFormatter
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
|
||||
@@ -143,136 +139,45 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None
|
||||
|
||||
|
||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for redis to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
ready = False
|
||||
time_start = time.monotonic()
|
||||
logger.info("Redis: Readiness probe starting.")
|
||||
logger.info("Redis: Readiness check starting.")
|
||||
while True:
|
||||
try:
|
||||
if r.ping():
|
||||
ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Redis: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
msg = (
|
||||
f"Redis: Readiness check did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
if not ready:
|
||||
msg = (
|
||||
f"Redis: Readiness probe did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
logger.info("Redis: Readiness probe succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def wait_for_db(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for the db to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
ready = False
|
||||
time_start = time.monotonic()
|
||||
logger.info("Database: Readiness probe starting.")
|
||||
while True:
|
||||
try:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(text("SELECT NOW()")).scalar()
|
||||
if result:
|
||||
ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Database: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
if not ready:
|
||||
msg = (
|
||||
f"Database: Readiness probe did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
logger.info("Database: Readiness probe succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
ready = False
|
||||
time_start = time.monotonic()
|
||||
logger.info("Vespa: Readiness probe starting.")
|
||||
while True:
|
||||
try:
|
||||
response = requests.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
|
||||
response.raise_for_status()
|
||||
|
||||
response_dict = response.json()
|
||||
if response_dict["status"]["code"] == "up":
|
||||
ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
if not ready:
|
||||
msg = (
|
||||
f"Vespa: Readiness probe did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
logger.info("Vespa: Readiness probe succeeded. Continuing...")
|
||||
logger.info("Redis: Readiness check succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
|
||||
# Exit early if multi-tenant since primary worker check not needed
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
# Set up variables for waiting on primary worker
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
@@ -12,7 +12,6 @@ from danswer.db.engine import get_all_tenant_ids
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
@@ -120,10 +119,10 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
else:
|
||||
logger.info("Schedule is up to date, no changes needed")
|
||||
|
||||
except (AttributeError, KeyError):
|
||||
logger.exception("Failed to process task configuration")
|
||||
except Exception:
|
||||
logger.exception("Unexpected error updating tenant tasks")
|
||||
except (AttributeError, KeyError) as e:
|
||||
logger.exception(f"Failed to process task configuration: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error updating tenant tasks: {str(e)}")
|
||||
|
||||
def _should_update_schedule(
|
||||
self, current_schedule: dict, new_schedule: dict
|
||||
@@ -144,11 +143,6 @@ def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=2, max_overflow=0)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -61,13 +60,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=4, max_overflow=12)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -61,13 +60,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from danswer.db.engine import SqlEngine
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -60,13 +59,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -75,16 +75,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
|
||||
# Startup checks are not needed in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker. This is unnecessary in the multi tenant scenario
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
def name_sync_external_doc_permissions_task(
|
||||
cc_pair_id: int, tenant_id: str | None = None
|
||||
) -> str:
|
||||
return f"sync_external_doc_permissions_task__{cc_pair_id}"
|
||||
@@ -48,6 +48,7 @@ class QADocsResponse(RetrievalDocs):
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
NEW_RESPONSE = "new_response"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
|
||||
@@ -19,6 +19,7 @@ from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
@@ -137,6 +138,7 @@ def _translate_citations(
|
||||
"""Always cites the first instance of the document_id, assumes the db_docs
|
||||
are sorted in the order displayed in the UI"""
|
||||
doc_id_to_saved_doc_id_map: dict[str, int] = {}
|
||||
|
||||
for db_doc in db_docs:
|
||||
if db_doc.document_id not in doc_id_to_saved_doc_id_map:
|
||||
doc_id_to_saved_doc_id_map[db_doc.document_id] = db_doc.id
|
||||
@@ -687,6 +689,10 @@ def stream_chat_message_objects(
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
tool_name_to_tool_id = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
@@ -729,6 +735,74 @@ def stream_chat_message_objects(
|
||||
tool_result = None
|
||||
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, StreamStopInfo):
|
||||
if packet.stop_reason is not StreamStopReason.NEW_RESPONSE:
|
||||
break
|
||||
db_citations = None
|
||||
|
||||
if reference_db_search_docs:
|
||||
db_citations = _translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
if tool_result is None:
|
||||
tool_call = None
|
||||
else:
|
||||
tool_call = ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=cast(
|
||||
QADocsResponse, qa_docs_response
|
||||
).rephrased_query
|
||||
if qa_docs_response is not None
|
||||
else None,
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=cast(MessageSpecificCitations, db_citations).citation_map
|
||||
if db_citations is not None
|
||||
else None,
|
||||
error=None,
|
||||
tool_call=tool_call,
|
||||
)
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
yield msg_detail_response
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=gen_ai_response_message.id
|
||||
if user_message is not None
|
||||
else gen_ai_response_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=gen_ai_response_message,
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
reference_db_search_docs = None
|
||||
|
||||
if isinstance(packet, ToolResponse):
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
@@ -869,6 +943,8 @@ def stream_chat_message_objects(
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
if answer.llm_answer == "":
|
||||
return
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
|
||||
@@ -493,13 +493,3 @@ JWT_ALGORITHM = "HS256"
|
||||
# Super Users
|
||||
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
|
||||
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
|
||||
#####
|
||||
# API Key Configs
|
||||
#####
|
||||
# refers to the rounds described here: https://passlib.readthedocs.io/en/stable/lib/passlib.hash.sha256_crypt.html
|
||||
_API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
|
||||
API_KEY_HASH_ROUNDS = (
|
||||
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
|
||||
)
|
||||
|
||||
@@ -126,7 +126,6 @@ class DocumentSource(str, Enum):
|
||||
XENFORO = "xenforo"
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
FRESHDESK = "freshdesk"
|
||||
FIREFLIES = "fireflies"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
|
||||
@@ -16,7 +16,6 @@ from danswer.connectors.discourse.connector import DiscourseConnector
|
||||
from danswer.connectors.document360.connector import Document360Connector
|
||||
from danswer.connectors.dropbox.connector import DropboxConnector
|
||||
from danswer.connectors.file.connector import LocalFileConnector
|
||||
from danswer.connectors.fireflies.connector import FirefliesConnector
|
||||
from danswer.connectors.freshdesk.connector import FreshdeskConnector
|
||||
from danswer.connectors.github.connector import GithubConnector
|
||||
from danswer.connectors.gitlab.connector import GitlabConnector
|
||||
@@ -102,7 +101,6 @@ def identify_connector_class(
|
||||
DocumentSource.OCI_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.XENFORO: XenforoConnector,
|
||||
DocumentSource.FRESHDESK: FreshdeskConnector,
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
|
||||
@@ -123,13 +123,9 @@ def _process_file(
|
||||
"filename",
|
||||
"file_display_name",
|
||||
"title",
|
||||
"connector_type",
|
||||
]
|
||||
}
|
||||
|
||||
source_type_str = all_metadata.get("connector_type")
|
||||
source_type = DocumentSource(source_type_str) if source_type_str else None
|
||||
|
||||
p_owner_names = all_metadata.get("primary_owners")
|
||||
s_owner_names = all_metadata.get("secondary_owners")
|
||||
p_owners = (
|
||||
@@ -149,7 +145,7 @@ def _process_file(
|
||||
sections=[
|
||||
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
|
||||
],
|
||||
source=source_type or DocumentSource.FILE,
|
||||
source=DocumentSource.FILE,
|
||||
semantic_identifier=file_display_name,
|
||||
title=title,
|
||||
doc_updated_at=final_time_updated,
|
||||
|
||||
@@ -1,182 +0,0 @@
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_FIREFLIES_ID_PREFIX = "FIREFLIES_"
|
||||
|
||||
_FIREFLIES_API_URL = "https://api.fireflies.ai/graphql"
|
||||
|
||||
_FIREFLIES_TRANSCRIPT_QUERY_SIZE = 50 # Max page size is 50
|
||||
|
||||
_FIREFLIES_API_QUERY = """
|
||||
query Transcripts($fromDate: DateTime, $toDate: DateTime, $limit: Int!, $skip: Int!) {
|
||||
transcripts(fromDate: $fromDate, toDate: $toDate, limit: $limit, skip: $skip) {
|
||||
id
|
||||
title
|
||||
host_email
|
||||
participants
|
||||
date
|
||||
transcript_url
|
||||
sentences {
|
||||
text
|
||||
speaker_name
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
|
||||
def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
meeting_text = ""
|
||||
sentences = transcript.get("sentences", [])
|
||||
if sentences:
|
||||
for sentence in sentences:
|
||||
meeting_text += sentence.get("speaker_name") or "Unknown Speaker"
|
||||
meeting_text += ": " + sentence.get("text", "") + "\n\n"
|
||||
else:
|
||||
return None
|
||||
|
||||
meeting_link = transcript["transcript_url"]
|
||||
|
||||
fireflies_id = _FIREFLIES_ID_PREFIX + transcript["id"]
|
||||
|
||||
meeting_title = transcript["title"] or "No Title"
|
||||
|
||||
meeting_date_unix = transcript["date"]
|
||||
meeting_date = datetime.fromtimestamp(meeting_date_unix / 1000, tz=timezone.utc)
|
||||
|
||||
meeting_host_email = transcript["host_email"]
|
||||
host_email_user_info = [BasicExpertInfo(email=meeting_host_email)]
|
||||
|
||||
meeting_participants_email_list = []
|
||||
for participant in transcript.get("participants", []):
|
||||
if participant != meeting_host_email and participant:
|
||||
meeting_participants_email_list.append(BasicExpertInfo(email=participant))
|
||||
|
||||
return Document(
|
||||
id=fireflies_id,
|
||||
sections=[
|
||||
Section(
|
||||
link=meeting_link,
|
||||
text=meeting_text,
|
||||
)
|
||||
],
|
||||
source=DocumentSource.FIREFLIES,
|
||||
semantic_identifier=meeting_title,
|
||||
metadata={},
|
||||
doc_updated_at=meeting_date,
|
||||
primary_owners=host_email_user_info,
|
||||
secondary_owners=meeting_participants_email_list,
|
||||
)
|
||||
|
||||
|
||||
class FirefliesConnector(PollConnector, LoadConnector):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
|
||||
def load_credentials(self, credentials: dict[str, str]) -> None:
|
||||
api_key = credentials.get("fireflies_api_key")
|
||||
|
||||
if not isinstance(api_key, str):
|
||||
raise ConnectorMissingCredentialError(
|
||||
"The Fireflies API key must be a string"
|
||||
)
|
||||
|
||||
self.api_key = api_key
|
||||
|
||||
return None
|
||||
|
||||
def _fetch_transcripts(
|
||||
self, start_datetime: str | None = None, end_datetime: str | None = None
|
||||
) -> Iterator[List[dict]]:
|
||||
if self.api_key is None:
|
||||
raise ConnectorMissingCredentialError("Missing API key")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.api_key,
|
||||
}
|
||||
|
||||
skip = 0
|
||||
variables: dict[str, int | str] = {
|
||||
"limit": _FIREFLIES_TRANSCRIPT_QUERY_SIZE,
|
||||
}
|
||||
|
||||
if start_datetime:
|
||||
variables["fromDate"] = start_datetime
|
||||
if end_datetime:
|
||||
variables["toDate"] = end_datetime
|
||||
|
||||
while True:
|
||||
variables["skip"] = skip
|
||||
response = requests.post(
|
||||
_FIREFLIES_API_URL,
|
||||
headers=headers,
|
||||
json={"query": _FIREFLIES_API_QUERY, "variables": variables},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
if response.status_code == 204:
|
||||
break
|
||||
|
||||
recieved_transcripts = response.json()
|
||||
parsed_transcripts = recieved_transcripts.get("data", {}).get(
|
||||
"transcripts", []
|
||||
)
|
||||
|
||||
yield parsed_transcripts
|
||||
|
||||
if len(parsed_transcripts) < _FIREFLIES_TRANSCRIPT_QUERY_SIZE:
|
||||
break
|
||||
|
||||
skip += _FIREFLIES_TRANSCRIPT_QUERY_SIZE
|
||||
|
||||
def _process_transcripts(
|
||||
self, start: str | None = None, end: str | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch: List[Document] = []
|
||||
|
||||
for transcript_batch in self._fetch_transcripts(start, end):
|
||||
for transcript in transcript_batch:
|
||||
if doc := _create_doc_from_transcript(transcript):
|
||||
doc_batch.append(doc)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._process_transcripts()
|
||||
|
||||
def poll_source(
|
||||
self, start_unixtime: SecondsSinceUnixEpoch, end_unixtime: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.fromtimestamp(
|
||||
start_unixtime, tz=timezone.utc
|
||||
).strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
end_datetime = datetime.fromtimestamp(end_unixtime, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S.000Z"
|
||||
)
|
||||
|
||||
yield from self._process_transcripts(start_datetime, end_datetime)
|
||||
@@ -77,7 +77,6 @@ from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import SLACK_CHANNEL_ID
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -190,67 +189,59 @@ class SlackbotHandler:
|
||||
continue
|
||||
|
||||
logger.debug(f"Acquired lock for tenant {tenant_id}")
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
|
||||
tenant_id or POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try:
|
||||
logger.debug(
|
||||
f"Setting tenant ID context variable for tenant {tenant_id}"
|
||||
)
|
||||
slack_bot_tokens = fetch_tokens()
|
||||
logger.debug(f"Fetched Slack bot tokens for tenant {tenant_id}")
|
||||
logger.debug(
|
||||
f"Reset tenant ID context variable for tenant {tenant_id}"
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try:
|
||||
logger.debug(
|
||||
f"Setting tenant ID context variable for tenant {tenant_id}"
|
||||
)
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public")
|
||||
slack_bot_tokens = fetch_tokens()
|
||||
logger.debug(f"Fetched Slack bot tokens for tenant {tenant_id}")
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
logger.debug(
|
||||
f"Reset tenant ID context variable for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
if not slack_bot_tokens:
|
||||
logger.debug(
|
||||
f"No Slack bot token found for tenant {tenant_id}"
|
||||
)
|
||||
if tenant_id in self.socket_clients:
|
||||
asyncio.run(self.socket_clients[tenant_id].close())
|
||||
del self.socket_clients[tenant_id]
|
||||
del self.slack_bot_tokens[tenant_id]
|
||||
continue
|
||||
|
||||
if (
|
||||
tenant_id not in self.slack_bot_tokens
|
||||
or slack_bot_tokens != self.slack_bot_tokens[tenant_id]
|
||||
):
|
||||
if tenant_id in self.slack_bot_tokens:
|
||||
logger.info(
|
||||
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
|
||||
)
|
||||
else:
|
||||
search_settings = get_current_search_settings(
|
||||
db_session
|
||||
)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
warm_up_bi_encoder(embedding_model=embedding_model)
|
||||
|
||||
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
|
||||
|
||||
if tenant_id in self.socket_clients:
|
||||
asyncio.run(self.socket_clients[tenant_id].close())
|
||||
|
||||
self.start_socket_client(tenant_id, slack_bot_tokens)
|
||||
|
||||
except KvKeyNotFoundError:
|
||||
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
|
||||
if not slack_bot_tokens:
|
||||
logger.debug(f"No Slack bot token found for tenant {tenant_id}")
|
||||
if tenant_id in self.socket_clients:
|
||||
asyncio.run(self.socket_clients[tenant_id].close())
|
||||
del self.socket_clients[tenant_id]
|
||||
del self.slack_bot_tokens[tenant_id]
|
||||
except Exception as e:
|
||||
logger.exception(f"Error handling tenant {tenant_id}: {e}")
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
continue
|
||||
|
||||
if (
|
||||
tenant_id not in self.slack_bot_tokens
|
||||
or slack_bot_tokens != self.slack_bot_tokens[tenant_id]
|
||||
):
|
||||
if tenant_id in self.slack_bot_tokens:
|
||||
logger.info(
|
||||
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
|
||||
)
|
||||
else:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
warm_up_bi_encoder(embedding_model=embedding_model)
|
||||
|
||||
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
|
||||
|
||||
if tenant_id in self.socket_clients:
|
||||
asyncio.run(self.socket_clients[tenant_id].close())
|
||||
|
||||
self.start_socket_client(tenant_id, slack_bot_tokens)
|
||||
|
||||
except KvKeyNotFoundError:
|
||||
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
|
||||
if tenant_id in self.socket_clients:
|
||||
asyncio.run(self.socket_clients[tenant_id].close())
|
||||
del self.socket_clients[tenant_id]
|
||||
del self.slack_bot_tokens[tenant_id]
|
||||
except Exception as e:
|
||||
logger.exception(f"Error handling tenant {tenant_id}: {e}")
|
||||
|
||||
def send_heartbeats(self) -> None:
|
||||
current_time = int(time.time())
|
||||
|
||||
@@ -14,7 +14,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.invited_users import get_invited_users
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.db.api_key import get_api_key_email_pattern
|
||||
from danswer.db.engine import get_async_session
|
||||
from danswer.db.engine import get_async_session_with_tenant
|
||||
from danswer.db.models import AccessToken
|
||||
@@ -23,6 +22,7 @@ from danswer.db.models import User
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from ee.danswer.db.api_key import get_api_key_email_pattern
|
||||
|
||||
|
||||
def get_default_admin_user_emails() -> list[str]:
|
||||
|
||||
@@ -25,8 +25,8 @@ from danswer.db.models import UserGroup__ConnectorCredentialPair
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.sync_params import check_if_valid_sync_source
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -351,11 +351,7 @@ def add_credential_to_connector(
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
|
||||
if access_type == AccessType.SYNC:
|
||||
if not fetch_ee_implementation_or_noop(
|
||||
"danswer.external_permissions.sync_params",
|
||||
"check_if_valid_sync_source",
|
||||
noop_return_value=True,
|
||||
)(connector.source):
|
||||
if not check_if_valid_sync_source(connector.source):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Connector of type {connector.source} does not support SYNC access type",
|
||||
@@ -442,10 +438,7 @@ def remove_credential_from_connector(
|
||||
)
|
||||
|
||||
if association is not None:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.db.external_perm",
|
||||
"delete_user__ext_group_for_cc_pair__no_commit",
|
||||
)(
|
||||
delete_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=association.id,
|
||||
)
|
||||
|
||||
@@ -323,28 +323,16 @@ async def get_async_session_with_tenant(
|
||||
yield session
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_default_tenant() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Get a database session using the current tenant ID from the context variable.
|
||||
"""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
with get_session_with_tenant(tenant_id) as session:
|
||||
yield session
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Generate a database session for a specific tenant.
|
||||
|
||||
This function:
|
||||
1. Sets the database schema to the specified tenant's schema.
|
||||
2. Preserves the tenant ID across the session.
|
||||
3. Reverts to the previous tenant ID after the session is closed.
|
||||
4. Uses the default schema if no tenant ID is provided.
|
||||
Generate a database session bound to a connection with the appropriate tenant schema set.
|
||||
This preserves the tenant ID across the session and reverts to the previous tenant ID
|
||||
after the session is closed.
|
||||
If tenant ID is set, we save the previous tenant ID from the context var to set
|
||||
after the session is closed. The value `None` evaluates to the default schema.
|
||||
"""
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
@@ -352,9 +340,9 @@ def get_session_with_tenant(
|
||||
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
if tenant_id is None:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
tenant_id = previous_tenant_id
|
||||
else:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
event.listen(engine, "checkout", set_search_path_on_checkout)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
from danswer.db.engine import get_session_with_default_tenant
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.llm import fetch_embedding_provider
|
||||
from danswer.db.models import CloudEmbeddingProvider
|
||||
from danswer.db.models import IndexAttempt
|
||||
@@ -152,7 +152,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
|
||||
|
||||
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
|
||||
if db_session is None:
|
||||
with get_session_with_default_tenant() as db_session:
|
||||
with get_session_with_tenant() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
else:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import TokenRateLimitScope
|
||||
from danswer.db.models import TokenRateLimit
|
||||
from danswer.db.models import TokenRateLimit__UserGroup
|
||||
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
|
||||
|
||||
def fetch_all_user_token_rate_limits(
|
||||
db_session: Session,
|
||||
enabled_only: bool = False,
|
||||
ordered: bool = True,
|
||||
) -> Sequence[TokenRateLimit]:
|
||||
query = select(TokenRateLimit).where(
|
||||
TokenRateLimit.scope == TokenRateLimitScope.USER
|
||||
)
|
||||
|
||||
if enabled_only:
|
||||
query = query.where(TokenRateLimit.enabled.is_(True))
|
||||
|
||||
if ordered:
|
||||
query = query.order_by(TokenRateLimit.created_at.desc())
|
||||
|
||||
return db_session.scalars(query).all()
|
||||
|
||||
|
||||
def fetch_all_global_token_rate_limits(
|
||||
db_session: Session,
|
||||
enabled_only: bool = False,
|
||||
ordered: bool = True,
|
||||
) -> Sequence[TokenRateLimit]:
|
||||
query = select(TokenRateLimit).where(
|
||||
TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
|
||||
)
|
||||
|
||||
if enabled_only:
|
||||
query = query.where(TokenRateLimit.enabled.is_(True))
|
||||
|
||||
if ordered:
|
||||
query = query.order_by(TokenRateLimit.created_at.desc())
|
||||
|
||||
token_rate_limits = db_session.scalars(query).all()
|
||||
return token_rate_limits
|
||||
|
||||
|
||||
def insert_user_token_rate_limit(
|
||||
db_session: Session,
|
||||
token_rate_limit_settings: TokenRateLimitArgs,
|
||||
) -> TokenRateLimit:
|
||||
token_limit = TokenRateLimit(
|
||||
enabled=token_rate_limit_settings.enabled,
|
||||
token_budget=token_rate_limit_settings.token_budget,
|
||||
period_hours=token_rate_limit_settings.period_hours,
|
||||
scope=TokenRateLimitScope.USER,
|
||||
)
|
||||
db_session.add(token_limit)
|
||||
db_session.commit()
|
||||
|
||||
return token_limit
|
||||
|
||||
|
||||
def insert_global_token_rate_limit(
|
||||
db_session: Session,
|
||||
token_rate_limit_settings: TokenRateLimitArgs,
|
||||
) -> TokenRateLimit:
|
||||
token_limit = TokenRateLimit(
|
||||
enabled=token_rate_limit_settings.enabled,
|
||||
token_budget=token_rate_limit_settings.token_budget,
|
||||
period_hours=token_rate_limit_settings.period_hours,
|
||||
scope=TokenRateLimitScope.GLOBAL,
|
||||
)
|
||||
db_session.add(token_limit)
|
||||
db_session.commit()
|
||||
|
||||
return token_limit
|
||||
|
||||
|
||||
def update_token_rate_limit(
|
||||
db_session: Session,
|
||||
token_rate_limit_id: int,
|
||||
token_rate_limit_settings: TokenRateLimitArgs,
|
||||
) -> TokenRateLimit:
|
||||
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
|
||||
if token_limit is None:
|
||||
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
|
||||
|
||||
token_limit.enabled = token_rate_limit_settings.enabled
|
||||
token_limit.token_budget = token_rate_limit_settings.token_budget
|
||||
token_limit.period_hours = token_rate_limit_settings.period_hours
|
||||
db_session.commit()
|
||||
|
||||
return token_limit
|
||||
|
||||
|
||||
def delete_token_rate_limit(
|
||||
db_session: Session,
|
||||
token_rate_limit_id: int,
|
||||
) -> None:
|
||||
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
|
||||
if token_limit is None:
|
||||
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
|
||||
|
||||
db_session.query(TokenRateLimit__UserGroup).filter(
|
||||
TokenRateLimit__UserGroup.rate_limit_id == token_rate_limit_id
|
||||
).delete()
|
||||
|
||||
db_session.delete(token_limit)
|
||||
db_session.commit()
|
||||
@@ -9,6 +9,8 @@ from langchain_core.messages import ToolCall
|
||||
from danswer.chat.models import AnswerQuestionPossibleReturn
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager
|
||||
@@ -118,6 +120,9 @@ class Answer:
|
||||
)
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
self.current_streamed_output: list = []
|
||||
|
||||
self.processing_stream: list = []
|
||||
|
||||
def _get_tools_list(self) -> list[Tool]:
|
||||
if not self.force_use_tool.force_use:
|
||||
@@ -155,6 +160,7 @@ class Answer:
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_handler, DummyAnswerResponseHandler(), self.is_cancelled
|
||||
)
|
||||
|
||||
yield from response_handler_manager.handle_llm_response(
|
||||
iter([dummy_tool_call_chunk])
|
||||
)
|
||||
@@ -165,7 +171,13 @@ class Answer:
|
||||
else:
|
||||
raise RuntimeError("Tool call handler did not return a new LLM call")
|
||||
|
||||
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
|
||||
def _get_response(
|
||||
self,
|
||||
llm_calls: list[LLMCall],
|
||||
check_for_tool_call: bool = False,
|
||||
previously_used_tool: Tool | None = None,
|
||||
previous_tool_response: ToolResponse | None = None,
|
||||
) -> AnswerStream:
|
||||
current_llm_call = llm_calls[-1]
|
||||
|
||||
# handle the case where no decision has to be made; we simply run the tool
|
||||
@@ -231,7 +243,6 @@ class Answer:
|
||||
tool_call_handler, answer_handler, self.is_cancelled
|
||||
)
|
||||
|
||||
# DEBUG: good breakpoint
|
||||
stream = self.llm.stream(
|
||||
prompt=current_llm_call.prompt_builder.build(),
|
||||
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
||||
@@ -242,11 +253,101 @@ class Answer:
|
||||
),
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(stream)
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
tool_call_made = False
|
||||
tool_call_name: str | None = None
|
||||
buffered_packets = []
|
||||
|
||||
tool_response = None
|
||||
for packet in response_handler_manager.handle_llm_response(stream):
|
||||
if isinstance(packet, DanswerAnswerPiece):
|
||||
pass
|
||||
|
||||
if isinstance(packet, ToolResponse):
|
||||
tool_response = packet
|
||||
|
||||
if check_for_tool_call:
|
||||
buffered_packets.append(packet)
|
||||
if isinstance(packet, ToolCallKickoff):
|
||||
# if has_streamed_text and not has_completed:
|
||||
# yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
|
||||
# has_completed = True
|
||||
|
||||
tool_call_name = packet.tool_name
|
||||
tool_call_made = True
|
||||
for buffered_packet in buffered_packets:
|
||||
yield buffered_packet
|
||||
buffered_packets = []
|
||||
else:
|
||||
yield packet
|
||||
if isinstance(packet, ToolCallKickoff):
|
||||
# if has_streamed_text and not has_completed:
|
||||
# yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
|
||||
# has_completed = True
|
||||
tool_call_name = packet.tool_name
|
||||
tool_call_made = True
|
||||
|
||||
if check_for_tool_call and not tool_call_made:
|
||||
for remaining_packet in buffered_packets:
|
||||
yield remaining_packet
|
||||
return
|
||||
|
||||
for remaining_packet in buffered_packets:
|
||||
yield remaining_packet
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(
|
||||
current_llm_call, tool_call_made
|
||||
)
|
||||
tool_used: Tool | None = None
|
||||
if tool_call_made:
|
||||
tool_used = next(
|
||||
(tool for tool in self.tools if tool.name == tool_call_name), None
|
||||
)
|
||||
|
||||
if new_llm_call:
|
||||
yield from self._get_response(llm_calls + [new_llm_call])
|
||||
yield from self._get_response(
|
||||
llm_calls + [new_llm_call],
|
||||
check_for_tool_call=not tool_call_made,
|
||||
previously_used_tool=tool_used,
|
||||
previous_tool_response=tool_response,
|
||||
)
|
||||
|
||||
else:
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
|
||||
|
||||
# Logic here
|
||||
if (
|
||||
not check_for_tool_call
|
||||
and not tool_call_made
|
||||
and not previously_used_tool
|
||||
):
|
||||
return
|
||||
|
||||
if previously_used_tool:
|
||||
previously_used_tool.build_prompt_after_tool_call(
|
||||
current_llm_call.prompt_builder,
|
||||
self.question,
|
||||
self.llm_answer,
|
||||
previous_tool_response,
|
||||
)
|
||||
# Build next prompter with the original question and the LLM's last answer
|
||||
# current_llm_call.prompt_builder.update_user_prompt(HumanMessage(content=self.question))
|
||||
# current_llm_call.prompt_builder.build_next_prompter(self.question, self.llm_answer)
|
||||
|
||||
llm_call = LLMCall(
|
||||
prompt_builder=current_llm_call.prompt_builder,
|
||||
tools=self._get_tools_list(),
|
||||
force_use_tool=self.force_use_tool,
|
||||
files=self.latest_query_files,
|
||||
tool_call_info=[],
|
||||
using_tool_calling_llm=self.using_tool_calling_llm,
|
||||
)
|
||||
yield from self._get_response(
|
||||
[llm_call],
|
||||
check_for_tool_call=not tool_call_made,
|
||||
previously_used_tool=tool_used,
|
||||
previous_tool_response=tool_response,
|
||||
)
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
@@ -276,26 +377,32 @@ class Answer:
|
||||
using_tool_calling_llm=self.using_tool_calling_llm,
|
||||
)
|
||||
|
||||
processed_stream = []
|
||||
for processed_packet in self._get_response([llm_call]):
|
||||
processed_stream.append(processed_packet)
|
||||
if (
|
||||
isinstance(processed_packet, StreamStopInfo)
|
||||
and processed_packet.stop_reason == StreamStopReason.NEW_RESPONSE
|
||||
):
|
||||
self.current_streamed_output = self.processing_stream
|
||||
self.processing_stream = []
|
||||
self.processing_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
|
||||
self._processed_stream = processed_stream
|
||||
self.current_streamed_output = self.processing_stream
|
||||
self._processed_stream = self.processing_stream
|
||||
|
||||
@property
|
||||
def llm_answer(self) -> str:
|
||||
answer = ""
|
||||
for packet in self.processed_streamed_output:
|
||||
if not self._processed_stream and not self.current_streamed_output:
|
||||
return ""
|
||||
for packet in self.current_streamed_output or self._processed_stream or []:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
|
||||
return answer
|
||||
|
||||
@property
|
||||
def citations(self) -> list[CitationInfo]:
|
||||
citations: list[CitationInfo] = []
|
||||
for packet in self.processed_streamed_output:
|
||||
for packet in self.current_streamed_output or self._processed_stream or []:
|
||||
if isinstance(packet, CitationInfo):
|
||||
citations.append(packet)
|
||||
|
||||
|
||||
@@ -80,5 +80,7 @@ class LLMResponseHandlerManager:
|
||||
yield from self.tool_handler.handle_response_part(None, all_messages)
|
||||
yield from self.answer_handler.handle_response_part(None, all_messages)
|
||||
|
||||
def next_llm_call(self, llm_call: LLMCall) -> LLMCall | None:
|
||||
return self.tool_handler.next_llm_call(llm_call)
|
||||
def next_llm_call(
|
||||
self, llm_call: LLMCall, tool_call_made: bool = False
|
||||
) -> LLMCall | None:
|
||||
return self.tool_handler.next_llm_call(llm_call, tool_call_made)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
@@ -36,7 +37,10 @@ def default_build_system_message(
|
||||
|
||||
|
||||
def default_build_user_message(
|
||||
user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = []
|
||||
user_query: str,
|
||||
prompt_config: PromptConfig,
|
||||
files: list[InMemoryChatFile] = [],
|
||||
previous_tool_call_count: int = 0,
|
||||
) -> HumanMessage:
|
||||
user_prompt = (
|
||||
CHAT_USER_CONTEXT_FREE_PROMPT.format(
|
||||
@@ -45,10 +49,16 @@ def default_build_user_message(
|
||||
if prompt_config.task_prompt
|
||||
else user_query
|
||||
)
|
||||
if previous_tool_call_count > 0:
|
||||
user_prompt = (
|
||||
f"You have already generated the above so do not call a tool if not necessary. "
|
||||
f"Remember the query is: `{user_prompt}`"
|
||||
)
|
||||
user_prompt = user_prompt.strip()
|
||||
user_msg = HumanMessage(
|
||||
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
|
||||
)
|
||||
|
||||
return user_msg
|
||||
|
||||
|
||||
@@ -87,6 +97,30 @@ class AnswerPromptBuilder:
|
||||
)
|
||||
|
||||
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
||||
self.task_reminder: tuple[HumanMessage, int] | None = None
|
||||
|
||||
def update_task_reminder(self, task_reminder: HumanMessage) -> None:
|
||||
token_count = check_message_tokens(
|
||||
task_reminder, self.llm_tokenizer_encode_func
|
||||
)
|
||||
self.task_reminder = (task_reminder, token_count)
|
||||
|
||||
def build_next_prompter(
|
||||
self, question: str, llm_answer: str, task_reminder: str | None = None
|
||||
):
|
||||
# Append the AI's previous response
|
||||
self.append_message(AIMessage(content=llm_answer))
|
||||
# Add a new user message prompting the assistant to continue
|
||||
self.append_message(
|
||||
HumanMessage(
|
||||
content=(
|
||||
f"If your previous responses did not fully answer the original query: '{question}', "
|
||||
"please continue and complete the answer. Only add information if the original question "
|
||||
"wasn't fully addressed. Use any necessary tools to provide a comprehensive response. "
|
||||
"If the original query was already completely fulfilled, do NOT call a tool."
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
|
||||
if not system_message:
|
||||
@@ -132,6 +166,8 @@ class AnswerPromptBuilder:
|
||||
|
||||
if self.new_messages_and_token_cnts:
|
||||
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
|
||||
if self.task_reminder:
|
||||
final_messages_with_tokens.append(self.task_reminder)
|
||||
|
||||
return drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
|
||||
@@ -173,7 +173,9 @@ class ToolResponseHandler:
|
||||
|
||||
return
|
||||
|
||||
def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None:
|
||||
def next_llm_call(
|
||||
self, current_llm_call: LLMCall, tool_call_made: bool
|
||||
) -> LLMCall | None:
|
||||
if (
|
||||
self.tool_runner is None
|
||||
or self.tool_call_summary is None
|
||||
@@ -191,7 +193,9 @@ class ToolResponseHandler:
|
||||
)
|
||||
return LLMCall(
|
||||
prompt_builder=new_prompt_builder,
|
||||
tools=[], # for now, only allow one tool call per response
|
||||
tools=self.tools
|
||||
if not tool_call_made
|
||||
else [], # for now, only allow one tool call per response
|
||||
force_use_tool=ForceUseTool(
|
||||
force_use=False,
|
||||
tool_name="",
|
||||
|
||||
@@ -25,7 +25,6 @@ from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRead
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.auth.users import auth_backend
|
||||
from danswer.auth.users import BasicAuthenticationError
|
||||
from danswer.auth.users import fastapi_users
|
||||
from danswer.configs.app_configs import APP_API_PREFIX
|
||||
from danswer.configs.app_configs import APP_HOST
|
||||
@@ -195,12 +194,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
|
||||
def log_http_error(_: Request, exc: Exception) -> JSONResponse:
|
||||
status_code = getattr(exc, "status_code", 500)
|
||||
|
||||
if isinstance(exc, BasicAuthenticationError):
|
||||
# For BasicAuthenticationError, just log a brief message without stack trace (almost always spam)
|
||||
logger.error(f"Authentication failed: {str(exc)}")
|
||||
|
||||
elif status_code >= 400:
|
||||
if status_code >= 400:
|
||||
error_msg = f"{str(exc)}\n"
|
||||
error_msg += "".join(traceback.format_tb(exc.__traceback__))
|
||||
logger.error(error_msg)
|
||||
@@ -226,6 +220,7 @@ def get_application() -> FastAPI:
|
||||
else:
|
||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||
|
||||
# Add the custom exception handler
|
||||
application.add_exception_handler(status.HTTP_400_BAD_REQUEST, log_http_error)
|
||||
application.add_exception_handler(status.HTTP_401_UNAUTHORIZED, log_http_error)
|
||||
application.add_exception_handler(status.HTTP_403_FORBIDDEN, log_http_error)
|
||||
@@ -282,14 +277,12 @@ def get_application() -> FastAPI:
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_register_router(UserRead, UserCreate),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_reset_password_router(),
|
||||
|
||||
@@ -35,31 +35,23 @@ class BaseTokenizer(ABC):
|
||||
class TiktokenTokenizer(BaseTokenizer):
|
||||
_instances: dict[str, "TiktokenTokenizer"] = {}
|
||||
|
||||
def __new__(cls, model_name: str) -> "TiktokenTokenizer":
|
||||
if model_name not in cls._instances:
|
||||
cls._instances[model_name] = super(TiktokenTokenizer, cls).__new__(cls)
|
||||
return cls._instances[model_name]
|
||||
def __new__(cls, encoding_name: str = "cl100k_base") -> "TiktokenTokenizer":
|
||||
if encoding_name not in cls._instances:
|
||||
cls._instances[encoding_name] = super(TiktokenTokenizer, cls).__new__(cls)
|
||||
return cls._instances[encoding_name]
|
||||
|
||||
def __init__(self, model_name: str):
|
||||
def __init__(self, encoding_name: str = "cl100k_base"):
|
||||
if not hasattr(self, "encoder"):
|
||||
import tiktoken
|
||||
|
||||
self.encoder = tiktoken.encoding_for_model(model_name)
|
||||
self.encoder = tiktoken.get_encoding(encoding_name)
|
||||
|
||||
def encode(self, string: str) -> list[int]:
|
||||
# this ignores special tokens that the model is trained on, see encode_ordinary for details
|
||||
# this returns no special tokens
|
||||
return self.encoder.encode_ordinary(string)
|
||||
|
||||
def tokenize(self, string: str) -> list[str]:
|
||||
encoded = self.encode(string)
|
||||
decoded = [self.encoder.decode([token]) for token in encoded]
|
||||
|
||||
if len(decoded) != len(encoded):
|
||||
logger.warning(
|
||||
f"OpenAI tokenized length {len(decoded)} does not match encoded length {len(encoded)} for string: {string}"
|
||||
)
|
||||
|
||||
return decoded
|
||||
return [self.encoder.decode([token]) for token in self.encode(string)]
|
||||
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
return self.encoder.decode(tokens)
|
||||
@@ -82,35 +74,22 @@ class HuggingFaceTokenizer(BaseTokenizer):
|
||||
return self.encoder.decode(tokens)
|
||||
|
||||
|
||||
_TOKENIZER_CACHE: dict[tuple[EmbeddingProvider | None, str | None], BaseTokenizer] = {}
|
||||
_TOKENIZER_CACHE: dict[str, BaseTokenizer] = {}
|
||||
|
||||
|
||||
def _check_tokenizer_cache(
|
||||
model_provider: EmbeddingProvider | None, model_name: str | None
|
||||
) -> BaseTokenizer:
|
||||
def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
|
||||
global _TOKENIZER_CACHE
|
||||
|
||||
id_tuple = (model_provider, model_name)
|
||||
|
||||
if id_tuple not in _TOKENIZER_CACHE:
|
||||
if model_provider in [EmbeddingProvider.OPENAI, EmbeddingProvider.AZURE]:
|
||||
if model_name is None:
|
||||
raise ValueError(
|
||||
"model_name is required for OPENAI and AZURE embeddings"
|
||||
)
|
||||
|
||||
_TOKENIZER_CACHE[id_tuple] = TiktokenTokenizer(model_name)
|
||||
return _TOKENIZER_CACHE[id_tuple]
|
||||
|
||||
if tokenizer_name not in _TOKENIZER_CACHE:
|
||||
if tokenizer_name == "openai":
|
||||
_TOKENIZER_CACHE[tokenizer_name] = TiktokenTokenizer("cl100k_base")
|
||||
return _TOKENIZER_CACHE[tokenizer_name]
|
||||
try:
|
||||
if model_name is None:
|
||||
model_name = DOCUMENT_ENCODER_MODEL
|
||||
|
||||
logger.debug(f"Initializing HuggingFaceTokenizer for: {model_name}")
|
||||
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(model_name)
|
||||
logger.debug(f"Initializing HuggingFaceTokenizer for: {tokenizer_name}")
|
||||
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(tokenizer_name)
|
||||
except Exception as primary_error:
|
||||
logger.error(
|
||||
f"Error initializing HuggingFaceTokenizer for {model_name}: {primary_error}"
|
||||
f"Error initializing HuggingFaceTokenizer for {tokenizer_name}: {primary_error}"
|
||||
)
|
||||
logger.warning(
|
||||
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
|
||||
@@ -119,7 +98,7 @@ def _check_tokenizer_cache(
|
||||
try:
|
||||
# Cache this tokenizer name to the default so we don't have to try to load it again
|
||||
# and fail again
|
||||
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(
|
||||
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(
|
||||
DOCUMENT_ENCODER_MODEL
|
||||
)
|
||||
except Exception as fallback_error:
|
||||
@@ -127,10 +106,10 @@ def _check_tokenizer_cache(
|
||||
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to initialize tokenizer for {model_name} and fallback model"
|
||||
f"Failed to initialize tokenizer for {tokenizer_name} and fallback model"
|
||||
) from fallback_error
|
||||
|
||||
return _TOKENIZER_CACHE[id_tuple]
|
||||
return _TOKENIZER_CACHE[tokenizer_name]
|
||||
|
||||
|
||||
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
||||
@@ -139,16 +118,11 @@ _DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
||||
def get_tokenizer(
|
||||
model_name: str | None, provider_type: EmbeddingProvider | str | None
|
||||
) -> BaseTokenizer:
|
||||
if provider_type is not None:
|
||||
if isinstance(provider_type, str):
|
||||
try:
|
||||
provider_type = EmbeddingProvider(provider_type)
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
|
||||
)
|
||||
return _DEFAULT_TOKENIZER
|
||||
return _check_tokenizer_cache(provider_type, model_name)
|
||||
# Currently all of the viable models use the same sentencepiece tokenizer
|
||||
# OpenAI uses a different one but currently it's not supported due to quality issues
|
||||
# the inconsistent chunking makes using the sentencepiece tokenizer default better for now
|
||||
# LLM tokenizers are specified by strings
|
||||
global _DEFAULT_TOKENIZER
|
||||
return _DEFAULT_TOKENIZER
|
||||
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ from danswer.tools.tool_implementations.search.search_tool import (
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -125,11 +125,11 @@ def stream_answer_objects(
|
||||
)
|
||||
|
||||
temporary_persona: Persona | None = None
|
||||
|
||||
if query_req.persona_config is not None:
|
||||
temporary_persona = fetch_ee_implementation_or_noop(
|
||||
"danswer.server.query_and_chat.utils", "create_temporary_persona", None
|
||||
)(db_session=db_session, persona_config=query_req.persona_config, user=user)
|
||||
new_persona = create_temporary_persona(
|
||||
db_session=db_session, persona_config=query_req.persona_config, user=user
|
||||
)
|
||||
temporary_persona = new_persona
|
||||
|
||||
persona = temporary_persona if temporary_persona else chat_session.persona
|
||||
|
||||
|
||||
@@ -10,7 +10,8 @@ from danswer.auth.users import current_user
|
||||
from danswer.auth.users import current_user_with_expired_token
|
||||
from danswer.configs.app_configs import APP_API_PREFIX
|
||||
from danswer.server.danswer_api.ingestion import api_key_dep
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from ee.danswer.auth.users import current_cloud_superuser
|
||||
from ee.danswer.server.tenants.access import control_plane_dep
|
||||
|
||||
|
||||
PUBLIC_ENDPOINT_SPECS = [
|
||||
@@ -80,14 +81,6 @@ def check_router_auth(
|
||||
(1) have auth enabled OR
|
||||
(2) are explicitly marked as a public endpoint
|
||||
"""
|
||||
|
||||
control_plane_dep = fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.access", "control_plane_dep"
|
||||
)
|
||||
current_cloud_superuser = fetch_ee_implementation_or_noop(
|
||||
"danswer.auth.users", "current_cloud_superuser"
|
||||
)
|
||||
|
||||
for route in application.routes:
|
||||
# explicitly marked as public
|
||||
if is_route_in_spec_list(route, public_endpoint_specs):
|
||||
|
||||
@@ -3,7 +3,6 @@ from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import api_key_dep
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
@@ -23,6 +22,7 @@ from danswer.server.danswer_api.models import DocMinimalInfo
|
||||
from danswer.server.danswer_api.models import IngestionDocument
|
||||
from danswer.server.danswer_api.models import IngestionResult
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.auth.users import api_key_dep
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -16,9 +16,6 @@ from danswer.background.celery.tasks.pruning.tasks import (
|
||||
try_creating_prune_generator_task,
|
||||
)
|
||||
from danswer.background.celery.versioned_apps.primary import app as primary_app
|
||||
from danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from danswer.db.connector_credential_pair import add_credential_to_connector
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import remove_credential_from_connector
|
||||
@@ -50,7 +47,11 @@ from danswer.server.documents.models import ConnectorCredentialPairMetadata
|
||||
from danswer.server.documents.models import PaginatedIndexAttempts
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from ee.danswer.db.user_group import validate_user_creation_permissions
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/manage")
|
||||
@@ -331,6 +332,9 @@ def sync_cc_pair(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[list[int]]:
|
||||
# avoiding circular refs
|
||||
from ee.danswer.background.celery.apps.primary import (
|
||||
sync_external_doc_permissions_task,
|
||||
)
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -356,19 +360,12 @@ def sync_cc_pair(
|
||||
)
|
||||
|
||||
logger.info(f"Syncing the {cc_pair.connector.name} connector.")
|
||||
sync_external_doc_permissions_task = fetch_ee_implementation_or_noop(
|
||||
"danswer.background.celery.apps.primary",
|
||||
"sync_external_doc_permissions_task",
|
||||
None,
|
||||
sync_external_doc_permissions_task.apply_async(
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id, tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
),
|
||||
)
|
||||
|
||||
if sync_external_doc_permissions_task:
|
||||
sync_external_doc_permissions_task.apply_async(
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id, tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
),
|
||||
)
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Successfully created the sync task.",
|
||||
@@ -383,9 +380,7 @@ def associate_credential_to_connector(
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[int]:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.db.user_group", "validate_user_creation_permissions", None
|
||||
)(
|
||||
validate_user_creation_permissions(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=metadata.groups,
|
||||
|
||||
@@ -108,7 +108,7 @@ from danswer.server.documents.models import ObjectCreationIdResponse
|
||||
from danswer.server.documents.models import RunConnectorRequest
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from ee.danswer.db.user_group import validate_user_creation_permissions
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -658,10 +658,7 @@ def create_connector_from_model(
|
||||
) -> ObjectCreationIdResponse:
|
||||
try:
|
||||
_validate_connector_allowed(connector_data.source)
|
||||
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.db.user_group", "validate_user_creation_permissions", None
|
||||
)(
|
||||
validate_user_creation_permissions(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=connector_data.groups,
|
||||
@@ -735,9 +732,7 @@ def update_connector_from_model(
|
||||
) -> ConnectorSnapshot | StatusResponse[int]:
|
||||
try:
|
||||
_validate_connector_allowed(connector_data.source)
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.db.user_group", "validate_user_creation_permissions", None
|
||||
)(
|
||||
validate_user_creation_permissions(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=connector_data.groups,
|
||||
|
||||
@@ -28,7 +28,7 @@ from danswer.server.documents.models import CredentialSwapRequest
|
||||
from danswer.server.documents.models import ObjectCreationIdResponse
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from ee.danswer.db.user_group import validate_user_creation_permissions
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -121,9 +121,7 @@ def create_credential_from_model(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ObjectCreationIdResponse:
|
||||
if not _ignore_credential_permissions(credential_info.source):
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.db.user_group", "validate_user_creation_permissions", None
|
||||
)(
|
||||
validate_user_creation_permissions(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=credential_info.groups,
|
||||
|
||||
@@ -18,7 +18,7 @@ from danswer.server.features.document_set.models import CheckDocSetPublicRespons
|
||||
from danswer.server.features.document_set.models import DocumentSet
|
||||
from danswer.server.features.document_set.models import DocumentSetCreationRequest
|
||||
from danswer.server.features.document_set.models import DocumentSetUpdateRequest
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from ee.danswer.db.user_group import validate_user_creation_permissions
|
||||
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
@@ -30,9 +30,7 @@ def create_document_set(
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> int:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.db.user_group", "validate_user_creation_permissions", None
|
||||
)(
|
||||
validate_user_creation_permissions(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=document_set_creation_request.groups,
|
||||
@@ -55,9 +53,7 @@ def patch_document_set(
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.db.user_group", "validate_user_creation_permissions", None
|
||||
)(
|
||||
validate_user_creation_permissions(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=document_set_update_request.groups,
|
||||
|
||||
@@ -11,6 +11,7 @@ from fastapi import Body
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import status
|
||||
from psycopg2.errors import UniqueViolation
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Column
|
||||
@@ -26,10 +27,10 @@ from danswer.auth.noauth_user import fetch_no_auth_user
|
||||
from danswer.auth.noauth_user import set_no_auth_user_preferences
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserStatus
|
||||
from danswer.auth.users import BasicAuthenticationError
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.auth.users import get_tenant_id_for_email
|
||||
from danswer.auth.users import optional_user
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||
@@ -37,7 +38,6 @@ from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.configs.app_configs import SUPER_USERS
|
||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.api_key import is_api_key_email_address
|
||||
from danswer.db.auth import get_total_users_count
|
||||
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from danswer.db.engine import get_session
|
||||
@@ -61,7 +61,12 @@ from danswer.server.models import InvitedUserSnapshot
|
||||
from danswer.server.models import MinimalUserSnapshot
|
||||
from danswer.server.utils import send_user_email_invite
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from ee.danswer.db.api_key import is_api_key_email_address
|
||||
from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_commit
|
||||
from ee.danswer.db.user_group import remove_curator_status__no_commit
|
||||
from ee.danswer.server.tenants.billing import register_tenant_users
|
||||
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
|
||||
from ee.danswer.server.tenants.provisioning import remove_users_from_tenant
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -100,10 +105,7 @@ def set_user_role(
|
||||
)
|
||||
|
||||
if user_to_update.role == UserRole.CURATOR:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.db.user_group",
|
||||
"remove_curator_status__no_commit",
|
||||
)(db_session, user_to_update)
|
||||
remove_curator_status__no_commit(db_session, user_to_update)
|
||||
|
||||
user_to_update.role = user_role_update_request.new_role.value
|
||||
|
||||
@@ -203,9 +205,7 @@ def bulk_invite_users(
|
||||
|
||||
if MULTI_TENANT:
|
||||
try:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning", "add_users_to_tenant", None
|
||||
)(normalized_emails, tenant_id)
|
||||
add_users_to_tenant(normalized_emails, tenant_id)
|
||||
|
||||
except IntegrityError as e:
|
||||
if isinstance(e.orig, UniqueViolation):
|
||||
@@ -226,9 +226,9 @@ def bulk_invite_users(
|
||||
return number_of_invited_users
|
||||
try:
|
||||
logger.info("Registering tenant users")
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.billing", "register_tenant_users", None
|
||||
)(CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session))
|
||||
register_tenant_users(
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session)
|
||||
)
|
||||
if ENABLE_EMAIL_INVITES:
|
||||
try:
|
||||
for email in all_emails:
|
||||
@@ -243,9 +243,7 @@ def bulk_invite_users(
|
||||
"Reverting changes: removing users from tenant and resetting invited users"
|
||||
)
|
||||
write_invited_users(initial_invited_users) # Reset to original state
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.user_mapping", "remove_users_from_tenant", None
|
||||
)(normalized_emails, tenant_id)
|
||||
remove_users_from_tenant(normalized_emails, tenant_id)
|
||||
raise e
|
||||
|
||||
|
||||
@@ -259,16 +257,14 @@ def remove_invited_user(
|
||||
remaining_users = [user for user in user_emails if user != user_email.user_email]
|
||||
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.user_mapping", "remove_users_from_tenant", None
|
||||
)([user_email.user_email], tenant_id)
|
||||
remove_users_from_tenant([user_email.user_email], tenant_id)
|
||||
number_of_invited_users = write_invited_users(remaining_users)
|
||||
|
||||
try:
|
||||
if MULTI_TENANT:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.billing", "register_tenant_users", None
|
||||
)(CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session))
|
||||
register_tenant_users(
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session)
|
||||
)
|
||||
except Exception:
|
||||
logger.error(
|
||||
"Request to update number of seats taken in control plane failed. "
|
||||
@@ -335,10 +331,7 @@ async def delete_user(
|
||||
for oauth_account in user_to_delete.oauth_accounts:
|
||||
db_session.delete(oauth_account)
|
||||
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.db.external_perm",
|
||||
"delete_user__ext_group_for_user__no_commit",
|
||||
)(
|
||||
delete_user__ext_group_for_user__no_commit(
|
||||
db_session=db_session,
|
||||
user_id=user_to_delete.id,
|
||||
)
|
||||
@@ -492,19 +485,20 @@ def verify_user_logged_in(
|
||||
store = get_kv_store()
|
||||
return fetch_no_auth_user(store)
|
||||
|
||||
raise BasicAuthenticationError(detail="User Not Authenticated")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="User Not Authenticated"
|
||||
)
|
||||
|
||||
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
|
||||
raise BasicAuthenticationError(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
token_created_at = (
|
||||
None if MULTI_TENANT else get_current_token_creation(user, db_session)
|
||||
)
|
||||
organization_name = fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.user_mapping", "get_tenant_id_for_email", None
|
||||
)(user.email)
|
||||
organization_name = get_tenant_id_for_email(user.email)
|
||||
|
||||
user_info = UserInfo.from_model(
|
||||
user,
|
||||
|
||||
@@ -359,7 +359,7 @@ def handle_new_chat_message(
|
||||
yield json.dumps(packet) if isinstance(packet, dict) else packet
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Error in chat message streaming")
|
||||
logger.exception(f"Error in chat message streaming: {e}")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
finally:
|
||||
|
||||
@@ -279,7 +279,7 @@ def get_answer_with_quote(
|
||||
):
|
||||
yield json.dumps(packet) if isinstance(packet, dict) else packet
|
||||
except Exception as e:
|
||||
logger.exception("Error in search answer streaming")
|
||||
logger.exception(f"Error in search answer streaming: {e}")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="application/json")
|
||||
|
||||
@@ -18,9 +18,9 @@ from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import ChatSession
|
||||
from danswer.db.models import TokenRateLimit
|
||||
from danswer.db.models import User
|
||||
from danswer.db.token_limit import fetch_all_global_token_rate_limits
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
|
||||
@@ -5,13 +5,13 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.token_limit import delete_token_rate_limit
|
||||
from danswer.db.token_limit import fetch_all_global_token_rate_limits
|
||||
from danswer.db.token_limit import insert_global_token_rate_limit
|
||||
from danswer.db.token_limit import update_token_rate_limit
|
||||
from danswer.server.query_and_chat.token_limit import any_rate_limit_exists
|
||||
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
from danswer.server.token_rate_limits.models import TokenRateLimitDisplay
|
||||
from ee.danswer.db.token_limit import delete_token_rate_limit
|
||||
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
|
||||
from ee.danswer.db.token_limit import insert_global_token_rate_limit
|
||||
from ee.danswer.db.token_limit import update_token_rate_limit
|
||||
|
||||
router = APIRouter(prefix="/admin/token-rate-limits")
|
||||
|
||||
|
||||
@@ -80,3 +80,15 @@ class Tool(abc.ABC):
|
||||
using_tool_calling_llm: bool,
|
||||
) -> "AnswerPromptBuilder":
|
||||
raise NotImplementedError
|
||||
|
||||
# This is the prompt builder that is used when the tool call AND LLM response has been updated
|
||||
# and we need to build the next prompt (for LLM calling tools)
|
||||
# @abc.abstractmethod
|
||||
|
||||
def build_prompt_after_tool_call(
|
||||
self,
|
||||
prompt_builder: "AnswerPromptBuilder",
|
||||
query: str,
|
||||
llm_answer: str,
|
||||
) -> "AnswerPromptBuilder":
|
||||
pass
|
||||
|
||||
@@ -15,7 +15,7 @@ from langchain_core.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.db.engine import get_session_with_default_tenant
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
@@ -187,7 +187,7 @@ class CustomTool(BaseTool):
|
||||
def _save_and_get_file_references(
|
||||
self, file_content: bytes | str, content_type: str
|
||||
) -> List[str]:
|
||||
with get_session_with_default_tenant() as db_session:
|
||||
with get_session_with_tenant() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
|
||||
file_id = str(uuid.uuid4())
|
||||
@@ -299,7 +299,7 @@ class CustomTool(BaseTool):
|
||||
|
||||
# Load files from storage
|
||||
files = []
|
||||
with get_session_with_default_tenant() as db_session:
|
||||
with get_session_with_tenant() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
|
||||
for file_id in response.tool_result.file_ids:
|
||||
|
||||
@@ -4,6 +4,8 @@ from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from litellm import image_generation # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -22,6 +24,9 @@ from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_implementations.images.prompt import (
|
||||
build_image_generation_user_prompt,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.prompt import (
|
||||
build_image_generation_user_task_prompt,
|
||||
)
|
||||
from danswer.utils.headers import build_llm_extra_headers
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -297,3 +302,49 @@ class ImageGenerationTool(Tool):
|
||||
)
|
||||
|
||||
return prompt_builder
|
||||
|
||||
def build_prompt_after_tool_call(
|
||||
self,
|
||||
prompt_builder: "AnswerPromptBuilder",
|
||||
query: str,
|
||||
llm_answer: str,
|
||||
tool_responses: "ToolResponse",
|
||||
) -> "AnswerPromptBuilder":
|
||||
# Append the assistant's previous response to the message history
|
||||
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], tool_responses.response
|
||||
)
|
||||
|
||||
if img_generation_response is None:
|
||||
raise ValueError("No image generation response found")
|
||||
|
||||
img_urls = [img.url for img in img_generation_response]
|
||||
|
||||
# Build a user message that includes the images generated
|
||||
user_message = build_image_generation_user_task_prompt(
|
||||
img_urls=img_urls,
|
||||
)
|
||||
prompt_builder.update_user_prompt(HumanMessage(content=query))
|
||||
|
||||
# Update the user prompt with the new message containing images
|
||||
prompt_builder.append_message(user_message)
|
||||
|
||||
prompt_builder.append_message(
|
||||
AIMessage(
|
||||
content=f"The images I generated can be described as the following: {llm_answer}"
|
||||
)
|
||||
)
|
||||
|
||||
# Append a new user message reminding the assistant of the original query and what remains to be done
|
||||
prompt_builder.update_task_reminder(
|
||||
HumanMessage(
|
||||
content=f"Reminder: the original request was: '{query}'.\n\n"
|
||||
"You generated the above images as part of this request. "
|
||||
"If any parts have not been fulfilled, please proceed to complete them using the appropriate tools. "
|
||||
"If the original request has been fulfilled with the prior messages,"
|
||||
"you can provide a final summary and DO NOT call a tool."
|
||||
)
|
||||
)
|
||||
|
||||
return prompt_builder
|
||||
|
||||
@@ -10,6 +10,11 @@ Can you please summarize them in a sentence or two? Do NOT include image urls or
|
||||
"""
|
||||
|
||||
|
||||
IMG_GENERATION_USER_PROMPT = """
|
||||
These are the IMAGES you generated
|
||||
"""
|
||||
|
||||
|
||||
def build_image_generation_user_prompt(
|
||||
query: str, img_urls: list[str] | None = None
|
||||
) -> HumanMessage:
|
||||
@@ -19,3 +24,14 @@ def build_image_generation_user_prompt(
|
||||
img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_image_generation_user_task_prompt(
|
||||
img_urls: list[str] | None = None,
|
||||
) -> HumanMessage:
|
||||
return HumanMessage(
|
||||
content=build_content_with_imgs(
|
||||
message=IMG_GENERATION_USER_PROMPT,
|
||||
img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -119,30 +119,3 @@ def noop_fallback(*args: Any, **kwargs: Any) -> None:
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
|
||||
def fetch_ee_implementation_or_noop(
|
||||
module: str, attribute: str, noop_return_value: Any = None
|
||||
) -> Any:
|
||||
"""
|
||||
Fetches an EE implementation if EE is enabled, otherwise returns a no-op function.
|
||||
Raises an exception if EE is enabled but the fetch fails.
|
||||
|
||||
Args:
|
||||
module (str): The name of the module from which to fetch the attribute.
|
||||
attribute (str): The name of the attribute to fetch from the module.
|
||||
|
||||
Returns:
|
||||
Any: The fetched EE implementation if successful and EE is enabled, otherwise a no-op function.
|
||||
|
||||
Raises:
|
||||
Exception: If EE is enabled but the fetch fails.
|
||||
"""
|
||||
if not global_version.is_ee_version():
|
||||
return lambda *args, **kwargs: noop_return_value
|
||||
|
||||
try:
|
||||
return fetch_versioned_implementation(module, attribute)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch implementation for {module}.{attribute}: {e}")
|
||||
raise
|
||||
|
||||
@@ -8,7 +8,7 @@ from passlib.hash import sha256_crypt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.app_configs import API_KEY_HASH_ROUNDS
|
||||
from ee.danswer.configs.app_configs import API_KEY_HASH_ROUNDS
|
||||
|
||||
|
||||
_API_KEY_HEADER_NAME = "Authorization"
|
||||
@@ -4,15 +4,16 @@ from fastapi import Request
|
||||
from fastapi import status
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.api_key import get_hashed_api_key_from_request
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import SUPER_CLOUD_API_KEY
|
||||
from danswer.configs.app_configs import SUPER_USERS
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.api_key import fetch_user_for_api_key
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.auth.api_key import get_hashed_api_key_from_request
|
||||
from ee.danswer.db.api_key import fetch_user_for_api_key
|
||||
from ee.danswer.db.saml import get_saml_account
|
||||
from ee.danswer.server.seeding import get_seed_config
|
||||
from ee.danswer.utils.secrets import extract_hashed_cookie
|
||||
@@ -47,6 +48,25 @@ async def optional_user_(
|
||||
return user
|
||||
|
||||
|
||||
def api_key_dep(
|
||||
request: Request, db_session: Session = Depends(get_session)
|
||||
) -> User | None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return None
|
||||
|
||||
hashed_api_key = get_hashed_api_key_from_request(request)
|
||||
if not hashed_api_key:
|
||||
raise HTTPException(status_code=401, detail="Missing API key")
|
||||
|
||||
if hashed_api_key:
|
||||
user = fetch_user_for_api_key(hashed_api_key, db_session)
|
||||
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def get_default_admin_user_emails_() -> list[str]:
|
||||
seed_config = get_seed_config()
|
||||
if seed_config and seed_config.admin_user_emails:
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
from danswer.background.celery.apps.primary import celery_app
|
||||
from danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.db.chat import delete_chat_sessions_older_than
|
||||
@@ -17,6 +14,9 @@ from ee.danswer.background.celery_utils import (
|
||||
should_perform_external_group_permissions_check,
|
||||
)
|
||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_group_permissions_task,
|
||||
)
|
||||
|
||||
@@ -3,15 +3,15 @@ from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.background.task_name_builders import name_chat_ttl_task
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_doc_permissions_task,
|
||||
)
|
||||
from ee.danswer.background.task_name_builders import (
|
||||
name_sync_external_group_permissions_task,
|
||||
)
|
||||
|
||||
@@ -2,6 +2,12 @@ def name_chat_ttl_task(retention_limit_days: int, tenant_id: str | None = None)
|
||||
return f"chat_ttl_{retention_limit_days}_days"
|
||||
|
||||
|
||||
def name_sync_external_doc_permissions_task(
|
||||
cc_pair_id: int, tenant_id: str | None = None
|
||||
) -> str:
|
||||
return f"sync_external_doc_permissions_task__{cc_pair_id}"
|
||||
|
||||
|
||||
def name_sync_external_group_permissions_task(
|
||||
cc_pair_id: int, tenant_id: str | None = None
|
||||
) -> str:
|
||||
|
||||
@@ -7,6 +7,16 @@ OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
|
||||
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/danswer/configs/saml_config"
|
||||
|
||||
|
||||
#####
|
||||
# API Key Configs
|
||||
#####
|
||||
# refers to the rounds described here: https://passlib.readthedocs.io/en/stable/lib/passlib.hash.sha256_crypt.html
|
||||
_API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
|
||||
API_KEY_HASH_ROUNDS = (
|
||||
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# Auto Permission Sync
|
||||
#####
|
||||
|
||||
@@ -5,16 +5,16 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.api_key import ApiKeyDescriptor
|
||||
from danswer.auth.api_key import build_displayable_api_key
|
||||
from danswer.auth.api_key import generate_api_key
|
||||
from danswer.auth.api_key import hash_api_key
|
||||
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from danswer.db.models import ApiKey
|
||||
from danswer.db.models import User
|
||||
from danswer.server.api_key.models import APIKeyArgs
|
||||
from ee.danswer.auth.api_key import ApiKeyDescriptor
|
||||
from ee.danswer.auth.api_key import build_displayable_api_key
|
||||
from ee.danswer.auth.api_key import generate_api_key
|
||||
from ee.danswer.auth.api_key import hash_api_key
|
||||
from ee.danswer.server.api_key.models import APIKeyArgs
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
@@ -65,6 +65,64 @@ def _add_user_filters(
|
||||
return stmt.where(where_clause)
|
||||
|
||||
|
||||
def fetch_all_user_token_rate_limits(
|
||||
db_session: Session,
|
||||
enabled_only: bool = False,
|
||||
ordered: bool = True,
|
||||
) -> Sequence[TokenRateLimit]:
|
||||
query = select(TokenRateLimit).where(
|
||||
TokenRateLimit.scope == TokenRateLimitScope.USER
|
||||
)
|
||||
|
||||
if enabled_only:
|
||||
query = query.where(TokenRateLimit.enabled.is_(True))
|
||||
|
||||
if ordered:
|
||||
query = query.order_by(TokenRateLimit.created_at.desc())
|
||||
|
||||
return db_session.scalars(query).all()
|
||||
|
||||
|
||||
def fetch_all_global_token_rate_limits(
|
||||
db_session: Session,
|
||||
enabled_only: bool = False,
|
||||
ordered: bool = True,
|
||||
) -> Sequence[TokenRateLimit]:
|
||||
query = select(TokenRateLimit).where(
|
||||
TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
|
||||
)
|
||||
|
||||
if enabled_only:
|
||||
query = query.where(TokenRateLimit.enabled.is_(True))
|
||||
|
||||
if ordered:
|
||||
query = query.order_by(TokenRateLimit.created_at.desc())
|
||||
|
||||
token_rate_limits = db_session.scalars(query).all()
|
||||
return token_rate_limits
|
||||
|
||||
|
||||
def fetch_user_group_token_rate_limits(
|
||||
db_session: Session,
|
||||
group_id: int,
|
||||
user: User | None = None,
|
||||
enabled_only: bool = False,
|
||||
ordered: bool = True,
|
||||
get_editable: bool = True,
|
||||
) -> Sequence[TokenRateLimit]:
|
||||
stmt = select(TokenRateLimit)
|
||||
stmt = stmt.where(User__UserGroup.user_group_id == group_id)
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
|
||||
if enabled_only:
|
||||
stmt = stmt.where(TokenRateLimit.enabled.is_(True))
|
||||
|
||||
if ordered:
|
||||
stmt = stmt.order_by(TokenRateLimit.created_at.desc())
|
||||
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def fetch_all_user_group_token_rate_limits_by_group(
|
||||
db_session: Session,
|
||||
) -> Sequence[Row[tuple[TokenRateLimit, str]]]:
|
||||
@@ -80,6 +138,38 @@ def fetch_all_user_group_token_rate_limits_by_group(
|
||||
return db_session.execute(query).all()
|
||||
|
||||
|
||||
def insert_user_token_rate_limit(
|
||||
db_session: Session,
|
||||
token_rate_limit_settings: TokenRateLimitArgs,
|
||||
) -> TokenRateLimit:
|
||||
token_limit = TokenRateLimit(
|
||||
enabled=token_rate_limit_settings.enabled,
|
||||
token_budget=token_rate_limit_settings.token_budget,
|
||||
period_hours=token_rate_limit_settings.period_hours,
|
||||
scope=TokenRateLimitScope.USER,
|
||||
)
|
||||
db_session.add(token_limit)
|
||||
db_session.commit()
|
||||
|
||||
return token_limit
|
||||
|
||||
|
||||
def insert_global_token_rate_limit(
|
||||
db_session: Session,
|
||||
token_rate_limit_settings: TokenRateLimitArgs,
|
||||
) -> TokenRateLimit:
|
||||
token_limit = TokenRateLimit(
|
||||
enabled=token_rate_limit_settings.enabled,
|
||||
token_budget=token_rate_limit_settings.token_budget,
|
||||
period_hours=token_rate_limit_settings.period_hours,
|
||||
scope=TokenRateLimitScope.GLOBAL,
|
||||
)
|
||||
db_session.add(token_limit)
|
||||
db_session.commit()
|
||||
|
||||
return token_limit
|
||||
|
||||
|
||||
def insert_user_group_token_rate_limit(
|
||||
db_session: Session,
|
||||
token_rate_limit_settings: TokenRateLimitArgs,
|
||||
@@ -103,22 +193,34 @@ def insert_user_group_token_rate_limit(
|
||||
return token_limit
|
||||
|
||||
|
||||
def fetch_user_group_token_rate_limits(
|
||||
def update_token_rate_limit(
|
||||
db_session: Session,
|
||||
group_id: int,
|
||||
user: User | None = None,
|
||||
enabled_only: bool = False,
|
||||
ordered: bool = True,
|
||||
get_editable: bool = True,
|
||||
) -> Sequence[TokenRateLimit]:
|
||||
stmt = select(TokenRateLimit)
|
||||
stmt = stmt.where(User__UserGroup.user_group_id == group_id)
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
token_rate_limit_id: int,
|
||||
token_rate_limit_settings: TokenRateLimitArgs,
|
||||
) -> TokenRateLimit:
|
||||
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
|
||||
if token_limit is None:
|
||||
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
|
||||
|
||||
if enabled_only:
|
||||
stmt = stmt.where(TokenRateLimit.enabled.is_(True))
|
||||
token_limit.enabled = token_rate_limit_settings.enabled
|
||||
token_limit.token_budget = token_rate_limit_settings.token_budget
|
||||
token_limit.period_hours = token_rate_limit_settings.period_hours
|
||||
db_session.commit()
|
||||
|
||||
if ordered:
|
||||
stmt = stmt.order_by(TokenRateLimit.created_at.desc())
|
||||
return token_limit
|
||||
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
def delete_token_rate_limit(
|
||||
db_session: Session,
|
||||
token_rate_limit_id: int,
|
||||
) -> None:
|
||||
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
|
||||
if token_limit is None:
|
||||
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
|
||||
|
||||
db_session.query(TokenRateLimit__UserGroup).filter(
|
||||
TokenRateLimit__UserGroup.rate_limit_id == token_rate_limit_id
|
||||
).delete()
|
||||
|
||||
db_session.delete(token_limit)
|
||||
db_session.commit()
|
||||
|
||||
@@ -12,11 +12,11 @@ from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.main import get_application as get_application_base
|
||||
from danswer.main import include_router_with_global_prefix_prepended
|
||||
from danswer.server.api_key.api import router as api_key_router
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from ee.danswer.configs.app_configs import OPENID_CONFIG_URL
|
||||
from ee.danswer.server.analytics.api import router as analytics_router
|
||||
from ee.danswer.server.api_key.api import router as api_key_router
|
||||
from ee.danswer.server.auth_check import check_ee_router_auth
|
||||
from ee.danswer.server.enterprise_settings.api import (
|
||||
admin_router as enterprise_settings_admin_router,
|
||||
|
||||
@@ -3,15 +3,15 @@ from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.db.api_key import ApiKeyDescriptor
|
||||
from danswer.db.api_key import fetch_api_keys
|
||||
from danswer.db.api_key import insert_api_key
|
||||
from danswer.db.api_key import regenerate_api_key
|
||||
from danswer.db.api_key import remove_api_key
|
||||
from danswer.db.api_key import update_api_key
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.server.api_key.models import APIKeyArgs
|
||||
from ee.danswer.db.api_key import ApiKeyDescriptor
|
||||
from ee.danswer.db.api_key import fetch_api_keys
|
||||
from ee.danswer.db.api_key import insert_api_key
|
||||
from ee.danswer.db.api_key import regenerate_api_key
|
||||
from ee.danswer.db.api_key import remove_api_key
|
||||
from ee.danswer.db.api_key import update_api_key
|
||||
from ee.danswer.server.api_key.models import APIKeyArgs
|
||||
|
||||
|
||||
router = APIRouter(prefix="/admin/api-key")
|
||||
@@ -8,9 +8,9 @@ from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
|
||||
from danswer.auth.api_key import extract_tenant_from_api_key_header
|
||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
from danswer.db.engine import is_valid_schema_name
|
||||
from ee.danswer.auth.api_key import extract_tenant_from_api_key_header
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
@@ -12,7 +12,6 @@ from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.api_key import is_api_key_email_address
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import ChatSession
|
||||
@@ -21,11 +20,12 @@ from danswer.db.models import TokenRateLimit__UserGroup
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.db.token_limit import fetch_all_user_token_rate_limits
|
||||
from danswer.server.query_and_chat.token_limit import _get_cutoff_time
|
||||
from danswer.server.query_and_chat.token_limit import _is_rate_limited
|
||||
from danswer.server.query_and_chat.token_limit import _user_is_rate_limited_by_global
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from ee.danswer.db.api_key import is_api_key_email_address
|
||||
from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits
|
||||
|
||||
|
||||
def _check_token_rate_limits(user: User | None, tenant_id: str | None) -> None:
|
||||
|
||||
@@ -7,6 +7,7 @@ from fastapi import Response
|
||||
from danswer.auth.users import auth_backend
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import get_jwt_strategy
|
||||
from danswer.auth.users import get_tenant_id_for_email
|
||||
from danswer.auth.users import User
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
@@ -14,6 +15,7 @@ from danswer.db.notification import create_notification
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.server.settings.store import load_settings
|
||||
from danswer.server.settings.store import store_settings
|
||||
from danswer.setup import setup_danswer
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.auth.users import current_cloud_superuser
|
||||
from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY
|
||||
@@ -21,9 +23,15 @@ from ee.danswer.server.tenants.access import control_plane_dep
|
||||
from ee.danswer.server.tenants.billing import fetch_billing_information
|
||||
from ee.danswer.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.danswer.server.tenants.models import BillingInformation
|
||||
from ee.danswer.server.tenants.models import CreateTenantRequest
|
||||
from ee.danswer.server.tenants.models import ImpersonateRequest
|
||||
from ee.danswer.server.tenants.models import ProductGatingRequest
|
||||
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
|
||||
from ee.danswer.server.tenants.provisioning import configure_default_api_keys
|
||||
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
|
||||
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
|
||||
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
@@ -32,6 +40,52 @@ logger = setup_logger()
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/create")
|
||||
def create_tenant(
|
||||
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
|
||||
) -> dict[str, str]:
|
||||
if not MULTI_TENANT:
|
||||
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
|
||||
|
||||
tenant_id = create_tenant_request.tenant_id
|
||||
email = create_tenant_request.initial_admin_email
|
||||
token = None
|
||||
|
||||
if user_owns_a_tenant(email):
|
||||
raise HTTPException(
|
||||
status_code=409, detail="User already belongs to an organization"
|
||||
)
|
||||
|
||||
try:
|
||||
if not ensure_schema_exists(tenant_id):
|
||||
logger.info(f"Created schema for tenant {tenant_id}")
|
||||
else:
|
||||
logger.info(f"Schema already exists for tenant {tenant_id}")
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
run_alembic_migrations(tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
setup_danswer(db_session, tenant_id)
|
||||
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Tenant {tenant_id} created successfully",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create tenant {tenant_id}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create tenant: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
|
||||
@@ -33,8 +33,3 @@ class CheckoutSessionCreationResponse(BaseModel):
|
||||
|
||||
class ImpersonateRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class TenantCreationPayload(BaseModel):
|
||||
tenant_id: str
|
||||
email: str
|
||||
|
||||
@@ -1,210 +1,145 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
|
||||
import aiohttp # Async HTTP client
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.schema import CreateSchema
|
||||
|
||||
from danswer.auth.users import exceptions
|
||||
from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.llm import update_default_provider
|
||||
from danswer.db.llm import upsert_cloud_embedding_provider
|
||||
from danswer.db.llm import upsert_llm_provider
|
||||
from danswer.db.models import UserTenantMapping
|
||||
from danswer.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
|
||||
from danswer.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
|
||||
from danswer.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
|
||||
from danswer.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from danswer.setup import setup_danswer
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
|
||||
from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
from ee.danswer.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.danswer.server.tenants.access import generate_data_plane_token
|
||||
from ee.danswer.server.tenants.models import TenantCreationPayload
|
||||
from ee.danswer.server.tenants.schema_management import create_schema_if_not_exists
|
||||
from ee.danswer.server.tenants.schema_management import drop_schema
|
||||
from ee.danswer.server.tenants.schema_management import run_alembic_migrations
|
||||
from ee.danswer.server.tenants.user_mapping import add_users_to_tenant
|
||||
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.danswer.server.tenants.user_mapping import user_owns_a_tenant
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
async def get_or_create_tenant_id(email: str) -> str:
|
||||
"""Get existing tenant ID for an email or create a new tenant if none exists."""
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
def run_alembic_migrations(schema_name: str) -> None:
|
||||
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
|
||||
|
||||
try:
|
||||
tenant_id = get_tenant_id_for_email(email)
|
||||
except exceptions.UserNotExists:
|
||||
# If tenant does not exist and in Multi tenant mode, provision a new tenant
|
||||
try:
|
||||
tenant_id = await create_tenant(email)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))
|
||||
alembic_ini_path = os.path.join(root_dir, "alembic.ini")
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User does not belong to an organization"
|
||||
# Configure Alembic
|
||||
alembic_cfg = Config(alembic_ini_path)
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string())
|
||||
alembic_cfg.set_main_option(
|
||||
"script_location", os.path.join(root_dir, "alembic")
|
||||
)
|
||||
|
||||
return tenant_id
|
||||
# Ensure that logging isn't broken
|
||||
alembic_cfg.attributes["configure_logger"] = False
|
||||
|
||||
# Mimic command-line options by adding 'cmd_opts' to the config
|
||||
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
|
||||
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
|
||||
|
||||
async def create_tenant(email: str) -> str:
|
||||
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
|
||||
try:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
# Notify control plane
|
||||
await notify_control_plane(tenant_id, email)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
return tenant_id
|
||||
# Run migrations programmatically
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
|
||||
async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
if not MULTI_TENANT:
|
||||
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
|
||||
|
||||
if user_owns_a_tenant(email):
|
||||
raise HTTPException(
|
||||
status_code=409, detail="User already belongs to an organization"
|
||||
# Run migrations programmatically
|
||||
logger.info(
|
||||
f"Alembic migrations completed successfully for schema: {schema_name}"
|
||||
)
|
||||
|
||||
logger.info(f"Provisioning tenant: {tenant_id}")
|
||||
token = None
|
||||
|
||||
try:
|
||||
if not create_schema_if_not_exists(tenant_id):
|
||||
logger.info(f"Created schema for tenant {tenant_id}")
|
||||
else:
|
||||
logger.info(f"Schema already exists for tenant {tenant_id}")
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Await the Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
setup_danswer(db_session, tenant_id)
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create tenant {tenant_id}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create tenant: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
async def notify_control_plane(tenant_id: str, email: str) -> None:
|
||||
logger.info("Fetching billing information")
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = TenantCreationPayload(tenant_id=tenant_id, email=email)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{CONTROL_PLANE_API_BASE_URL}/tenants/create",
|
||||
headers=headers,
|
||||
json=payload.model_dump(),
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Control plane tenant creation failed: {error_text}")
|
||||
raise Exception(
|
||||
f"Failed to create tenant on control plane: {error_text}"
|
||||
)
|
||||
|
||||
|
||||
async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
# Logic to rollback tenant provisioning on data plane
|
||||
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
|
||||
try:
|
||||
# Drop the tenant's schema to rollback provisioning
|
||||
drop_schema(tenant_id)
|
||||
# Remove tenant mapping
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to rollback tenant provisioning: {e}")
|
||||
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4",
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
)
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider="OpenAI",
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4o",
|
||||
)
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
provider="Anthropic",
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-5-sonnet-20240620",
|
||||
)
|
||||
upsert_llm_provider(open_provider, db_session)
|
||||
upsert_llm_provider(anthropic_provider, db_session)
|
||||
|
||||
if ANTHROPIC_DEFAULT_API_KEY:
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-5-sonnet-20241022",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Anthropic provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
)
|
||||
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
api_key=COHERE_DEFAULT_API_KEY,
|
||||
)
|
||||
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
|
||||
|
||||
if COHERE_DEFAULT_API_KEY:
|
||||
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
api_key=COHERE_DEFAULT_API_KEY,
|
||||
|
||||
def ensure_schema_exists(tenant_id: str) -> bool:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with db_session.begin():
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"
|
||||
),
|
||||
{"schema_name": tenant_id},
|
||||
)
|
||||
schema_exists = result.scalar() is not None
|
||||
if not schema_exists:
|
||||
stmt = CreateSchema(tenant_id)
|
||||
db_session.execute(stmt)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# For now, we're implementing a primitive mapping between users and tenants.
|
||||
# This function is only used to determine a user's relationship to a tenant upon creation (implying ownership).
|
||||
def user_owns_a_tenant(email: str) -> bool:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(UserTenantMapping.email == email)
|
||||
.first()
|
||||
)
|
||||
return result is not None
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure Cohere embedding provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration"
|
||||
)
|
||||
logger.exception(f"Failed to add users to tenant {tenant_id}: {str(e)}")
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
mappings_to_delete = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email.in_(emails),
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for mapping in mappings_to_delete:
|
||||
db_session.delete(mapping)
|
||||
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
|
||||
)
|
||||
db_session.rollback()
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
from types import SimpleNamespace
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.schema import CreateSchema
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def run_alembic_migrations(schema_name: str) -> None:
|
||||
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
|
||||
|
||||
try:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))
|
||||
alembic_ini_path = os.path.join(root_dir, "alembic.ini")
|
||||
|
||||
# Configure Alembic
|
||||
alembic_cfg = Config(alembic_ini_path)
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string())
|
||||
alembic_cfg.set_main_option(
|
||||
"script_location", os.path.join(root_dir, "alembic")
|
||||
)
|
||||
|
||||
# Ensure that logging isn't broken
|
||||
alembic_cfg.attributes["configure_logger"] = False
|
||||
|
||||
# Mimic command-line options by adding 'cmd_opts' to the config
|
||||
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
|
||||
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
|
||||
|
||||
# Run migrations programmatically
|
||||
command.upgrade(alembic_cfg, "head")
|
||||
|
||||
# Run migrations programmatically
|
||||
logger.info(
|
||||
f"Alembic migrations completed successfully for schema: {schema_name}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
def create_schema_if_not_exists(tenant_id: str) -> bool:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with db_session.begin():
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"
|
||||
),
|
||||
{"schema_name": tenant_id},
|
||||
)
|
||||
schema_exists = result.scalar() is not None
|
||||
if not schema_exists:
|
||||
stmt = CreateSchema(tenant_id)
|
||||
db_session.execute(stmt)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def drop_schema(tenant_id: str) -> None:
|
||||
if not tenant_id.isidentifier():
|
||||
raise ValueError("Invalid tenant_id.")
|
||||
with get_sqlalchemy_engine().connect() as connection:
|
||||
connection.execute(
|
||||
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
|
||||
{"schema_name": tenant_id},
|
||||
)
|
||||
@@ -1,70 +0,0 @@
|
||||
import logging
|
||||
|
||||
from fastapi_users import exceptions
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import UserTenantMapping
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_tenant_id_for_email(email: str) -> str:
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
# Implement logic to get tenant_id from the mapping table
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
|
||||
)
|
||||
tenant_id = result.scalar_one_or_none()
|
||||
if tenant_id is None:
|
||||
raise exceptions.UserNotExists()
|
||||
return tenant_id
|
||||
|
||||
|
||||
def user_owns_a_tenant(email: str) -> bool:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(UserTenantMapping.email == email)
|
||||
.first()
|
||||
)
|
||||
return result is not None
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
except Exception:
|
||||
logger.exception(f"Failed to add users to tenant {tenant_id}")
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
mappings_to_delete = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email.in_(emails),
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for mapping in mappings_to_delete:
|
||||
db_session.delete(mapping)
|
||||
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
|
||||
)
|
||||
db_session.rollback()
|
||||
@@ -8,14 +8,14 @@ from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.token_limit import fetch_all_user_token_rate_limits
|
||||
from danswer.db.token_limit import insert_user_token_rate_limit
|
||||
from danswer.server.query_and_chat.token_limit import any_rate_limit_exists
|
||||
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
|
||||
from danswer.server.token_rate_limits.models import TokenRateLimitDisplay
|
||||
from ee.danswer.db.token_limit import fetch_all_user_group_token_rate_limits_by_group
|
||||
from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits
|
||||
from ee.danswer.db.token_limit import fetch_user_group_token_rate_limits
|
||||
from ee.danswer.db.token_limit import insert_user_group_token_rate_limit
|
||||
from ee.danswer.db.token_limit import insert_user_token_rate_limit
|
||||
|
||||
router = APIRouter(prefix="/admin/token-rate-limits")
|
||||
|
||||
|
||||
@@ -15,9 +15,9 @@ docker rm danswer_postgres danswer_vespa danswer_redis
|
||||
# Start the PostgreSQL container with optional volume
|
||||
echo "Starting PostgreSQL container..."
|
||||
if [[ -n "$POSTGRES_VOLUME" ]]; then
|
||||
docker run -p 5432:5432 --name danswer_postgres -e POSTGRES_PASSWORD=password -d -v $POSTGRES_VOLUME:/var/lib/postgresql/data postgres -c max_connections=250
|
||||
docker run -p 5432:5432 --name danswer_postgres -e POSTGRES_PASSWORD=password -d -v $POSTGRES_VOLUME:/var/lib/postgresql/data postgres
|
||||
else
|
||||
docker run -p 5432:5432 --name danswer_postgres -e POSTGRES_PASSWORD=password -d postgres -c max_connections=250
|
||||
docker run -p 5432:5432 --name danswer_postgres -e POSTGRES_PASSWORD=password -d postgres
|
||||
fi
|
||||
|
||||
# Start the Vespa container with optional volume
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from urllib.parse import urlparse
|
||||
|
||||
@@ -134,11 +133,6 @@ MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
|
||||
|
||||
POSTGRES_DEFAULT_SCHEMA = os.environ.get("POSTGRES_DEFAULT_SCHEMA") or "public"
|
||||
|
||||
|
||||
async def async_return_default_schema(*args: Any, **kwargs: Any) -> str:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
|
||||
# Prefix used for all tenant ids
|
||||
TENANT_ID_PREFIX = "tenant_"
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from uuid import uuid4
|
||||
import requests
|
||||
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.server.api_key.models import APIKeyArgs
|
||||
from ee.danswer.server.api_key.models import APIKeyArgs
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
|
||||
10
ct.yaml
10
ct.yaml
@@ -1,18 +1,12 @@
|
||||
# See https://github.com/helm/chart-testing#configuration
|
||||
|
||||
# still have to specify this on the command line for list-changed
|
||||
chart-dirs:
|
||||
- deployment/helm/charts
|
||||
|
||||
# must be kept in sync with Chart.yaml
|
||||
chart-repos:
|
||||
- vespa=https://danswer-ai.github.io/vespa-helm-charts
|
||||
- vespa=https://unoplat.github.io/vespa-helm-charts
|
||||
- postgresql=https://charts.bitnami.com/bitnami
|
||||
|
||||
helm-extra-args: --debug --timeout 600s
|
||||
|
||||
# nginx appears to not work on kind, likely due to lack of loadbalancer support
|
||||
# helm-extra-set-args also only works on the command line, not in this yaml
|
||||
# helm-extra-set-args: --set=nginx.enabled=false
|
||||
helm-extra-args: --timeout 600s
|
||||
|
||||
validate-maintainers: false
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: TriggerAuthentication
|
||||
metadata:
|
||||
name: celery-worker-auth
|
||||
namespace: danswer
|
||||
spec:
|
||||
secretTargetRef:
|
||||
- parameter: host
|
||||
name: keda-redis-secret
|
||||
key: host
|
||||
- parameter: password
|
||||
name: keda-redis-secret
|
||||
key: password
|
||||
@@ -1,46 +0,0 @@
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
name: celery-worker-indexing-scaledobject
|
||||
namespace: danswer
|
||||
labels:
|
||||
app: celery-worker-indexing
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
name: celery-worker-indexing
|
||||
minReplicaCount: 1
|
||||
maxReplicaCount: 10
|
||||
triggers:
|
||||
- type: redis
|
||||
metadata:
|
||||
sslEnabled: "true"
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_indexing
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
sslEnabled: "true"
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_indexing:2
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
sslEnabled: "true"
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_indexing:3
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
@@ -1,63 +0,0 @@
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
name: celery-worker-light-scaledobject
|
||||
namespace: danswer
|
||||
labels:
|
||||
app: celery-worker-light
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
name: celery-worker-light
|
||||
minReplicaCount: 1
|
||||
maxReplicaCount: 20
|
||||
triggers:
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: vespa_metadata_sync
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: vespa_metadata_sync:2
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: vespa_metadata_sync:3
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_deletion
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_deletion:2
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
@@ -1,76 +0,0 @@
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
name: celery-worker-primary-scaledobject
|
||||
namespace: danswer
|
||||
labels:
|
||||
app: celery-worker-primary
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
name: celery-worker-primary
|
||||
pollingInterval: 15 # Check every 15 seconds
|
||||
cooldownPeriod: 30 # Wait 30 seconds before scaling down
|
||||
minReplicaCount: 1
|
||||
maxReplicaCount: 1
|
||||
triggers:
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery:1
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery:2
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery:3
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: periodic_tasks
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
host: "{host}"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: periodic_tasks:2
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
@@ -1,9 +0,0 @@
|
||||
apiVersion: v1
|
||||
kind: Secret
|
||||
metadata:
|
||||
name: keda-redis-secret
|
||||
namespace: danswer
|
||||
type: Opaque
|
||||
data:
|
||||
host: { { base64-encoded-hostname } }
|
||||
password: { { base64-encoded-password } }
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-beat
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
|
||||
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
@@ -31,7 +31,7 @@ spec:
|
||||
name: danswer-secrets
|
||||
key: redis_password
|
||||
- name: DANSWER_VERSION
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
value: "v0.11.0-cloud.beta.4"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-heavy
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
|
||||
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
@@ -34,7 +34,7 @@ spec:
|
||||
name: danswer-secrets
|
||||
key: redis_password
|
||||
- name: DANSWER_VERSION
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
value: "v0.11.0-cloud.beta.4"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-indexing
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
|
||||
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
@@ -26,8 +26,6 @@ spec:
|
||||
"--hostname=indexing@%n",
|
||||
"-Q",
|
||||
"connector_indexing",
|
||||
"--prefetch-multiplier=1",
|
||||
"--concurrency=10",
|
||||
]
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
@@ -36,7 +34,7 @@ spec:
|
||||
name: danswer-secrets
|
||||
key: redis_password
|
||||
- name: DANSWER_VERSION
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
value: "v0.11.0-cloud.beta.4"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-light
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
|
||||
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
@@ -26,8 +26,6 @@ spec:
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion",
|
||||
"--prefetch-multiplier=1",
|
||||
"--concurrency=10",
|
||||
]
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
@@ -36,7 +34,7 @@ spec:
|
||||
name: danswer-secrets
|
||||
key: redis_password
|
||||
- name: DANSWER_VERSION
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
value: "v0.11.0-cloud.beta.4"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-primary
|
||||
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
|
||||
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
|
||||
imagePullPolicy: Always
|
||||
command:
|
||||
[
|
||||
@@ -25,9 +25,7 @@ spec:
|
||||
"--loglevel=INFO",
|
||||
"--hostname=primary@%n",
|
||||
"-Q",
|
||||
"celery,periodic_tasks",
|
||||
"--prefetch-multiplier=1",
|
||||
"--concurrency=10",
|
||||
"celery,periodic_tasks,vespa_metadata_sync",
|
||||
]
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
@@ -36,7 +34,7 @@ spec:
|
||||
name: danswer-secrets
|
||||
key: redis_password
|
||||
- name: DANSWER_VERSION
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
value: "v0.11.0-cloud.beta.4"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
|
||||
@@ -221,10 +221,6 @@ services:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# Uncomment the following lines if you need to include a custom CA certificate
|
||||
# This section enables the use of a custom CA certificate
|
||||
# If present, the custom CA certificate is mounted as a volume
|
||||
# The container checks for its existence and updates the system's CA certificates
|
||||
# This allows for secure communication with services using custom SSL certificates
|
||||
# Optional volume mount for CA certificate
|
||||
# volumes:
|
||||
# # Maps to the CA_CERT_PATH environment variable in the Dockerfile
|
||||
|
||||
@@ -65,10 +65,6 @@ services:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# Uncomment the following lines if you need to include a custom CA certificate
|
||||
# This section enables the use of a custom CA certificate
|
||||
# If present, the custom CA certificate is mounted as a volume
|
||||
# The container checks for its existence and updates the system's CA certificates
|
||||
# This allows for secure communication with services using custom SSL certificates
|
||||
# volumes:
|
||||
# # Maps to the CA_CERT_PATH environment variable in the Dockerfile
|
||||
# - ${CA_CERT_PATH:-./custom-ca.crt}:/etc/ssl/certs/custom-ca.crt:ro
|
||||
|
||||
@@ -3,13 +3,13 @@ dependencies:
|
||||
repository: https://charts.bitnami.com/bitnami
|
||||
version: 14.3.1
|
||||
- name: vespa
|
||||
repository: https://danswer-ai.github.io/vespa-helm-charts
|
||||
version: 0.2.16
|
||||
repository: https://unoplat.github.io/vespa-helm-charts
|
||||
version: 0.2.3
|
||||
- name: nginx
|
||||
repository: oci://registry-1.docker.io/bitnamicharts
|
||||
version: 15.14.0
|
||||
- name: redis
|
||||
repository: https://charts.bitnami.com/bitnami
|
||||
version: 20.1.0
|
||||
digest: sha256:711bbb76ba6ab604a36c9bf1839ab6faa5610afb21e535afd933c78f2d102232
|
||||
generated: "2024-11-07T09:39:30.17171-08:00"
|
||||
digest: sha256:fb42426c1d13667a4929d0d6a7d681bf08120e4a4eb1d15437e4ec70920be3f8
|
||||
generated: "2024-09-11T09:16:03.312328-07:00"
|
||||
|
||||
@@ -5,7 +5,7 @@ home: https://www.danswer.ai/
|
||||
sources:
|
||||
- "https://github.com/danswer-ai/danswer"
|
||||
type: application
|
||||
version: 0.2.1
|
||||
version: 0.2.0
|
||||
appVersion: "latest"
|
||||
annotations:
|
||||
category: Productivity
|
||||
@@ -23,8 +23,8 @@ dependencies:
|
||||
repository: https://charts.bitnami.com/bitnami
|
||||
condition: postgresql.enabled
|
||||
- name: vespa
|
||||
version: 0.2.16
|
||||
repository: https://danswer-ai.github.io/vespa-helm-charts
|
||||
version: 0.2.3
|
||||
repository: https://unoplat.github.io/vespa-helm-charts
|
||||
condition: vespa.enabled
|
||||
- name: nginx
|
||||
version: 15.14.0
|
||||
|
||||
@@ -7,7 +7,7 @@ metadata:
|
||||
data:
|
||||
INTERNAL_URL: "http://{{ include "danswer-stack.fullname" . }}-api-service:{{ .Values.api.service.port | default 8080 }}"
|
||||
POSTGRES_HOST: {{ .Release.Name }}-postgresql
|
||||
VESPA_HOST: da-vespa-0.vespa-service
|
||||
VESPA_HOST: "document-index-service"
|
||||
REDIS_HOST: {{ .Release.Name }}-redis-master
|
||||
MODEL_SERVER_HOST: "{{ include "danswer-stack.fullname" . }}-inference-model-service"
|
||||
INDEXING_MODEL_SERVER_HOST: "{{ include "danswer-stack.fullname" . }}-indexing-model-service"
|
||||
|
||||
@@ -11,5 +11,5 @@ spec:
|
||||
- name: wget
|
||||
image: busybox
|
||||
command: ['wget']
|
||||
args: ['{{ include "danswer-stack.fullname" . }}-webserver:{{ .Values.webserver.service.port }}']
|
||||
args: ['{{ include "danswer-stack.fullname" . }}:{{ .Values.webserver.service.port }}']
|
||||
restartPolicy: Never
|
||||
|
||||
@@ -17,13 +17,11 @@ spec:
|
||||
image: danswer/danswer-backend:latest
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
- "/bin/sh"
|
||||
- "-c"
|
||||
- |
|
||||
if [ -f /etc/ssl/certs/custom-ca.crt ]; then
|
||||
update-ca-certificates;
|
||||
fi &&
|
||||
/usr/bin/supervisord -c /etc/supervisor/conf.d/supervisord.conf
|
||||
[
|
||||
"/usr/bin/supervisord",
|
||||
"-c",
|
||||
"/etc/supervisor/conf.d/supervisord.conf",
|
||||
]
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
valueFrom:
|
||||
@@ -34,10 +32,6 @@ spec:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
# Uncomment the following lines if you need to include a custom CA certificate
|
||||
# This section allows for the inclusion of a custom CA certificate
|
||||
# If a custom CA certificate is present, it updates the system's CA certificates
|
||||
# This is useful for environments with self-signed or internal CA certificates
|
||||
# The certificate is mounted as a volume and the container checks for its presence
|
||||
# Optional volume mount for CA certificate
|
||||
# volumeMounts:
|
||||
# - name: my-ca-cert-volume
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 10 KiB |
@@ -11,7 +11,6 @@ import {
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
} from "../../configuration/llm/constants";
|
||||
import { mutate } from "swr";
|
||||
import { testEmbedding } from "../pages/utils";
|
||||
|
||||
export function ChangeCredentialsModal({
|
||||
provider,
|
||||
@@ -113,15 +112,16 @@ export function ChangeCredentialsModal({
|
||||
const normalizedProviderType = provider.provider_type
|
||||
.toLowerCase()
|
||||
.split(" ")[0];
|
||||
|
||||
try {
|
||||
const testResponse = await testEmbedding({
|
||||
provider_type: normalizedProviderType,
|
||||
modelName,
|
||||
apiKey,
|
||||
apiUrl,
|
||||
apiVersion: null,
|
||||
deploymentName: null,
|
||||
const testResponse = await fetch("/api/admin/embedding/test-embedding", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider_type: normalizedProviderType,
|
||||
api_key: apiKey,
|
||||
api_url: apiUrl,
|
||||
model_name: modelName,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!testResponse.ok) {
|
||||
|
||||
@@ -110,27 +110,20 @@ export function ProviderCreationModal({
|
||||
setErrorMsg("");
|
||||
try {
|
||||
const customConfig = Object.fromEntries(values.custom_config);
|
||||
const providerType = values.provider_type.toLowerCase().split(" ")[0];
|
||||
const isOpenAI = providerType === "openai";
|
||||
|
||||
const testModelName =
|
||||
isOpenAI || isAzure ? "text-embedding-3-small" : values.model_name;
|
||||
|
||||
const testEmbeddingPayload = {
|
||||
provider_type: providerType,
|
||||
api_key: values.api_key,
|
||||
api_url: values.api_url,
|
||||
model_name: testModelName,
|
||||
api_version: values.api_version,
|
||||
deployment_name: values.deployment_name,
|
||||
};
|
||||
|
||||
const initialResponse = await fetch(
|
||||
"/api/admin/embedding/test-embedding",
|
||||
{
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(testEmbeddingPayload),
|
||||
body: JSON.stringify({
|
||||
provider_type: values.provider_type.toLowerCase().split(" ")[0],
|
||||
api_key: values.api_key,
|
||||
api_url: values.api_url,
|
||||
model_name: values.model_name,
|
||||
api_version: values.api_version,
|
||||
deployment_name: values.deployment_name,
|
||||
}),
|
||||
}
|
||||
);
|
||||
|
||||
|
||||
@@ -8,37 +8,3 @@ export const deleteSearchSettings = async (search_settings_id: number) => {
|
||||
});
|
||||
return response;
|
||||
};
|
||||
|
||||
export const testEmbedding = async ({
|
||||
provider_type,
|
||||
modelName,
|
||||
apiKey,
|
||||
apiUrl,
|
||||
apiVersion,
|
||||
deploymentName,
|
||||
}: {
|
||||
provider_type: string;
|
||||
modelName: string;
|
||||
apiKey: string | null;
|
||||
apiUrl: string | null;
|
||||
apiVersion: string | null;
|
||||
deploymentName: string | null;
|
||||
}) => {
|
||||
const testModelName =
|
||||
provider_type === "openai" ? "text-embedding-3-small" : modelName;
|
||||
|
||||
const testResponse = await fetch("/api/admin/embedding/test-embedding", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider_type: provider_type,
|
||||
api_key: apiKey,
|
||||
api_url: apiUrl,
|
||||
model_name: testModelName,
|
||||
api_version: apiVersion,
|
||||
deployment_name: deploymentName,
|
||||
}),
|
||||
});
|
||||
|
||||
return testResponse;
|
||||
};
|
||||
|
||||
@@ -4,6 +4,7 @@ import { fetchChatData } from "@/lib/chat/fetchChatData";
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import { redirect } from "next/navigation";
|
||||
import WrappedAssistantsGallery from "./WrappedAssistantsGallery";
|
||||
import { AssistantsProvider } from "@/components/context/AssistantsContext";
|
||||
import { cookies } from "next/headers";
|
||||
|
||||
export default async function GalleryPage(props: {
|
||||
|
||||
@@ -131,8 +131,6 @@ export function ChatPage({
|
||||
|
||||
const {
|
||||
chatSessions,
|
||||
availableSources,
|
||||
availableDocumentSets,
|
||||
llmProviders,
|
||||
folders,
|
||||
openedFolders,
|
||||
@@ -2173,6 +2171,25 @@ export function ChatPage({
|
||||
) {
|
||||
return <></>;
|
||||
}
|
||||
const mostRecentNonAIParent = messageHistory
|
||||
.slice(0, i)
|
||||
.reverse()
|
||||
.find((msg) => msg.type !== "assistant");
|
||||
|
||||
const hasChildMessage =
|
||||
message.latestChildMessageId !== null &&
|
||||
message.latestChildMessageId !== undefined;
|
||||
const childMessage = hasChildMessage
|
||||
? messageMap.get(
|
||||
message.latestChildMessageId!
|
||||
)
|
||||
: null;
|
||||
|
||||
const hasParentAI =
|
||||
parentMessage?.type == "assistant";
|
||||
const hasChildAI =
|
||||
childMessage?.type == "assistant";
|
||||
|
||||
return (
|
||||
<div
|
||||
id={`message-${message.messageId}`}
|
||||
@@ -2184,6 +2201,9 @@ export function ChatPage({
|
||||
}
|
||||
>
|
||||
<AIMessage
|
||||
setPopup={setPopup}
|
||||
hasChildAI={hasChildAI}
|
||||
hasParentAI={hasParentAI}
|
||||
continueGenerating={
|
||||
i == messageHistory.length - 1 &&
|
||||
currentCanContinue()
|
||||
@@ -2193,7 +2213,7 @@ export function ChatPage({
|
||||
overriddenModel={message.overridden_model}
|
||||
regenerate={createRegenerator({
|
||||
messageId: message.messageId,
|
||||
parentMessage: parentMessage!,
|
||||
parentMessage: mostRecentNonAIParent!,
|
||||
})}
|
||||
otherMessagesCanSwitchTo={
|
||||
parentMessage?.childrenMessageIds || []
|
||||
@@ -2340,6 +2360,7 @@ export function ChatPage({
|
||||
return (
|
||||
<div key={messageReactComponentKey}>
|
||||
<AIMessage
|
||||
setPopup={setPopup}
|
||||
currentPersona={liveAssistant}
|
||||
messageId={message.messageId}
|
||||
content={
|
||||
@@ -2382,6 +2403,7 @@ export function ChatPage({
|
||||
key={`${messageHistory.length}-${chatSessionIdRef.current}`}
|
||||
>
|
||||
<AIMessage
|
||||
setPopup={setPopup}
|
||||
key={-3}
|
||||
currentPersona={liveAssistant}
|
||||
alternativeAssistant={
|
||||
@@ -2406,6 +2428,7 @@ export function ChatPage({
|
||||
{loadingError && (
|
||||
<div key={-1}>
|
||||
<AIMessage
|
||||
setPopup={setPopup}
|
||||
currentPersona={liveAssistant}
|
||||
messageId={-1}
|
||||
content={
|
||||
|
||||
@@ -144,3 +144,10 @@ export interface StreamingError {
|
||||
error: string;
|
||||
stack_trace: string;
|
||||
}
|
||||
|
||||
export interface ImageGenerationResult {
|
||||
revised_prompt: string;
|
||||
url: string;
|
||||
}
|
||||
|
||||
export type ImageGenerationResults = ImageGenerationResult[];
|
||||
|
||||
@@ -1,29 +1,73 @@
|
||||
import { Citation } from "@/components/search/results/Citation";
|
||||
import React, { memo } from "react";
|
||||
import { IMAGE_GENERATION_TOOL_NAME } from "../tools/constants";
|
||||
|
||||
export const MemoizedLink = memo((props: any) => {
|
||||
const { node, ...rest } = props;
|
||||
const value = rest.children;
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/ui/popover";
|
||||
import { SearchIcon } from "lucide-react";
|
||||
import DualPromptDisplay from "../tools/ImageCitation";
|
||||
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { ImageGenerationResults, ToolCallFinalResult } from "../interfaces";
|
||||
|
||||
if (value?.toString().startsWith("*")) {
|
||||
return (
|
||||
<div className="flex-none bg-background-800 inline-block rounded-full h-3 w-3 ml-2" />
|
||||
);
|
||||
} else if (value?.toString().startsWith("[")) {
|
||||
return <Citation link={rest?.href}>{rest.children}</Citation>;
|
||||
} else {
|
||||
return (
|
||||
<a
|
||||
onMouseDown={() =>
|
||||
rest.href ? window.open(rest.href, "_blank") : undefined
|
||||
}
|
||||
className="cursor-pointer text-link hover:text-link-hover"
|
||||
>
|
||||
{rest.children}
|
||||
</a>
|
||||
);
|
||||
export const MemoizedLink = memo(
|
||||
({
|
||||
toolCall,
|
||||
setPopup,
|
||||
...props
|
||||
}: {
|
||||
toolCall?: ToolCallFinalResult;
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
} & any) => {
|
||||
const { node, ...rest } = props;
|
||||
const value = rest.children;
|
||||
|
||||
if (value?.toString().startsWith(IMAGE_GENERATION_TOOL_NAME)) {
|
||||
const imageGenerationResult =
|
||||
toolCall?.tool_result as ImageGenerationResults;
|
||||
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<span className="inline-block">
|
||||
<SearchIcon className="cursor-pointer flex-none text-blue-500 hover:text-blue-700 !h-4 !w-4 inline-block" />
|
||||
</span>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="w-96" side="top" align="center">
|
||||
<DualPromptDisplay
|
||||
arg="Prompt"
|
||||
setPopup={setPopup!}
|
||||
prompts={imageGenerationResult.map(
|
||||
(result) => result.revised_prompt
|
||||
)}
|
||||
/>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
|
||||
if (value?.toString().startsWith("*")) {
|
||||
return (
|
||||
<div className="flex-none bg-background-800 inline-block rounded-full h-3 w-3 ml-2" />
|
||||
);
|
||||
} else if (value?.toString().startsWith("[")) {
|
||||
return <Citation link={rest?.href}>{rest.children}</Citation>;
|
||||
} else {
|
||||
return (
|
||||
<a
|
||||
onMouseDown={() =>
|
||||
rest.href ? window.open(rest.href, "_blank") : undefined
|
||||
}
|
||||
className="cursor-pointer text-link hover:text-link-hover"
|
||||
>
|
||||
{rest.children}
|
||||
</a>
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
);
|
||||
|
||||
export const MemoizedParagraph = memo(({ ...props }: any) => {
|
||||
return <p {...props} className="text-default" />;
|
||||
|
||||
@@ -62,6 +62,7 @@ import { MemoizedLink, MemoizedParagraph } from "./MemoizedTextComponents";
|
||||
import { extractCodeText } from "./codeUtils";
|
||||
import ToolResult from "../../../components/tools/ToolResult";
|
||||
import CsvContent from "../../../components/tools/CSVContent";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
|
||||
const TOOLS_WITH_CUSTOM_HANDLING = [
|
||||
SEARCH_TOOL_NAME,
|
||||
@@ -153,6 +154,9 @@ function FileDisplay({
|
||||
}
|
||||
|
||||
export const AIMessage = ({
|
||||
setPopup,
|
||||
hasChildAI,
|
||||
hasParentAI,
|
||||
regenerate,
|
||||
overriddenModel,
|
||||
continueGenerating,
|
||||
@@ -179,6 +183,9 @@ export const AIMessage = ({
|
||||
otherMessagesCanSwitchTo,
|
||||
onMessageSelection,
|
||||
}: {
|
||||
setPopup?: (popupSpec: PopupSpec | null) => void;
|
||||
hasChildAI?: boolean;
|
||||
hasParentAI?: boolean;
|
||||
shared?: boolean;
|
||||
isActive?: boolean;
|
||||
continueGenerating?: () => void;
|
||||
@@ -227,6 +234,13 @@ export const AIMessage = ({
|
||||
return content;
|
||||
}
|
||||
}
|
||||
if (
|
||||
isComplete &&
|
||||
toolCall?.tool_result &&
|
||||
toolCall.tool_name == IMAGE_GENERATION_TOOL_NAME
|
||||
) {
|
||||
return content + ` [${toolCall.tool_name}]()`;
|
||||
}
|
||||
|
||||
return content + (!isComplete && !toolCallGenerating ? " [*]() " : "");
|
||||
};
|
||||
@@ -296,7 +310,9 @@ export const AIMessage = ({
|
||||
|
||||
const markdownComponents = useMemo(
|
||||
() => ({
|
||||
a: MemoizedLink,
|
||||
a: (props: any) => (
|
||||
<MemoizedLink {...props} toolCall={toolCall} setPopup={setPopup} />
|
||||
),
|
||||
p: MemoizedParagraph,
|
||||
code: ({ node, className, children, ...props }: any) => {
|
||||
const codeText = extractCodeText(
|
||||
@@ -312,7 +328,7 @@ export const AIMessage = ({
|
||||
);
|
||||
},
|
||||
}),
|
||||
[finalContent]
|
||||
[finalContent, toolCall]
|
||||
);
|
||||
|
||||
const renderedMarkdown = useMemo(() => {
|
||||
@@ -338,7 +354,7 @@ export const AIMessage = ({
|
||||
<div
|
||||
id="danswer-ai-message"
|
||||
ref={trackedElementRef}
|
||||
className={"py-5 ml-4 px-5 relative flex "}
|
||||
className={`${hasParentAI ? "pb-5" : "py-5"} px-2 lg:px-5 relative flex `}
|
||||
>
|
||||
<div
|
||||
className={`mx-auto ${
|
||||
@@ -347,10 +363,14 @@ export const AIMessage = ({
|
||||
>
|
||||
<div className={`desktop:mr-12 ${!shared && "mobile:ml-0 md:ml-8"}`}>
|
||||
<div className="flex">
|
||||
<AssistantIcon
|
||||
size="small"
|
||||
assistant={alternativeAssistant || currentPersona}
|
||||
/>
|
||||
{!hasParentAI ? (
|
||||
<AssistantIcon
|
||||
size="small"
|
||||
assistant={alternativeAssistant || currentPersona}
|
||||
/>
|
||||
) : (
|
||||
<div className="w-6" />
|
||||
)}
|
||||
|
||||
<div className="w-full">
|
||||
<div className="max-w-message-max break-words">
|
||||
@@ -514,7 +534,8 @@ export const AIMessage = ({
|
||||
)}
|
||||
</div>
|
||||
|
||||
{handleFeedback &&
|
||||
{!hasChildAI &&
|
||||
handleFeedback &&
|
||||
(isActive ? (
|
||||
<div
|
||||
className={`
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user