Compare commits

..

8 Commits

Author SHA1 Message Date
pablodanswer
59e9a33b30 k 2024-11-18 09:12:37 -08:00
pablodanswer
6e60437c56 quick nits 2024-11-17 21:52:00 -08:00
pablodanswer
9cde51f1a2 scalable but not formalized 2024-11-07 16:01:37 -08:00
pablodanswer
8b8952f117 k 2024-11-07 12:25:44 -08:00
pablodanswer
dc01eea610 add logs 2024-11-07 12:24:10 -08:00
pablodanswer
c89d8318c0 add image prompt citations 2024-11-07 12:22:46 -08:00
pablodanswer
3f2d6557dc functioning albeit janky 2024-11-07 11:57:00 -08:00
pablodanswer
b3818877af initial functioning update 2024-11-07 10:59:40 -08:00
114 changed files with 1325 additions and 1735 deletions

View File

@@ -3,61 +3,61 @@ name: Build and Push Backend Image on Tag
on:
push:
tags:
- "*"
- '*'
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-backend-cloud' || 'danswer/danswer-backend' }}
REGISTRY_IMAGE: danswer/danswer-backend
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# TODO: investigate a matrix build like the web container
# TODO: investigate a matrix build like the web container
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Install build-essential
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Install build-essential
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
trivyignores: ./backend/.trivyignore
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
# To run locally: trivy image --severity HIGH,CRITICAL danswer/danswer-backend
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'
trivyignores: ./backend/.trivyignore

View File

@@ -4,12 +4,12 @@ name: Build and Push Cloud Web Image on Tag
on:
push:
tags:
- "*"
- '*'
env:
REGISTRY_IMAGE: danswer/danswer-web-server-cloud
REGISTRY_IMAGE: danswer/danswer-cloud-web-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build:
runs-on:
@@ -28,11 +28,11 @@ jobs:
- name: Prepare
run: |
platform=${{ matrix.platform }}
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
echo "PLATFORM_PAIR=${platform//\//-}" >> $GITHUB_ENV
- name: Checkout
uses: actions/checkout@v4
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
@@ -41,16 +41,16 @@ jobs:
tags: |
type=raw,value=${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
type=raw,value=${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and push by digest
id: build
uses: docker/build-push-action@v5
@@ -65,17 +65,17 @@ jobs:
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
# needed due to weird interactions with the builds for different platforms
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
- name: Export digest
run: |
mkdir -p /tmp/digests
digest="${{ steps.build.outputs.digest }}"
touch "/tmp/digests/${digest#sha256:}"
touch "/tmp/digests/${digest#sha256:}"
- name: Upload digest
uses: actions/upload-artifact@v4
with:
@@ -95,42 +95,42 @@ jobs:
path: /tmp/digests
pattern: digests-*
merge-multiple: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Docker meta
id: meta
uses: docker/metadata-action@v5
with:
images: ${{ env.REGISTRY_IMAGE }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Create manifest list and push
working-directory: /tmp/digests
run: |
docker buildx imagetools create $(jq -cr '.tags | map("-t " + .) | join(" ")' <<< "$DOCKER_METADATA_OUTPUT_JSON") \
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
$(printf '${{ env.REGISTRY_IMAGE }}@sha256:%s ' *)
- name: Inspect image
run: |
docker buildx imagetools inspect ${{ env.REGISTRY_IMAGE }}:${{ steps.meta.outputs.version }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
severity: 'CRITICAL,HIGH'

View File

@@ -3,53 +3,53 @@ name: Build and Push Model Server Image on Tag
on:
push:
tags:
- "*"
- '*'
env:
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'danswer/danswer-model-server-cloud' || 'danswer/danswer-model-server' }}
REGISTRY_IMAGE: danswer/danswer-model-server
LATEST_TAG: ${{ contains(github.ref_name, 'latest') }}
jobs:
build-and-push:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Model Server Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
- name: Model Server Image Docker Build and Push
uses: docker/build-push-action@v5
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: |
${{ env.REGISTRY_IMAGE }}:${{ github.ref_name }}
${{ env.LATEST_TAG == 'true' && format('{0}:latest', env.REGISTRY_IMAGE) || '' }}
build-args: |
DANSWER_VERSION=${{ github.ref_name }}
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-db:2"
TRIVY_JAVA_DB_REPOSITORY: "public.ecr.aws/aquasecurity/trivy-java-db:1"
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: "CRITICAL,HIGH"
# trivy has their own rate limiting issues causing this action to flake
# we worked around it by hardcoding to different db repos in env
# can re-enable when they figure it out
# https://github.com/aquasecurity/trivy/discussions/7538
# https://github.com/aquasecurity/trivy-action/issues/389
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@master
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: docker.io/danswer/danswer-model-server:${{ github.ref_name }}
severity: 'CRITICAL,HIGH'

View File

@@ -210,18 +210,17 @@ jobs:
echo "All integration tests passed successfully."
fi
# save before stopping the containers so the logs can be captured
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Save Docker logs
if: success() || failure()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
- name: Upload logs
if: success() || failure()

View File

@@ -1,20 +1,24 @@
# This workflow is intentionally disabled while we're still working on it
# It's close to ready, but a race condition needs to be fixed with
# API server and Vespa startup, and it needs to have a way to build/test against
# local containers
name: Helm - Lint and Test Charts
on:
merge_group:
pull_request:
branches: [ main ]
workflow_dispatch: # Allows manual triggering
jobs:
helm-chart-check:
lint-test:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,hdd=256,"run-id=${{ github.run_id }}"]
# fetch-depth 0 is required for helm/chart-testing-action
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
with:
fetch-depth: 0
@@ -24,7 +28,7 @@ jobs:
version: v3.14.4
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v4
with:
python-version: '3.11'
cache: 'pip'
@@ -41,31 +45,24 @@ jobs:
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
id: list-changed
run: |
echo "default_branch: ${{ github.event.repository.default_branch }}"
changed=$(ct list-changed --remote origin --target-branch ${{ github.event.repository.default_branch }} --chart-dirs deployment/helm/charts)
echo "list-changed output: $changed"
changed=$(ct list-changed --target-branch ${{ github.event.repository.default_branch }})
if [[ -n "$changed" ]]; then
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
if: steps.list-changed.outputs.changed == 'true'
run: ct lint --config ct.yaml --all
# the following would lint only changed charts, but linting isn't expensive
# run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
# if: steps.list-changed.outputs.changed == 'true'
run: ct lint --all --config ct.yaml --target-branch ${{ github.event.repository.default_branch }}
- name: Create kind cluster
if: steps.list-changed.outputs.changed == 'true'
# if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.10.0
- name: Run chart-testing (install)
if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --helm-extra-set-args="--set=nginx.enabled=false" --debug --config ct.yaml
# the following would install only changed charts, but we only have one chart so
# don't worry about that for now
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
# if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --config ct.yaml
# run: ct install --target-branch ${{ github.event.repository.default_branch }}

View File

@@ -26,8 +26,7 @@ def upgrade() -> None:
FROM jsonb_array_elements(starter_messages) elem
)
WHERE starter_messages IS NOT NULL
AND jsonb_typeof(starter_messages) = 'array'
"""
"""
)
)
@@ -42,7 +41,6 @@ def downgrade() -> None:
FROM jsonb_array_elements(starter_messages) elem
)
WHERE starter_messages IS NOT NULL
AND jsonb_typeof(starter_messages) = 'array'
"""
"""
)
)

View File

@@ -48,11 +48,11 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.orm import attributes
from sqlalchemy.orm import Session
from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
@@ -75,7 +75,6 @@ from danswer.configs.constants import AuthType
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.api_key import fetch_user_for_api_key
from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
@@ -84,27 +83,24 @@ from danswer.db.auth import SQLAlchemyUserAdminDB
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.engine import get_session
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import AccessToken
from danswer.db.models import OAuthAccount
from danswer.db.models import User
from danswer.db.models import UserTenantMapping
from danswer.db.users import get_user_by_email
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True
@@ -194,6 +190,20 @@ def verify_email_domain(email: str) -> None:
)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
def send_user_verification_email(
user_email: str,
token: str,
@@ -228,13 +238,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> User:
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user_create.email,
)
try:
tenant_id = (
get_tenant_id_for_email(user_create.email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -255,7 +271,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
@@ -276,9 +292,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
else:
raise exceptions.UserAlreadyExists()
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user
async def oauth_callback(
@@ -294,18 +308,19 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> models.UOAP:
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=account_email,
)
# Get tenant_id from mapping table
try:
tenant_id = (
get_tenant_id_for_email(account_email)
if MULTI_TENANT
else POSTGRES_DEFAULT_SCHEMA
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
if not tenant_id:
raise HTTPException(status_code=401, detail="User not found")
# Proceed with the tenant context
token = None
async with get_async_session_with_tenant(tenant_id) as db_session:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -356,9 +371,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Explicitly set the Postgres schema for this session to ensure
# OAuth account creation happens in the correct tenant schema
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
await self.on_after_register(user, request)
else:
@@ -438,13 +453,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
email = credentials.username
# Get tenant_id from mapping table
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=email,
)
tenant_id = get_tenant_id_for_email(email)
if not tenant_id:
# User not found in mapping
self.password_helper.hash(credentials.password)
@@ -468,7 +477,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
has_web_login = attributes.get_attribute(user, "has_web_login")
if not has_web_login:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
@@ -501,14 +511,7 @@ cookie_transport = CookieTransport(
# This strategy is used to add tenant_id to the JWT token
class TenantAwareJWTStrategy(JWTStrategy):
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
tenant_id = await fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning",
"get_or_create_tenant_id",
async_return_default_schema,
)(
email=user.email,
)
tenant_id = get_tenant_id_for_email(user.email)
data = {
"sub": str(user.id),
"aud": self.token_audience,
@@ -625,12 +628,14 @@ async def double_check_user(
return None
if user is None:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not authenticated.",
)
if user_needs_to_be_verified() and not user.is_verified:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not verified.",
)
@@ -639,7 +644,8 @@ async def double_check_user(
and user.oidc_expiry < datetime.now(timezone.utc)
and not include_expired
):
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User's OIDC token has expired.",
)
@@ -665,13 +671,15 @@ async def current_curator_or_admin_user(
return None
if not user or not hasattr(user, "role"):
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not authenticated or lacks role information.",
)
allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN}
if user.role not in allowed_roles:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User is not a curator or admin.",
)
@@ -683,7 +691,8 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
return None
if not user or not hasattr(user, "role") or user.role != UserRole.ADMIN:
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User must be an admin to perform this action.",
)
@@ -876,22 +885,3 @@ def get_oauth_router(
return redirect_response
return router
def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
hashed_api_key = get_hashed_api_key_from_request(request)
if not hashed_api_key:
raise HTTPException(status_code=401, detail="Missing API key")
if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")
return user

View File

@@ -3,7 +3,6 @@ import multiprocessing
import time
from typing import Any
import requests
import sentry_sdk
from celery import Task
from celery.app import trace
@@ -12,15 +11,11 @@ from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from celery.worker import strategy # type: ignore
from sentry_sdk.integrations.celery import CeleryIntegration
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from danswer.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
from danswer.redis.redis_connector_delete import RedisConnectorDelete
@@ -31,6 +26,7 @@ from danswer.redis.redis_usergroup import RedisUserGroup
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import SENTRY_DSN
@@ -143,136 +139,45 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
"""Waits for redis to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
r = get_redis_client(tenant_id=None)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Redis: Readiness probe starting.")
logger.info("Redis: Readiness check starting.")
while True:
try:
if r.ping():
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Redis: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Redis: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Redis: Readiness probe succeeded. Continuing...")
return
def wait_for_db(sender: Any, **kwargs: Any) -> None:
"""Waits for the db to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Database: Readiness probe starting.")
while True:
try:
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(text("SELECT NOW()")).scalar()
if result:
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Database: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Database: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Database: Readiness probe succeeded. Continuing...")
return
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
"""Waits for Vespa to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
ready = False
time_start = time.monotonic()
logger.info("Vespa: Readiness probe starting.")
while True:
try:
response = requests.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
response.raise_for_status()
response_dict = response.json()
if response_dict["status"]["code"] == "up":
ready = True
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
if time_elapsed > WAIT_LIMIT:
break
logger.info(
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
time.sleep(WAIT_INTERVAL)
if not ready:
msg = (
f"Vespa: Readiness probe did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
logger.info("Vespa: Readiness probe succeeded. Continuing...")
logger.info("Redis: Readiness check succeeded. Continuing...")
return
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("Running as a secondary celery worker.")
# Exit early if multi-tenant since primary worker check not needed
if MULTI_TENANT:
return
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60

View File

@@ -12,7 +12,6 @@ from danswer.db.engine import get_all_tenant_ids
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
@@ -120,10 +119,10 @@ class DynamicTenantScheduler(PersistentScheduler):
else:
logger.info("Schedule is up to date, no changes needed")
except (AttributeError, KeyError):
logger.exception("Failed to process task configuration")
except Exception:
logger.exception("Unexpected error updating tenant tasks")
except (AttributeError, KeyError) as e:
logger.exception(f"Failed to process task configuration: {str(e)}")
except Exception as e:
logger.exception(f"Unexpected error updating tenant tasks: {str(e)}")
def _should_update_schedule(
self, current_schedule: dict, new_schedule: dict
@@ -144,11 +143,6 @@ def on_beat_init(sender: Any, **kwargs: Any) -> None:
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
SqlEngine.init_engine(pool_size=2, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)

View File

@@ -13,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -61,13 +60,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=4, max_overflow=12)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)

View File

@@ -13,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -61,13 +60,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)

View File

@@ -13,7 +13,6 @@ import danswer.background.celery.apps.app_base as app_base
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from danswer.db.engine import SqlEngine
from danswer.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -60,13 +59,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
app_base.on_secondary_worker_init(sender, **kwargs)

View File

@@ -75,16 +75,13 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
# Startup checks are not needed in multi-tenant case
if MULTI_TENANT:
return
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa(sender, **kwargs)
logger.info("Running as the primary celery worker.")
if MULTI_TENANT:
return
# This is singleton work that should be done on startup exactly once
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_redis_client(tenant_id=None)

View File

@@ -1,4 +0,0 @@
def name_sync_external_doc_permissions_task(
cc_pair_id: int, tenant_id: str | None = None
) -> str:
return f"sync_external_doc_permissions_task__{cc_pair_id}"

View File

@@ -48,6 +48,7 @@ class QADocsResponse(RetrievalDocs):
class StreamStopReason(Enum):
CONTEXT_LENGTH = "context_length"
CANCELLED = "cancelled"
NEW_RESPONSE = "new_response"
class StreamStopInfo(BaseModel):

View File

@@ -19,6 +19,7 @@ from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
@@ -137,6 +138,7 @@ def _translate_citations(
"""Always cites the first instance of the document_id, assumes the db_docs
are sorted in the order displayed in the UI"""
doc_id_to_saved_doc_id_map: dict[str, int] = {}
for db_doc in db_docs:
if db_doc.document_id not in doc_id_to_saved_doc_id_map:
doc_id_to_saved_doc_id_map[db_doc.document_id] = db_doc.id
@@ -687,6 +689,10 @@ def stream_chat_message_objects(
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
tool_name_to_tool_id = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
@@ -729,6 +735,74 @@ def stream_chat_message_objects(
tool_result = None
for packet in answer.processed_streamed_output:
if isinstance(packet, StreamStopInfo):
if packet.stop_reason is not StreamStopReason.NEW_RESPONSE:
break
db_citations = None
if reference_db_search_docs:
db_citations = _translate_citations(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
# Saving Gen AI answer and responding with message info
if tool_result is None:
tool_call = None
else:
tool_call = ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=cast(
QADocsResponse, qa_docs_response
).rephrased_query
if qa_docs_response is not None
else None,
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=cast(MessageSpecificCitations, db_citations).citation_map
if db_citations is not None
else None,
error=None,
tool_call=tool_call,
)
db_session.commit() # actually save user / assistant message
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield msg_detail_response
reserved_message_id = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=gen_ai_response_message.id
if user_message is not None
else gen_ai_response_message.id,
message_type=MessageType.ASSISTANT,
)
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=gen_ai_response_message,
prompt_id=prompt_id,
overridden_model=overridden_model,
message_type=MessageType.ASSISTANT,
alternate_assistant_id=new_msg_req.alternate_assistant_id,
db_session=db_session,
commit=False,
)
reference_db_search_docs = None
if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
@@ -869,6 +943,8 @@ def stream_chat_message_objects(
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
if answer.llm_answer == "":
return
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,

View File

@@ -493,13 +493,3 @@ JWT_ALGORITHM = "HS256"
# Super Users
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
#####
# API Key Configs
#####
# refers to the rounds described here: https://passlib.readthedocs.io/en/stable/lib/passlib.hash.sha256_crypt.html
_API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
API_KEY_HASH_ROUNDS = (
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
)

View File

@@ -126,7 +126,6 @@ class DocumentSource(str, Enum):
XENFORO = "xenforo"
NOT_APPLICABLE = "not_applicable"
FRESHDESK = "freshdesk"
FIREFLIES = "fireflies"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]

View File

@@ -16,7 +16,6 @@ from danswer.connectors.discourse.connector import DiscourseConnector
from danswer.connectors.document360.connector import Document360Connector
from danswer.connectors.dropbox.connector import DropboxConnector
from danswer.connectors.file.connector import LocalFileConnector
from danswer.connectors.fireflies.connector import FirefliesConnector
from danswer.connectors.freshdesk.connector import FreshdeskConnector
from danswer.connectors.github.connector import GithubConnector
from danswer.connectors.gitlab.connector import GitlabConnector
@@ -102,7 +101,6 @@ def identify_connector_class(
DocumentSource.OCI_STORAGE: BlobStorageConnector,
DocumentSource.XENFORO: XenforoConnector,
DocumentSource.FRESHDESK: FreshdeskConnector,
DocumentSource.FIREFLIES: FirefliesConnector,
}
connector_by_source = connector_map.get(source, {})

View File

@@ -123,13 +123,9 @@ def _process_file(
"filename",
"file_display_name",
"title",
"connector_type",
]
}
source_type_str = all_metadata.get("connector_type")
source_type = DocumentSource(source_type_str) if source_type_str else None
p_owner_names = all_metadata.get("primary_owners")
s_owner_names = all_metadata.get("secondary_owners")
p_owners = (
@@ -149,7 +145,7 @@ def _process_file(
sections=[
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
],
source=source_type or DocumentSource.FILE,
source=DocumentSource.FILE,
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,

View File

@@ -1,182 +0,0 @@
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import List
import requests
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
logger = setup_logger()
_FIREFLIES_ID_PREFIX = "FIREFLIES_"
_FIREFLIES_API_URL = "https://api.fireflies.ai/graphql"
_FIREFLIES_TRANSCRIPT_QUERY_SIZE = 50 # Max page size is 50
_FIREFLIES_API_QUERY = """
query Transcripts($fromDate: DateTime, $toDate: DateTime, $limit: Int!, $skip: Int!) {
transcripts(fromDate: $fromDate, toDate: $toDate, limit: $limit, skip: $skip) {
id
title
host_email
participants
date
transcript_url
sentences {
text
speaker_name
}
}
}
"""
def _create_doc_from_transcript(transcript: dict) -> Document | None:
meeting_text = ""
sentences = transcript.get("sentences", [])
if sentences:
for sentence in sentences:
meeting_text += sentence.get("speaker_name") or "Unknown Speaker"
meeting_text += ": " + sentence.get("text", "") + "\n\n"
else:
return None
meeting_link = transcript["transcript_url"]
fireflies_id = _FIREFLIES_ID_PREFIX + transcript["id"]
meeting_title = transcript["title"] or "No Title"
meeting_date_unix = transcript["date"]
meeting_date = datetime.fromtimestamp(meeting_date_unix / 1000, tz=timezone.utc)
meeting_host_email = transcript["host_email"]
host_email_user_info = [BasicExpertInfo(email=meeting_host_email)]
meeting_participants_email_list = []
for participant in transcript.get("participants", []):
if participant != meeting_host_email and participant:
meeting_participants_email_list.append(BasicExpertInfo(email=participant))
return Document(
id=fireflies_id,
sections=[
Section(
link=meeting_link,
text=meeting_text,
)
],
source=DocumentSource.FIREFLIES,
semantic_identifier=meeting_title,
metadata={},
doc_updated_at=meeting_date,
primary_owners=host_email_user_info,
secondary_owners=meeting_participants_email_list,
)
class FirefliesConnector(PollConnector, LoadConnector):
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
self.batch_size = batch_size
def load_credentials(self, credentials: dict[str, str]) -> None:
api_key = credentials.get("fireflies_api_key")
if not isinstance(api_key, str):
raise ConnectorMissingCredentialError(
"The Fireflies API key must be a string"
)
self.api_key = api_key
return None
def _fetch_transcripts(
self, start_datetime: str | None = None, end_datetime: str | None = None
) -> Iterator[List[dict]]:
if self.api_key is None:
raise ConnectorMissingCredentialError("Missing API key")
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.api_key,
}
skip = 0
variables: dict[str, int | str] = {
"limit": _FIREFLIES_TRANSCRIPT_QUERY_SIZE,
}
if start_datetime:
variables["fromDate"] = start_datetime
if end_datetime:
variables["toDate"] = end_datetime
while True:
variables["skip"] = skip
response = requests.post(
_FIREFLIES_API_URL,
headers=headers,
json={"query": _FIREFLIES_API_QUERY, "variables": variables},
)
response.raise_for_status()
if response.status_code == 204:
break
recieved_transcripts = response.json()
parsed_transcripts = recieved_transcripts.get("data", {}).get(
"transcripts", []
)
yield parsed_transcripts
if len(parsed_transcripts) < _FIREFLIES_TRANSCRIPT_QUERY_SIZE:
break
skip += _FIREFLIES_TRANSCRIPT_QUERY_SIZE
def _process_transcripts(
self, start: str | None = None, end: str | None = None
) -> GenerateDocumentsOutput:
doc_batch: List[Document] = []
for transcript_batch in self._fetch_transcripts(start, end):
for transcript in transcript_batch:
if doc := _create_doc_from_transcript(transcript):
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._process_transcripts()
def poll_source(
self, start_unixtime: SecondsSinceUnixEpoch, end_unixtime: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.fromtimestamp(
start_unixtime, tz=timezone.utc
).strftime("%Y-%m-%dT%H:%M:%S.000Z")
end_datetime = datetime.fromtimestamp(end_unixtime, tz=timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%S.000Z"
)
yield from self._process_transcripts(start_datetime, end_datetime)

View File

@@ -77,7 +77,6 @@ from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SLACK_CHANNEL_ID
logger = setup_logger()
@@ -190,67 +189,59 @@ class SlackbotHandler:
continue
logger.debug(f"Acquired lock for tenant {tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
tenant_id or POSTGRES_DEFAULT_SCHEMA
)
try:
with get_session_with_tenant(tenant_id) as db_session:
try:
logger.debug(
f"Setting tenant ID context variable for tenant {tenant_id}"
)
slack_bot_tokens = fetch_tokens()
logger.debug(f"Fetched Slack bot tokens for tenant {tenant_id}")
logger.debug(
f"Reset tenant ID context variable for tenant {tenant_id}"
)
with get_session_with_tenant(tenant_id) as db_session:
try:
logger.debug(
f"Setting tenant ID context variable for tenant {tenant_id}"
)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id or "public")
slack_bot_tokens = fetch_tokens()
logger.debug(f"Fetched Slack bot tokens for tenant {tenant_id}")
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
logger.debug(
f"Reset tenant ID context variable for tenant {tenant_id}"
)
if not slack_bot_tokens:
logger.debug(
f"No Slack bot token found for tenant {tenant_id}"
)
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
continue
if (
tenant_id not in self.slack_bot_tokens
or slack_bot_tokens != self.slack_bot_tokens[tenant_id]
):
if tenant_id in self.slack_bot_tokens:
logger.info(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
search_settings = get_current_search_settings(
db_session
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
self.start_socket_client(tenant_id, slack_bot_tokens)
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if not slack_bot_tokens:
logger.debug(f"No Slack bot token found for tenant {tenant_id}")
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
except Exception as e:
logger.exception(f"Error handling tenant {tenant_id}: {e}")
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
continue
if (
tenant_id not in self.slack_bot_tokens
or slack_bot_tokens != self.slack_bot_tokens[tenant_id]
):
if tenant_id in self.slack_bot_tokens:
logger.info(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
self.slack_bot_tokens[tenant_id] = slack_bot_tokens
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
self.start_socket_client(tenant_id, slack_bot_tokens)
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if tenant_id in self.socket_clients:
asyncio.run(self.socket_clients[tenant_id].close())
del self.socket_clients[tenant_id]
del self.slack_bot_tokens[tenant_id]
except Exception as e:
logger.exception(f"Error handling tenant {tenant_id}: {e}")
def send_heartbeats(self) -> None:
current_time = int(time.time())

View File

@@ -14,7 +14,6 @@ from sqlalchemy.orm import Session
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserRole
from danswer.db.api_key import get_api_key_email_pattern
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
from danswer.db.models import AccessToken
@@ -23,6 +22,7 @@ from danswer.db.models import User
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from ee.danswer.db.api_key import get_api_key_email_pattern
def get_default_admin_user_emails() -> list[str]:

View File

@@ -25,8 +25,8 @@ from danswer.db.models import UserGroup__ConnectorCredentialPair
from danswer.db.models import UserRole
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.sync_params import check_if_valid_sync_source
logger = setup_logger()
@@ -351,11 +351,7 @@ def add_credential_to_connector(
raise HTTPException(status_code=404, detail="Connector does not exist")
if access_type == AccessType.SYNC:
if not fetch_ee_implementation_or_noop(
"danswer.external_permissions.sync_params",
"check_if_valid_sync_source",
noop_return_value=True,
)(connector.source):
if not check_if_valid_sync_source(connector.source):
raise HTTPException(
status_code=400,
detail=f"Connector of type {connector.source} does not support SYNC access type",
@@ -442,10 +438,7 @@ def remove_credential_from_connector(
)
if association is not None:
fetch_ee_implementation_or_noop(
"danswer.db.external_perm",
"delete_user__ext_group_for_cc_pair__no_commit",
)(
delete_user__ext_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
)

View File

@@ -323,28 +323,16 @@ async def get_async_session_with_tenant(
yield session
@contextmanager
def get_session_with_default_tenant() -> Generator[Session, None, None]:
"""
Get a database session using the current tenant ID from the context variable.
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
with get_session_with_tenant(tenant_id) as session:
yield session
@contextmanager
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""
Generate a database session for a specific tenant.
This function:
1. Sets the database schema to the specified tenant's schema.
2. Preserves the tenant ID across the session.
3. Reverts to the previous tenant ID after the session is closed.
4. Uses the default schema if no tenant ID is provided.
Generate a database session bound to a connection with the appropriate tenant schema set.
This preserves the tenant ID across the session and reverts to the previous tenant ID
after the session is closed.
If tenant ID is set, we save the previous tenant ID from the context var to set
after the session is closed. The value `None` evaluates to the default schema.
"""
engine = get_sqlalchemy_engine()
@@ -352,9 +340,9 @@ def get_session_with_tenant(
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
if tenant_id is None:
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
tenant_id = previous_tenant_id
else:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
event.listen(engine, "checkout", set_search_path_on_checkout)

View File

@@ -12,7 +12,7 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.engine import get_session_with_tenant
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import IndexAttempt
@@ -152,7 +152,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
if db_session is None:
with get_session_with_default_tenant() as db_session:
with get_session_with_tenant() as db_session:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_current_search_settings(db_session)

View File

@@ -1,111 +0,0 @@
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.constants import TokenRateLimitScope
from danswer.db.models import TokenRateLimit
from danswer.db.models import TokenRateLimit__UserGroup
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
def fetch_all_user_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
ordered: bool = True,
) -> Sequence[TokenRateLimit]:
query = select(TokenRateLimit).where(
TokenRateLimit.scope == TokenRateLimitScope.USER
)
if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))
if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())
return db_session.scalars(query).all()
def fetch_all_global_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
ordered: bool = True,
) -> Sequence[TokenRateLimit]:
query = select(TokenRateLimit).where(
TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
)
if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))
if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())
token_rate_limits = db_session.scalars(query).all()
return token_rate_limits
def insert_user_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.USER,
)
db_session.add(token_limit)
db_session.commit()
return token_limit
def insert_global_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.GLOBAL,
)
db_session.add(token_limit)
db_session.commit()
return token_limit
def update_token_rate_limit(
db_session: Session,
token_rate_limit_id: int,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
if token_limit is None:
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
token_limit.enabled = token_rate_limit_settings.enabled
token_limit.token_budget = token_rate_limit_settings.token_budget
token_limit.period_hours = token_rate_limit_settings.period_hours
db_session.commit()
return token_limit
def delete_token_rate_limit(
db_session: Session,
token_rate_limit_id: int,
) -> None:
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
if token_limit is None:
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
db_session.query(TokenRateLimit__UserGroup).filter(
TokenRateLimit__UserGroup.rate_limit_id == token_rate_limit_id
).delete()
db_session.delete(token_limit)
db_session.commit()

View File

@@ -9,6 +9,8 @@ from langchain_core.messages import ToolCall
from danswer.chat.models import AnswerQuestionPossibleReturn
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.llm_response_handler import LLMCall
from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager
@@ -118,6 +120,9 @@ class Answer:
)
and not skip_explicit_tool_calling
)
self.current_streamed_output: list = []
self.processing_stream: list = []
def _get_tools_list(self) -> list[Tool]:
if not self.force_use_tool.force_use:
@@ -155,6 +160,7 @@ class Answer:
response_handler_manager = LLMResponseHandlerManager(
tool_handler, DummyAnswerResponseHandler(), self.is_cancelled
)
yield from response_handler_manager.handle_llm_response(
iter([dummy_tool_call_chunk])
)
@@ -165,7 +171,13 @@ class Answer:
else:
raise RuntimeError("Tool call handler did not return a new LLM call")
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
def _get_response(
self,
llm_calls: list[LLMCall],
check_for_tool_call: bool = False,
previously_used_tool: Tool | None = None,
previous_tool_response: ToolResponse | None = None,
) -> AnswerStream:
current_llm_call = llm_calls[-1]
# handle the case where no decision has to be made; we simply run the tool
@@ -231,7 +243,6 @@ class Answer:
tool_call_handler, answer_handler, self.is_cancelled
)
# DEBUG: good breakpoint
stream = self.llm.stream(
prompt=current_llm_call.prompt_builder.build(),
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
@@ -242,11 +253,101 @@ class Answer:
),
structured_response_format=self.answer_style_config.structured_response_format,
)
yield from response_handler_manager.handle_llm_response(stream)
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
tool_call_made = False
tool_call_name: str | None = None
buffered_packets = []
tool_response = None
for packet in response_handler_manager.handle_llm_response(stream):
if isinstance(packet, DanswerAnswerPiece):
pass
if isinstance(packet, ToolResponse):
tool_response = packet
if check_for_tool_call:
buffered_packets.append(packet)
if isinstance(packet, ToolCallKickoff):
# if has_streamed_text and not has_completed:
# yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
# has_completed = True
tool_call_name = packet.tool_name
tool_call_made = True
for buffered_packet in buffered_packets:
yield buffered_packet
buffered_packets = []
else:
yield packet
if isinstance(packet, ToolCallKickoff):
# if has_streamed_text and not has_completed:
# yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
# has_completed = True
tool_call_name = packet.tool_name
tool_call_made = True
if check_for_tool_call and not tool_call_made:
for remaining_packet in buffered_packets:
yield remaining_packet
return
for remaining_packet in buffered_packets:
yield remaining_packet
new_llm_call = response_handler_manager.next_llm_call(
current_llm_call, tool_call_made
)
tool_used: Tool | None = None
if tool_call_made:
tool_used = next(
(tool for tool in self.tools if tool.name == tool_call_name), None
)
if new_llm_call:
yield from self._get_response(llm_calls + [new_llm_call])
yield from self._get_response(
llm_calls + [new_llm_call],
check_for_tool_call=not tool_call_made,
previously_used_tool=tool_used,
previous_tool_response=tool_response,
)
else:
yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
# Logic here
if (
not check_for_tool_call
and not tool_call_made
and not previously_used_tool
):
return
if previously_used_tool:
previously_used_tool.build_prompt_after_tool_call(
current_llm_call.prompt_builder,
self.question,
self.llm_answer,
previous_tool_response,
)
# Build next prompter with the original question and the LLM's last answer
# current_llm_call.prompt_builder.update_user_prompt(HumanMessage(content=self.question))
# current_llm_call.prompt_builder.build_next_prompter(self.question, self.llm_answer)
llm_call = LLMCall(
prompt_builder=current_llm_call.prompt_builder,
tools=self._get_tools_list(),
force_use_tool=self.force_use_tool,
files=self.latest_query_files,
tool_call_info=[],
using_tool_calling_llm=self.using_tool_calling_llm,
)
yield from self._get_response(
[llm_call],
check_for_tool_call=not tool_call_made,
previously_used_tool=tool_used,
previous_tool_response=tool_response,
)
@property
def processed_streamed_output(self) -> AnswerStream:
@@ -276,26 +377,32 @@ class Answer:
using_tool_calling_llm=self.using_tool_calling_llm,
)
processed_stream = []
for processed_packet in self._get_response([llm_call]):
processed_stream.append(processed_packet)
if (
isinstance(processed_packet, StreamStopInfo)
and processed_packet.stop_reason == StreamStopReason.NEW_RESPONSE
):
self.current_streamed_output = self.processing_stream
self.processing_stream = []
self.processing_stream.append(processed_packet)
yield processed_packet
self._processed_stream = processed_stream
self.current_streamed_output = self.processing_stream
self._processed_stream = self.processing_stream
@property
def llm_answer(self) -> str:
answer = ""
for packet in self.processed_streamed_output:
if not self._processed_stream and not self.current_streamed_output:
return ""
for packet in self.current_streamed_output or self._processed_stream or []:
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
return answer
@property
def citations(self) -> list[CitationInfo]:
citations: list[CitationInfo] = []
for packet in self.processed_streamed_output:
for packet in self.current_streamed_output or self._processed_stream or []:
if isinstance(packet, CitationInfo):
citations.append(packet)

View File

@@ -80,5 +80,7 @@ class LLMResponseHandlerManager:
yield from self.tool_handler.handle_response_part(None, all_messages)
yield from self.answer_handler.handle_response_part(None, all_messages)
def next_llm_call(self, llm_call: LLMCall) -> LLMCall | None:
return self.tool_handler.next_llm_call(llm_call)
def next_llm_call(
self, llm_call: LLMCall, tool_call_made: bool = False
) -> LLMCall | None:
return self.tool_handler.next_llm_call(llm_call, tool_call_made)

View File

@@ -1,6 +1,7 @@
from collections.abc import Callable
from typing import cast
from langchain_core.messages import AIMessage
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
@@ -36,7 +37,10 @@ def default_build_system_message(
def default_build_user_message(
user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = []
user_query: str,
prompt_config: PromptConfig,
files: list[InMemoryChatFile] = [],
previous_tool_call_count: int = 0,
) -> HumanMessage:
user_prompt = (
CHAT_USER_CONTEXT_FREE_PROMPT.format(
@@ -45,10 +49,16 @@ def default_build_user_message(
if prompt_config.task_prompt
else user_query
)
if previous_tool_call_count > 0:
user_prompt = (
f"You have already generated the above so do not call a tool if not necessary. "
f"Remember the query is: `{user_prompt}`"
)
user_prompt = user_prompt.strip()
user_msg = HumanMessage(
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
)
return user_msg
@@ -87,6 +97,30 @@ class AnswerPromptBuilder:
)
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
self.task_reminder: tuple[HumanMessage, int] | None = None
def update_task_reminder(self, task_reminder: HumanMessage) -> None:
token_count = check_message_tokens(
task_reminder, self.llm_tokenizer_encode_func
)
self.task_reminder = (task_reminder, token_count)
def build_next_prompter(
self, question: str, llm_answer: str, task_reminder: str | None = None
):
# Append the AI's previous response
self.append_message(AIMessage(content=llm_answer))
# Add a new user message prompting the assistant to continue
self.append_message(
HumanMessage(
content=(
f"If your previous responses did not fully answer the original query: '{question}', "
"please continue and complete the answer. Only add information if the original question "
"wasn't fully addressed. Use any necessary tools to provide a comprehensive response. "
"If the original query was already completely fulfilled, do NOT call a tool."
)
)
)
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
if not system_message:
@@ -132,6 +166,8 @@ class AnswerPromptBuilder:
if self.new_messages_and_token_cnts:
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
if self.task_reminder:
final_messages_with_tokens.append(self.task_reminder)
return drop_messages_history_overflow(
final_messages_with_tokens, self.max_tokens

View File

@@ -173,7 +173,9 @@ class ToolResponseHandler:
return
def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None:
def next_llm_call(
self, current_llm_call: LLMCall, tool_call_made: bool
) -> LLMCall | None:
if (
self.tool_runner is None
or self.tool_call_summary is None
@@ -191,7 +193,9 @@ class ToolResponseHandler:
)
return LLMCall(
prompt_builder=new_prompt_builder,
tools=[], # for now, only allow one tool call per response
tools=self.tools
if not tool_call_made
else [], # for now, only allow one tool call per response
force_use_tool=ForceUseTool(
force_use=False,
tool_name="",

View File

@@ -25,7 +25,6 @@ from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserUpdate
from danswer.auth.users import auth_backend
from danswer.auth.users import BasicAuthenticationError
from danswer.auth.users import fastapi_users
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.configs.app_configs import APP_HOST
@@ -195,12 +194,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
def log_http_error(_: Request, exc: Exception) -> JSONResponse:
status_code = getattr(exc, "status_code", 500)
if isinstance(exc, BasicAuthenticationError):
# For BasicAuthenticationError, just log a brief message without stack trace (almost always spam)
logger.error(f"Authentication failed: {str(exc)}")
elif status_code >= 400:
if status_code >= 400:
error_msg = f"{str(exc)}\n"
error_msg += "".join(traceback.format_tb(exc.__traceback__))
logger.error(error_msg)
@@ -226,6 +220,7 @@ def get_application() -> FastAPI:
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
# Add the custom exception handler
application.add_exception_handler(status.HTTP_400_BAD_REQUEST, log_http_error)
application.add_exception_handler(status.HTTP_401_UNAUTHORIZED, log_http_error)
application.add_exception_handler(status.HTTP_403_FORBIDDEN, log_http_error)
@@ -282,14 +277,12 @@ def get_application() -> FastAPI:
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_register_router(UserRead, UserCreate),
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_reset_password_router(),

View File

@@ -35,31 +35,23 @@ class BaseTokenizer(ABC):
class TiktokenTokenizer(BaseTokenizer):
_instances: dict[str, "TiktokenTokenizer"] = {}
def __new__(cls, model_name: str) -> "TiktokenTokenizer":
if model_name not in cls._instances:
cls._instances[model_name] = super(TiktokenTokenizer, cls).__new__(cls)
return cls._instances[model_name]
def __new__(cls, encoding_name: str = "cl100k_base") -> "TiktokenTokenizer":
if encoding_name not in cls._instances:
cls._instances[encoding_name] = super(TiktokenTokenizer, cls).__new__(cls)
return cls._instances[encoding_name]
def __init__(self, model_name: str):
def __init__(self, encoding_name: str = "cl100k_base"):
if not hasattr(self, "encoder"):
import tiktoken
self.encoder = tiktoken.encoding_for_model(model_name)
self.encoder = tiktoken.get_encoding(encoding_name)
def encode(self, string: str) -> list[int]:
# this ignores special tokens that the model is trained on, see encode_ordinary for details
# this returns no special tokens
return self.encoder.encode_ordinary(string)
def tokenize(self, string: str) -> list[str]:
encoded = self.encode(string)
decoded = [self.encoder.decode([token]) for token in encoded]
if len(decoded) != len(encoded):
logger.warning(
f"OpenAI tokenized length {len(decoded)} does not match encoded length {len(encoded)} for string: {string}"
)
return decoded
return [self.encoder.decode([token]) for token in self.encode(string)]
def decode(self, tokens: list[int]) -> str:
return self.encoder.decode(tokens)
@@ -82,35 +74,22 @@ class HuggingFaceTokenizer(BaseTokenizer):
return self.encoder.decode(tokens)
_TOKENIZER_CACHE: dict[tuple[EmbeddingProvider | None, str | None], BaseTokenizer] = {}
_TOKENIZER_CACHE: dict[str, BaseTokenizer] = {}
def _check_tokenizer_cache(
model_provider: EmbeddingProvider | None, model_name: str | None
) -> BaseTokenizer:
def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
global _TOKENIZER_CACHE
id_tuple = (model_provider, model_name)
if id_tuple not in _TOKENIZER_CACHE:
if model_provider in [EmbeddingProvider.OPENAI, EmbeddingProvider.AZURE]:
if model_name is None:
raise ValueError(
"model_name is required for OPENAI and AZURE embeddings"
)
_TOKENIZER_CACHE[id_tuple] = TiktokenTokenizer(model_name)
return _TOKENIZER_CACHE[id_tuple]
if tokenizer_name not in _TOKENIZER_CACHE:
if tokenizer_name == "openai":
_TOKENIZER_CACHE[tokenizer_name] = TiktokenTokenizer("cl100k_base")
return _TOKENIZER_CACHE[tokenizer_name]
try:
if model_name is None:
model_name = DOCUMENT_ENCODER_MODEL
logger.debug(f"Initializing HuggingFaceTokenizer for: {model_name}")
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(model_name)
logger.debug(f"Initializing HuggingFaceTokenizer for: {tokenizer_name}")
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(tokenizer_name)
except Exception as primary_error:
logger.error(
f"Error initializing HuggingFaceTokenizer for {model_name}: {primary_error}"
f"Error initializing HuggingFaceTokenizer for {tokenizer_name}: {primary_error}"
)
logger.warning(
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
@@ -119,7 +98,7 @@ def _check_tokenizer_cache(
try:
# Cache this tokenizer name to the default so we don't have to try to load it again
# and fail again
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(
DOCUMENT_ENCODER_MODEL
)
except Exception as fallback_error:
@@ -127,10 +106,10 @@ def _check_tokenizer_cache(
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
)
raise ValueError(
f"Failed to initialize tokenizer for {model_name} and fallback model"
f"Failed to initialize tokenizer for {tokenizer_name} and fallback model"
) from fallback_error
return _TOKENIZER_CACHE[id_tuple]
return _TOKENIZER_CACHE[tokenizer_name]
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
@@ -139,16 +118,11 @@ _DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
def get_tokenizer(
model_name: str | None, provider_type: EmbeddingProvider | str | None
) -> BaseTokenizer:
if provider_type is not None:
if isinstance(provider_type, str):
try:
provider_type = EmbeddingProvider(provider_type)
except ValueError:
logger.debug(
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
)
return _DEFAULT_TOKENIZER
return _check_tokenizer_cache(provider_type, model_name)
# Currently all of the viable models use the same sentencepiece tokenizer
# OpenAI uses a different one but currently it's not supported due to quality issues
# the inconsistent chunking makes using the sentencepiece tokenizer default better for now
# LLM tokenizers are specified by strings
global _DEFAULT_TOKENIZER
return _DEFAULT_TOKENIZER

View File

@@ -65,7 +65,7 @@ from danswer.tools.tool_implementations.search.search_tool import (
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
logger = setup_logger()
@@ -125,11 +125,11 @@ def stream_answer_objects(
)
temporary_persona: Persona | None = None
if query_req.persona_config is not None:
temporary_persona = fetch_ee_implementation_or_noop(
"danswer.server.query_and_chat.utils", "create_temporary_persona", None
)(db_session=db_session, persona_config=query_req.persona_config, user=user)
new_persona = create_temporary_persona(
db_session=db_session, persona_config=query_req.persona_config, user=user
)
temporary_persona = new_persona
persona = temporary_persona if temporary_persona else chat_session.persona

View File

@@ -10,7 +10,8 @@ from danswer.auth.users import current_user
from danswer.auth.users import current_user_with_expired_token
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.server.danswer_api.ingestion import api_key_dep
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from ee.danswer.auth.users import current_cloud_superuser
from ee.danswer.server.tenants.access import control_plane_dep
PUBLIC_ENDPOINT_SPECS = [
@@ -80,14 +81,6 @@ def check_router_auth(
(1) have auth enabled OR
(2) are explicitly marked as a public endpoint
"""
control_plane_dep = fetch_ee_implementation_or_noop(
"danswer.server.tenants.access", "control_plane_dep"
)
current_cloud_superuser = fetch_ee_implementation_or_noop(
"danswer.auth.users", "current_cloud_superuser"
)
for route in application.routes:
# explicitly marked as public
if is_route_in_spec_list(route, public_endpoint_specs):

View File

@@ -3,7 +3,6 @@ from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import api_key_dep
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import Document
from danswer.connectors.models import IndexAttemptMetadata
@@ -23,6 +22,7 @@ from danswer.server.danswer_api.models import DocMinimalInfo
from danswer.server.danswer_api.models import IngestionDocument
from danswer.server.danswer_api.models import IngestionResult
from danswer.utils.logger import setup_logger
from ee.danswer.auth.users import api_key_dep
logger = setup_logger()

View File

@@ -16,9 +16,6 @@ from danswer.background.celery.tasks.pruning.tasks import (
try_creating_prune_generator_task,
)
from danswer.background.celery.versioned_apps.primary import app as primary_app
from danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from danswer.db.connector_credential_pair import add_credential_to_connector
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import remove_credential_from_connector
@@ -50,7 +47,11 @@ from danswer.server.documents.models import ConnectorCredentialPairMetadata
from danswer.server.documents.models import PaginatedIndexAttempts
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from ee.danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from ee.danswer.db.user_group import validate_user_creation_permissions
logger = setup_logger()
router = APIRouter(prefix="/manage")
@@ -331,6 +332,9 @@ def sync_cc_pair(
db_session: Session = Depends(get_session),
) -> StatusResponse[list[int]]:
# avoiding circular refs
from ee.danswer.background.celery.apps.primary import (
sync_external_doc_permissions_task,
)
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
@@ -356,19 +360,12 @@ def sync_cc_pair(
)
logger.info(f"Syncing the {cc_pair.connector.name} connector.")
sync_external_doc_permissions_task = fetch_ee_implementation_or_noop(
"danswer.background.celery.apps.primary",
"sync_external_doc_permissions_task",
None,
sync_external_doc_permissions_task.apply_async(
kwargs=dict(
cc_pair_id=cc_pair_id, tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get()
),
)
if sync_external_doc_permissions_task:
sync_external_doc_permissions_task.apply_async(
kwargs=dict(
cc_pair_id=cc_pair_id, tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get()
),
)
return StatusResponse(
success=True,
message="Successfully created the sync task.",
@@ -383,9 +380,7 @@ def associate_credential_to_connector(
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[int]:
fetch_ee_implementation_or_noop(
"danswer.db.user_group", "validate_user_creation_permissions", None
)(
validate_user_creation_permissions(
db_session=db_session,
user=user,
target_group_ids=metadata.groups,

View File

@@ -108,7 +108,7 @@ from danswer.server.documents.models import ObjectCreationIdResponse
from danswer.server.documents.models import RunConnectorRequest
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from ee.danswer.db.user_group import validate_user_creation_permissions
logger = setup_logger()
@@ -658,10 +658,7 @@ def create_connector_from_model(
) -> ObjectCreationIdResponse:
try:
_validate_connector_allowed(connector_data.source)
fetch_ee_implementation_or_noop(
"danswer.db.user_group", "validate_user_creation_permissions", None
)(
validate_user_creation_permissions(
db_session=db_session,
user=user,
target_group_ids=connector_data.groups,
@@ -735,9 +732,7 @@ def update_connector_from_model(
) -> ConnectorSnapshot | StatusResponse[int]:
try:
_validate_connector_allowed(connector_data.source)
fetch_ee_implementation_or_noop(
"danswer.db.user_group", "validate_user_creation_permissions", None
)(
validate_user_creation_permissions(
db_session=db_session,
user=user,
target_group_ids=connector_data.groups,

View File

@@ -28,7 +28,7 @@ from danswer.server.documents.models import CredentialSwapRequest
from danswer.server.documents.models import ObjectCreationIdResponse
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from ee.danswer.db.user_group import validate_user_creation_permissions
logger = setup_logger()
@@ -121,9 +121,7 @@ def create_credential_from_model(
db_session: Session = Depends(get_session),
) -> ObjectCreationIdResponse:
if not _ignore_credential_permissions(credential_info.source):
fetch_ee_implementation_or_noop(
"danswer.db.user_group", "validate_user_creation_permissions", None
)(
validate_user_creation_permissions(
db_session=db_session,
user=user,
target_group_ids=credential_info.groups,

View File

@@ -18,7 +18,7 @@ from danswer.server.features.document_set.models import CheckDocSetPublicRespons
from danswer.server.features.document_set.models import DocumentSet
from danswer.server.features.document_set.models import DocumentSetCreationRequest
from danswer.server.features.document_set.models import DocumentSetUpdateRequest
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from ee.danswer.db.user_group import validate_user_creation_permissions
router = APIRouter(prefix="/manage")
@@ -30,9 +30,7 @@ def create_document_set(
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> int:
fetch_ee_implementation_or_noop(
"danswer.db.user_group", "validate_user_creation_permissions", None
)(
validate_user_creation_permissions(
db_session=db_session,
user=user,
target_group_ids=document_set_creation_request.groups,
@@ -55,9 +53,7 @@ def patch_document_set(
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> None:
fetch_ee_implementation_or_noop(
"danswer.db.user_group", "validate_user_creation_permissions", None
)(
validate_user_creation_permissions(
db_session=db_session,
user=user,
target_group_ids=document_set_update_request.groups,

View File

@@ -11,6 +11,7 @@ from fastapi import Body
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import status
from psycopg2.errors import UniqueViolation
from pydantic import BaseModel
from sqlalchemy import Column
@@ -26,10 +27,10 @@ from danswer.auth.noauth_user import fetch_no_auth_user
from danswer.auth.noauth_user import set_no_auth_user_preferences
from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserStatus
from danswer.auth.users import BasicAuthenticationError
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import get_tenant_id_for_email
from danswer.auth.users import optional_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import ENABLE_EMAIL_INVITES
@@ -37,7 +38,6 @@ from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import SUPER_USERS
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.constants import AuthType
from danswer.db.api_key import is_api_key_email_address
from danswer.db.auth import get_total_users_count
from danswer.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from danswer.db.engine import get_session
@@ -61,7 +61,12 @@ from danswer.server.models import InvitedUserSnapshot
from danswer.server.models import MinimalUserSnapshot
from danswer.server.utils import send_user_email_invite
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
from ee.danswer.db.api_key import is_api_key_email_address
from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_commit
from ee.danswer.db.user_group import remove_curator_status__no_commit
from ee.danswer.server.tenants.billing import register_tenant_users
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import remove_users_from_tenant
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -100,10 +105,7 @@ def set_user_role(
)
if user_to_update.role == UserRole.CURATOR:
fetch_ee_implementation_or_noop(
"danswer.db.user_group",
"remove_curator_status__no_commit",
)(db_session, user_to_update)
remove_curator_status__no_commit(db_session, user_to_update)
user_to_update.role = user_role_update_request.new_role.value
@@ -203,9 +205,7 @@ def bulk_invite_users(
if MULTI_TENANT:
try:
fetch_ee_implementation_or_noop(
"danswer.server.tenants.provisioning", "add_users_to_tenant", None
)(normalized_emails, tenant_id)
add_users_to_tenant(normalized_emails, tenant_id)
except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
@@ -226,9 +226,9 @@ def bulk_invite_users(
return number_of_invited_users
try:
logger.info("Registering tenant users")
fetch_ee_implementation_or_noop(
"danswer.server.tenants.billing", "register_tenant_users", None
)(CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session))
register_tenant_users(
CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session)
)
if ENABLE_EMAIL_INVITES:
try:
for email in all_emails:
@@ -243,9 +243,7 @@ def bulk_invite_users(
"Reverting changes: removing users from tenant and resetting invited users"
)
write_invited_users(initial_invited_users) # Reset to original state
fetch_ee_implementation_or_noop(
"danswer.server.tenants.user_mapping", "remove_users_from_tenant", None
)(normalized_emails, tenant_id)
remove_users_from_tenant(normalized_emails, tenant_id)
raise e
@@ -259,16 +257,14 @@ def remove_invited_user(
remaining_users = [user for user in user_emails if user != user_email.user_email]
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
fetch_ee_implementation_or_noop(
"danswer.server.tenants.user_mapping", "remove_users_from_tenant", None
)([user_email.user_email], tenant_id)
remove_users_from_tenant([user_email.user_email], tenant_id)
number_of_invited_users = write_invited_users(remaining_users)
try:
if MULTI_TENANT:
fetch_ee_implementation_or_noop(
"danswer.server.tenants.billing", "register_tenant_users", None
)(CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session))
register_tenant_users(
CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session)
)
except Exception:
logger.error(
"Request to update number of seats taken in control plane failed. "
@@ -335,10 +331,7 @@ async def delete_user(
for oauth_account in user_to_delete.oauth_accounts:
db_session.delete(oauth_account)
fetch_ee_implementation_or_noop(
"danswer.db.external_perm",
"delete_user__ext_group_for_user__no_commit",
)(
delete_user__ext_group_for_user__no_commit(
db_session=db_session,
user_id=user_to_delete.id,
)
@@ -492,19 +485,20 @@ def verify_user_logged_in(
store = get_kv_store()
return fetch_no_auth_user(store)
raise BasicAuthenticationError(detail="User Not Authenticated")
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="User Not Authenticated"
)
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
raise BasicAuthenticationError(
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Access denied. User's OIDC token has expired.",
)
token_created_at = (
None if MULTI_TENANT else get_current_token_creation(user, db_session)
)
organization_name = fetch_ee_implementation_or_noop(
"danswer.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(user.email)
organization_name = get_tenant_id_for_email(user.email)
user_info = UserInfo.from_model(
user,

View File

@@ -359,7 +359,7 @@ def handle_new_chat_message(
yield json.dumps(packet) if isinstance(packet, dict) else packet
except Exception as e:
logger.exception("Error in chat message streaming")
logger.exception(f"Error in chat message streaming: {e}")
yield json.dumps({"error": str(e)})
finally:

View File

@@ -279,7 +279,7 @@ def get_answer_with_quote(
):
yield json.dumps(packet) if isinstance(packet, dict) else packet
except Exception as e:
logger.exception("Error in search answer streaming")
logger.exception(f"Error in search answer streaming: {e}")
yield json.dumps({"error": str(e)})
return StreamingResponse(stream_generator(), media_type="application/json")

View File

@@ -18,9 +18,9 @@ from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
from danswer.db.models import TokenRateLimit
from danswer.db.models import User
from danswer.db.token_limit import fetch_all_global_token_rate_limits
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

View File

@@ -5,13 +5,13 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.token_limit import delete_token_rate_limit
from danswer.db.token_limit import fetch_all_global_token_rate_limits
from danswer.db.token_limit import insert_global_token_rate_limit
from danswer.db.token_limit import update_token_rate_limit
from danswer.server.query_and_chat.token_limit import any_rate_limit_exists
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
from danswer.server.token_rate_limits.models import TokenRateLimitDisplay
from ee.danswer.db.token_limit import delete_token_rate_limit
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
from ee.danswer.db.token_limit import insert_global_token_rate_limit
from ee.danswer.db.token_limit import update_token_rate_limit
router = APIRouter(prefix="/admin/token-rate-limits")

View File

@@ -80,3 +80,15 @@ class Tool(abc.ABC):
using_tool_calling_llm: bool,
) -> "AnswerPromptBuilder":
raise NotImplementedError
# This is the prompt builder that is used when the tool call AND LLM response has been updated
# and we need to build the next prompt (for LLM calling tools)
# @abc.abstractmethod
def build_prompt_after_tool_call(
self,
prompt_builder: "AnswerPromptBuilder",
query: str,
llm_answer: str,
) -> "AnswerPromptBuilder":
pass

View File

@@ -15,7 +15,7 @@ from langchain_core.messages import SystemMessage
from pydantic import BaseModel
from danswer.configs.constants import FileOrigin
from danswer.db.engine import get_session_with_default_tenant
from danswer.db.engine import get_session_with_tenant
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import InMemoryChatFile
@@ -187,7 +187,7 @@ class CustomTool(BaseTool):
def _save_and_get_file_references(
self, file_content: bytes | str, content_type: str
) -> List[str]:
with get_session_with_default_tenant() as db_session:
with get_session_with_tenant() as db_session:
file_store = get_default_file_store(db_session)
file_id = str(uuid.uuid4())
@@ -299,7 +299,7 @@ class CustomTool(BaseTool):
# Load files from storage
files = []
with get_session_with_default_tenant() as db_session:
with get_session_with_tenant() as db_session:
file_store = get_default_file_store(db_session)
for file_id in response.tool_result.file_ids:

View File

@@ -4,6 +4,8 @@ from enum import Enum
from typing import Any
from typing import cast
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from litellm import image_generation # type: ignore
from pydantic import BaseModel
@@ -22,6 +24,9 @@ from danswer.tools.tool import Tool
from danswer.tools.tool_implementations.images.prompt import (
build_image_generation_user_prompt,
)
from danswer.tools.tool_implementations.images.prompt import (
build_image_generation_user_task_prompt,
)
from danswer.utils.headers import build_llm_extra_headers
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -297,3 +302,49 @@ class ImageGenerationTool(Tool):
)
return prompt_builder
def build_prompt_after_tool_call(
self,
prompt_builder: "AnswerPromptBuilder",
query: str,
llm_answer: str,
tool_responses: "ToolResponse",
) -> "AnswerPromptBuilder":
# Append the assistant's previous response to the message history
img_generation_response = cast(
list[ImageGenerationResponse], tool_responses.response
)
if img_generation_response is None:
raise ValueError("No image generation response found")
img_urls = [img.url for img in img_generation_response]
# Build a user message that includes the images generated
user_message = build_image_generation_user_task_prompt(
img_urls=img_urls,
)
prompt_builder.update_user_prompt(HumanMessage(content=query))
# Update the user prompt with the new message containing images
prompt_builder.append_message(user_message)
prompt_builder.append_message(
AIMessage(
content=f"The images I generated can be described as the following: {llm_answer}"
)
)
# Append a new user message reminding the assistant of the original query and what remains to be done
prompt_builder.update_task_reminder(
HumanMessage(
content=f"Reminder: the original request was: '{query}'.\n\n"
"You generated the above images as part of this request. "
"If any parts have not been fulfilled, please proceed to complete them using the appropriate tools. "
"If the original request has been fulfilled with the prior messages,"
"you can provide a final summary and DO NOT call a tool."
)
)
return prompt_builder

View File

@@ -10,6 +10,11 @@ Can you please summarize them in a sentence or two? Do NOT include image urls or
"""
IMG_GENERATION_USER_PROMPT = """
These are the IMAGES you generated
"""
def build_image_generation_user_prompt(
query: str, img_urls: list[str] | None = None
) -> HumanMessage:
@@ -19,3 +24,14 @@ def build_image_generation_user_prompt(
img_urls=img_urls,
)
)
def build_image_generation_user_task_prompt(
img_urls: list[str] | None = None,
) -> HumanMessage:
return HumanMessage(
content=build_content_with_imgs(
message=IMG_GENERATION_USER_PROMPT,
img_urls=img_urls,
)
)

View File

@@ -119,30 +119,3 @@ def noop_fallback(*args: Any, **kwargs: Any) -> None:
Returns:
None
"""
def fetch_ee_implementation_or_noop(
module: str, attribute: str, noop_return_value: Any = None
) -> Any:
"""
Fetches an EE implementation if EE is enabled, otherwise returns a no-op function.
Raises an exception if EE is enabled but the fetch fails.
Args:
module (str): The name of the module from which to fetch the attribute.
attribute (str): The name of the attribute to fetch from the module.
Returns:
Any: The fetched EE implementation if successful and EE is enabled, otherwise a no-op function.
Raises:
Exception: If EE is enabled but the fetch fails.
"""
if not global_version.is_ee_version():
return lambda *args, **kwargs: noop_return_value
try:
return fetch_versioned_implementation(module, attribute)
except Exception as e:
logger.error(f"Failed to fetch implementation for {module}.{attribute}: {e}")
raise

View File

@@ -8,7 +8,7 @@ from passlib.hash import sha256_crypt
from pydantic import BaseModel
from danswer.auth.schemas import UserRole
from danswer.configs.app_configs import API_KEY_HASH_ROUNDS
from ee.danswer.configs.app_configs import API_KEY_HASH_ROUNDS
_API_KEY_HEADER_NAME = "Authorization"

View File

@@ -4,15 +4,16 @@ from fastapi import Request
from fastapi import status
from sqlalchemy.orm import Session
from danswer.auth.api_key import get_hashed_api_key_from_request
from danswer.auth.users import current_admin_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import SUPER_CLOUD_API_KEY
from danswer.configs.app_configs import SUPER_USERS
from danswer.configs.constants import AuthType
from danswer.db.api_key import fetch_user_for_api_key
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.utils.logger import setup_logger
from ee.danswer.auth.api_key import get_hashed_api_key_from_request
from ee.danswer.db.api_key import fetch_user_for_api_key
from ee.danswer.db.saml import get_saml_account
from ee.danswer.server.seeding import get_seed_config
from ee.danswer.utils.secrets import extract_hashed_cookie
@@ -47,6 +48,25 @@ async def optional_user_(
return user
def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
) -> User | None:
if AUTH_TYPE == AuthType.DISABLED:
return None
hashed_api_key = get_hashed_api_key_from_request(request)
if not hashed_api_key:
raise HTTPException(status_code=401, detail="Missing API key")
if hashed_api_key:
user = fetch_user_for_api_key(hashed_api_key, db_session)
if user is None:
raise HTTPException(status_code=401, detail="Invalid API key")
return user
def get_default_admin_user_emails_() -> list[str]:
seed_config = get_seed_config()
if seed_config and seed_config.admin_user_emails:

View File

@@ -1,7 +1,4 @@
from danswer.background.celery.apps.primary import celery_app
from danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.db.chat import delete_chat_sessions_older_than
@@ -17,6 +14,9 @@ from ee.danswer.background.celery_utils import (
should_perform_external_group_permissions_check,
)
from ee.danswer.background.task_name_builders import name_chat_ttl_task
from ee.danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from ee.danswer.background.task_name_builders import (
name_sync_external_group_permissions_task,
)

View File

@@ -3,15 +3,15 @@ from datetime import timezone
from sqlalchemy.orm import Session
from danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from danswer.db.enums import AccessType
from danswer.db.models import ConnectorCredentialPair
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.utils.logger import setup_logger
from ee.danswer.background.task_name_builders import name_chat_ttl_task
from ee.danswer.background.task_name_builders import (
name_sync_external_doc_permissions_task,
)
from ee.danswer.background.task_name_builders import (
name_sync_external_group_permissions_task,
)

View File

@@ -2,6 +2,12 @@ def name_chat_ttl_task(retention_limit_days: int, tenant_id: str | None = None)
return f"chat_ttl_{retention_limit_days}_days"
def name_sync_external_doc_permissions_task(
cc_pair_id: int, tenant_id: str | None = None
) -> str:
return f"sync_external_doc_permissions_task__{cc_pair_id}"
def name_sync_external_group_permissions_task(
cc_pair_id: int, tenant_id: str | None = None
) -> str:

View File

@@ -7,6 +7,16 @@ OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/danswer/configs/saml_config"
#####
# API Key Configs
#####
# refers to the rounds described here: https://passlib.readthedocs.io/en/stable/lib/passlib.hash.sha256_crypt.html
_API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS")
API_KEY_HASH_ROUNDS = (
int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None
)
#####
# Auto Permission Sync
#####

View File

@@ -5,16 +5,16 @@ from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from danswer.auth.api_key import ApiKeyDescriptor
from danswer.auth.api_key import build_displayable_api_key
from danswer.auth.api_key import generate_api_key
from danswer.auth.api_key import hash_api_key
from danswer.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from danswer.configs.constants import DANSWER_API_KEY_PREFIX
from danswer.configs.constants import UNNAMED_KEY_PLACEHOLDER
from danswer.db.models import ApiKey
from danswer.db.models import User
from danswer.server.api_key.models import APIKeyArgs
from ee.danswer.auth.api_key import ApiKeyDescriptor
from ee.danswer.auth.api_key import build_displayable_api_key
from ee.danswer.auth.api_key import generate_api_key
from ee.danswer.auth.api_key import hash_api_key
from ee.danswer.server.api_key.models import APIKeyArgs
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

View File

@@ -65,6 +65,64 @@ def _add_user_filters(
return stmt.where(where_clause)
def fetch_all_user_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
ordered: bool = True,
) -> Sequence[TokenRateLimit]:
query = select(TokenRateLimit).where(
TokenRateLimit.scope == TokenRateLimitScope.USER
)
if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))
if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())
return db_session.scalars(query).all()
def fetch_all_global_token_rate_limits(
db_session: Session,
enabled_only: bool = False,
ordered: bool = True,
) -> Sequence[TokenRateLimit]:
query = select(TokenRateLimit).where(
TokenRateLimit.scope == TokenRateLimitScope.GLOBAL
)
if enabled_only:
query = query.where(TokenRateLimit.enabled.is_(True))
if ordered:
query = query.order_by(TokenRateLimit.created_at.desc())
token_rate_limits = db_session.scalars(query).all()
return token_rate_limits
def fetch_user_group_token_rate_limits(
db_session: Session,
group_id: int,
user: User | None = None,
enabled_only: bool = False,
ordered: bool = True,
get_editable: bool = True,
) -> Sequence[TokenRateLimit]:
stmt = select(TokenRateLimit)
stmt = stmt.where(User__UserGroup.user_group_id == group_id)
stmt = _add_user_filters(stmt, user, get_editable)
if enabled_only:
stmt = stmt.where(TokenRateLimit.enabled.is_(True))
if ordered:
stmt = stmt.order_by(TokenRateLimit.created_at.desc())
return db_session.scalars(stmt).all()
def fetch_all_user_group_token_rate_limits_by_group(
db_session: Session,
) -> Sequence[Row[tuple[TokenRateLimit, str]]]:
@@ -80,6 +138,38 @@ def fetch_all_user_group_token_rate_limits_by_group(
return db_session.execute(query).all()
def insert_user_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.USER,
)
db_session.add(token_limit)
db_session.commit()
return token_limit
def insert_global_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = TokenRateLimit(
enabled=token_rate_limit_settings.enabled,
token_budget=token_rate_limit_settings.token_budget,
period_hours=token_rate_limit_settings.period_hours,
scope=TokenRateLimitScope.GLOBAL,
)
db_session.add(token_limit)
db_session.commit()
return token_limit
def insert_user_group_token_rate_limit(
db_session: Session,
token_rate_limit_settings: TokenRateLimitArgs,
@@ -103,22 +193,34 @@ def insert_user_group_token_rate_limit(
return token_limit
def fetch_user_group_token_rate_limits(
def update_token_rate_limit(
db_session: Session,
group_id: int,
user: User | None = None,
enabled_only: bool = False,
ordered: bool = True,
get_editable: bool = True,
) -> Sequence[TokenRateLimit]:
stmt = select(TokenRateLimit)
stmt = stmt.where(User__UserGroup.user_group_id == group_id)
stmt = _add_user_filters(stmt, user, get_editable)
token_rate_limit_id: int,
token_rate_limit_settings: TokenRateLimitArgs,
) -> TokenRateLimit:
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
if token_limit is None:
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
if enabled_only:
stmt = stmt.where(TokenRateLimit.enabled.is_(True))
token_limit.enabled = token_rate_limit_settings.enabled
token_limit.token_budget = token_rate_limit_settings.token_budget
token_limit.period_hours = token_rate_limit_settings.period_hours
db_session.commit()
if ordered:
stmt = stmt.order_by(TokenRateLimit.created_at.desc())
return token_limit
return db_session.scalars(stmt).all()
def delete_token_rate_limit(
db_session: Session,
token_rate_limit_id: int,
) -> None:
token_limit = db_session.get(TokenRateLimit, token_rate_limit_id)
if token_limit is None:
raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found")
db_session.query(TokenRateLimit__UserGroup).filter(
TokenRateLimit__UserGroup.rate_limit_id == token_rate_limit_id
).delete()
db_session.delete(token_limit)
db_session.commit()

View File

@@ -12,11 +12,11 @@ from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import AuthType
from danswer.main import get_application as get_application_base
from danswer.main import include_router_with_global_prefix_prepended
from danswer.server.api_key.api import router as api_key_router
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from ee.danswer.configs.app_configs import OPENID_CONFIG_URL
from ee.danswer.server.analytics.api import router as analytics_router
from ee.danswer.server.api_key.api import router as api_key_router
from ee.danswer.server.auth_check import check_ee_router_auth
from ee.danswer.server.enterprise_settings.api import (
admin_router as enterprise_settings_admin_router,

View File

@@ -3,15 +3,15 @@ from fastapi import Depends
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.db.api_key import ApiKeyDescriptor
from danswer.db.api_key import fetch_api_keys
from danswer.db.api_key import insert_api_key
from danswer.db.api_key import regenerate_api_key
from danswer.db.api_key import remove_api_key
from danswer.db.api_key import update_api_key
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.server.api_key.models import APIKeyArgs
from ee.danswer.db.api_key import ApiKeyDescriptor
from ee.danswer.db.api_key import fetch_api_keys
from ee.danswer.db.api_key import insert_api_key
from ee.danswer.db.api_key import regenerate_api_key
from ee.danswer.db.api_key import remove_api_key
from ee.danswer.db.api_key import update_api_key
from ee.danswer.server.api_key.models import APIKeyArgs
router = APIRouter(prefix="/admin/api-key")

View File

@@ -8,9 +8,9 @@ from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from danswer.auth.api_key import extract_tenant_from_api_key_header
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.db.engine import is_valid_schema_name
from ee.danswer.auth.api_key import extract_tenant_from_api_key_header
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR

View File

@@ -12,7 +12,6 @@ from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.api_key import is_api_key_email_address
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
@@ -21,11 +20,12 @@ from danswer.db.models import TokenRateLimit__UserGroup
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.db.models import UserGroup
from danswer.db.token_limit import fetch_all_user_token_rate_limits
from danswer.server.query_and_chat.token_limit import _get_cutoff_time
from danswer.server.query_and_chat.token_limit import _is_rate_limited
from danswer.server.query_and_chat.token_limit import _user_is_rate_limited_by_global
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from ee.danswer.db.api_key import is_api_key_email_address
from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits
def _check_token_rate_limits(user: User | None, tenant_id: str | None) -> None:

View File

@@ -7,6 +7,7 @@ from fastapi import Response
from danswer.auth.users import auth_backend
from danswer.auth.users import current_admin_user
from danswer.auth.users import get_jwt_strategy
from danswer.auth.users import get_tenant_id_for_email
from danswer.auth.users import User
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.db.engine import get_session_with_tenant
@@ -14,6 +15,7 @@ from danswer.db.notification import create_notification
from danswer.db.users import get_user_by_email
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.setup import setup_danswer
from danswer.utils.logger import setup_logger
from ee.danswer.auth.users import current_cloud_superuser
from ee.danswer.configs.app_configs import STRIPE_SECRET_KEY
@@ -21,9 +23,15 @@ from ee.danswer.server.tenants.access import control_plane_dep
from ee.danswer.server.tenants.billing import fetch_billing_information
from ee.danswer.server.tenants.billing import fetch_tenant_stripe_information
from ee.danswer.server.tenants.models import BillingInformation
from ee.danswer.server.tenants.models import CreateTenantRequest
from ee.danswer.server.tenants.models import ImpersonateRequest
from ee.danswer.server.tenants.models import ProductGatingRequest
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import configure_default_api_keys
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
stripe.api_key = STRIPE_SECRET_KEY
@@ -32,6 +40,52 @@ logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/create")
def create_tenant(
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
) -> dict[str, str]:
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
tenant_id = create_tenant_request.tenant_id
email = create_tenant_request.initial_admin_email
token = None
if user_owns_a_tenant(email):
raise HTTPException(
status_code=409, detail="User already belongs to an organization"
)
try:
if not ensure_schema_exists(tenant_id):
logger.info(f"Created schema for tenant {tenant_id}")
else:
logger.info(f"Schema already exists for tenant {tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
run_alembic_migrations(tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
setup_danswer(db_session, tenant_id)
configure_default_api_keys(db_session)
add_users_to_tenant([email], tenant_id)
return {
"status": "success",
"message": f"Tenant {tenant_id} created successfully",
}
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}: {str(e)}")
raise HTTPException(
status_code=500, detail=f"Failed to create tenant: {str(e)}"
)
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)

View File

@@ -33,8 +33,3 @@ class CheckoutSessionCreationResponse(BaseModel):
class ImpersonateRequest(BaseModel):
email: str
class TenantCreationPayload(BaseModel):
tenant_id: str
email: str

View File

@@ -1,210 +1,145 @@
import asyncio
import logging
import uuid
import os
from types import SimpleNamespace
import aiohttp # Async HTTP client
from fastapi import HTTPException
from sqlalchemy import text
from sqlalchemy.orm import Session
from sqlalchemy.schema import CreateSchema
from danswer.auth.users import exceptions
from danswer.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from alembic import command
from alembic.config import Config
from danswer.db.engine import build_connection_string
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_cloud_embedding_provider
from danswer.db.llm import upsert_llm_provider
from danswer.db.models import UserTenantMapping
from danswer.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
from danswer.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
from danswer.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
from danswer.llm.llm_provider_options import OPENAI_PROVIDER_NAME
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
from danswer.setup import setup_danswer
from danswer.utils.logger import setup_logger
from ee.danswer.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
from ee.danswer.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.danswer.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.danswer.server.tenants.access import generate_data_plane_token
from ee.danswer.server.tenants.models import TenantCreationPayload
from ee.danswer.server.tenants.schema_management import create_schema_if_not_exists
from ee.danswer.server.tenants.schema_management import drop_schema
from ee.danswer.server.tenants.schema_management import run_alembic_migrations
from ee.danswer.server.tenants.user_mapping import add_users_to_tenant
from ee.danswer.server.tenants.user_mapping import get_tenant_id_for_email
from ee.danswer.server.tenants.user_mapping import user_owns_a_tenant
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.enums import EmbeddingProvider
logger = logging.getLogger(__name__)
logger = setup_logger()
async def get_or_create_tenant_id(email: str) -> str:
"""Get existing tenant ID for an email or create a new tenant if none exists."""
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
def run_alembic_migrations(schema_name: str) -> None:
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
try:
tenant_id = get_tenant_id_for_email(email)
except exceptions.UserNotExists:
# If tenant does not exist and in Multi tenant mode, provision a new tenant
try:
tenant_id = await create_tenant(email)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))
alembic_ini_path = os.path.join(root_dir, "alembic.ini")
if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
# Configure Alembic
alembic_cfg = Config(alembic_ini_path)
alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string())
alembic_cfg.set_main_option(
"script_location", os.path.join(root_dir, "alembic")
)
return tenant_id
# Ensure that logging isn't broken
alembic_cfg.attributes["configure_logger"] = False
# Mimic command-line options by adding 'cmd_opts' to the config
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
async def create_tenant(email: str) -> str:
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
try:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane
await notify_control_plane(tenant_id, email)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
await rollback_tenant_provisioning(tenant_id)
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
return tenant_id
# Run migrations programmatically
command.upgrade(alembic_cfg, "head")
async def provision_tenant(tenant_id: str, email: str) -> None:
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
if user_owns_a_tenant(email):
raise HTTPException(
status_code=409, detail="User already belongs to an organization"
# Run migrations programmatically
logger.info(
f"Alembic migrations completed successfully for schema: {schema_name}"
)
logger.info(f"Provisioning tenant: {tenant_id}")
token = None
try:
if not create_schema_if_not_exists(tenant_id):
logger.info(f"Created schema for tenant {tenant_id}")
else:
logger.info(f"Schema already exists for tenant {tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Await the Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
setup_danswer(db_session, tenant_id)
configure_default_api_keys(db_session)
add_users_to_tenant([email], tenant_id)
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}")
raise HTTPException(
status_code=500, detail=f"Failed to create tenant: {str(e)}"
)
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def notify_control_plane(tenant_id: str, email: str) -> None:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
payload = TenantCreationPayload(tenant_id=tenant_id, email=email)
async with aiohttp.ClientSession() as session:
async with session.post(
f"{CONTROL_PLANE_API_BASE_URL}/tenants/create",
headers=headers,
json=payload.model_dump(),
) as response:
if response.status != 200:
error_text = await response.text()
logger.error(f"Control plane tenant creation failed: {error_text}")
raise Exception(
f"Failed to create tenant on control plane: {error_text}"
)
async def rollback_tenant_provisioning(tenant_id: str) -> None:
# Logic to rollback tenant provisioning on data plane
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
try:
# Drop the tenant's schema to rollback provisioning
drop_schema(tenant_id)
# Remove tenant mapping
with Session(get_sqlalchemy_engine()) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
except Exception as e:
logger.error(f"Failed to rollback tenant provisioning: {e}")
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
raise
def configure_default_api_keys(db_session: Session) -> None:
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4",
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
)
try:
full_provider = upsert_llm_provider(open_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure OpenAI provider: {e}")
else:
logger.error(
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
)
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider="OpenAI",
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4o",
)
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
provider="Anthropic",
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-5-sonnet-20240620",
)
upsert_llm_provider(open_provider, db_session)
upsert_llm_provider(anthropic_provider, db_session)
if ANTHROPIC_DEFAULT_API_KEY:
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-5-sonnet-20241022",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
)
try:
full_provider = upsert_llm_provider(anthropic_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure Anthropic provider: {e}")
else:
logger.error(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
)
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
provider_type=EmbeddingProvider.COHERE,
api_key=COHERE_DEFAULT_API_KEY,
)
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
if COHERE_DEFAULT_API_KEY:
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
provider_type=EmbeddingProvider.COHERE,
api_key=COHERE_DEFAULT_API_KEY,
def ensure_schema_exists(tenant_id: str) -> bool:
with Session(get_sqlalchemy_engine()) as db_session:
with db_session.begin():
result = db_session.execute(
text(
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"
),
{"schema_name": tenant_id},
)
schema_exists = result.scalar() is not None
if not schema_exists:
stmt = CreateSchema(tenant_id)
db_session.execute(stmt)
return True
return False
# For now, we're implementing a primitive mapping between users and tenants.
# This function is only used to determine a user's relationship to a tenant upon creation (implying ownership).
def user_owns_a_tenant(email: str) -> bool:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
result = (
db_session.query(UserTenantMapping)
.filter(UserTenantMapping.email == email)
.first()
)
return result is not None
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
except Exception as e:
logger.error(f"Failed to configure Cohere embedding provider: {e}")
else:
logger.error(
"COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration"
)
logger.exception(f"Failed to add users to tenant {tenant_id}: {str(e)}")
db_session.commit()
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
mappings_to_delete = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email.in_(emails),
UserTenantMapping.tenant_id == tenant_id,
)
.all()
)
for mapping in mappings_to_delete:
db_session.delete(mapping)
db_session.commit()
except Exception as e:
logger.exception(
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
)
db_session.rollback()

View File

@@ -1,76 +0,0 @@
import logging
import os
from types import SimpleNamespace
from sqlalchemy import text
from sqlalchemy.orm import Session
from sqlalchemy.schema import CreateSchema
from alembic import command
from alembic.config import Config
from danswer.db.engine import build_connection_string
from danswer.db.engine import get_sqlalchemy_engine
logger = logging.getLogger(__name__)
def run_alembic_migrations(schema_name: str) -> None:
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))
alembic_ini_path = os.path.join(root_dir, "alembic.ini")
# Configure Alembic
alembic_cfg = Config(alembic_ini_path)
alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string())
alembic_cfg.set_main_option(
"script_location", os.path.join(root_dir, "alembic")
)
# Ensure that logging isn't broken
alembic_cfg.attributes["configure_logger"] = False
# Mimic command-line options by adding 'cmd_opts' to the config
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
# Run migrations programmatically
command.upgrade(alembic_cfg, "head")
# Run migrations programmatically
logger.info(
f"Alembic migrations completed successfully for schema: {schema_name}"
)
except Exception as e:
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
raise
def create_schema_if_not_exists(tenant_id: str) -> bool:
with Session(get_sqlalchemy_engine()) as db_session:
with db_session.begin():
result = db_session.execute(
text(
"SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"
),
{"schema_name": tenant_id},
)
schema_exists = result.scalar() is not None
if not schema_exists:
stmt = CreateSchema(tenant_id)
db_session.execute(stmt)
return True
return False
def drop_schema(tenant_id: str) -> None:
if not tenant_id.isidentifier():
raise ValueError("Invalid tenant_id.")
with get_sqlalchemy_engine().connect() as connection:
connection.execute(
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
{"schema_name": tenant_id},
)

View File

@@ -1,70 +0,0 @@
import logging
from fastapi_users import exceptions
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import UserTenantMapping
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = logging.getLogger(__name__)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
def user_owns_a_tenant(email: str) -> bool:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
result = (
db_session.query(UserTenantMapping)
.filter(UserTenantMapping.email == email)
.first()
)
return result is not None
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
except Exception:
logger.exception(f"Failed to add users to tenant {tenant_id}")
db_session.commit()
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
mappings_to_delete = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email.in_(emails),
UserTenantMapping.tenant_id == tenant_id,
)
.all()
)
for mapping in mappings_to_delete:
db_session.delete(mapping)
db_session.commit()
except Exception as e:
logger.exception(
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
)
db_session.rollback()

View File

@@ -8,14 +8,14 @@ from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.token_limit import fetch_all_user_token_rate_limits
from danswer.db.token_limit import insert_user_token_rate_limit
from danswer.server.query_and_chat.token_limit import any_rate_limit_exists
from danswer.server.token_rate_limits.models import TokenRateLimitArgs
from danswer.server.token_rate_limits.models import TokenRateLimitDisplay
from ee.danswer.db.token_limit import fetch_all_user_group_token_rate_limits_by_group
from ee.danswer.db.token_limit import fetch_all_user_token_rate_limits
from ee.danswer.db.token_limit import fetch_user_group_token_rate_limits
from ee.danswer.db.token_limit import insert_user_group_token_rate_limit
from ee.danswer.db.token_limit import insert_user_token_rate_limit
router = APIRouter(prefix="/admin/token-rate-limits")

View File

@@ -15,9 +15,9 @@ docker rm danswer_postgres danswer_vespa danswer_redis
# Start the PostgreSQL container with optional volume
echo "Starting PostgreSQL container..."
if [[ -n "$POSTGRES_VOLUME" ]]; then
docker run -p 5432:5432 --name danswer_postgres -e POSTGRES_PASSWORD=password -d -v $POSTGRES_VOLUME:/var/lib/postgresql/data postgres -c max_connections=250
docker run -p 5432:5432 --name danswer_postgres -e POSTGRES_PASSWORD=password -d -v $POSTGRES_VOLUME:/var/lib/postgresql/data postgres
else
docker run -p 5432:5432 --name danswer_postgres -e POSTGRES_PASSWORD=password -d postgres -c max_connections=250
docker run -p 5432:5432 --name danswer_postgres -e POSTGRES_PASSWORD=password -d postgres
fi
# Start the Vespa container with optional volume

View File

@@ -1,5 +1,4 @@
import os
from typing import Any
from typing import List
from urllib.parse import urlparse
@@ -134,11 +133,6 @@ MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
POSTGRES_DEFAULT_SCHEMA = os.environ.get("POSTGRES_DEFAULT_SCHEMA") or "public"
async def async_return_default_schema(*args: Any, **kwargs: Any) -> str:
return POSTGRES_DEFAULT_SCHEMA
# Prefix used for all tenant ids
TENANT_ID_PREFIX = "tenant_"

View File

@@ -3,7 +3,7 @@ from uuid import uuid4
import requests
from danswer.db.models import UserRole
from danswer.server.api_key.models import APIKeyArgs
from ee.danswer.server.api_key.models import APIKeyArgs
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestAPIKey

10
ct.yaml
View File

@@ -1,18 +1,12 @@
# See https://github.com/helm/chart-testing#configuration
# still have to specify this on the command line for list-changed
chart-dirs:
- deployment/helm/charts
# must be kept in sync with Chart.yaml
chart-repos:
- vespa=https://danswer-ai.github.io/vespa-helm-charts
- vespa=https://unoplat.github.io/vespa-helm-charts
- postgresql=https://charts.bitnami.com/bitnami
helm-extra-args: --debug --timeout 600s
# nginx appears to not work on kind, likely due to lack of loadbalancer support
# helm-extra-set-args also only works on the command line, not in this yaml
# helm-extra-set-args: --set=nginx.enabled=false
helm-extra-args: --timeout 600s
validate-maintainers: false

View File

@@ -1,13 +0,0 @@
apiVersion: keda.sh/v1alpha1
kind: TriggerAuthentication
metadata:
name: celery-worker-auth
namespace: danswer
spec:
secretTargetRef:
- parameter: host
name: keda-redis-secret
key: host
- parameter: password
name: keda-redis-secret
key: password

View File

@@ -1,46 +0,0 @@
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:
name: celery-worker-indexing-scaledobject
namespace: danswer
labels:
app: celery-worker-indexing
spec:
scaleTargetRef:
name: celery-worker-indexing
minReplicaCount: 1
maxReplicaCount: 10
triggers:
- type: redis
metadata:
sslEnabled: "true"
host: "{host}"
port: "6379"
enableTLS: "true"
listName: connector_indexing
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
sslEnabled: "true"
host: "{host}"
port: "6379"
enableTLS: "true"
listName: connector_indexing:2
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
sslEnabled: "true"
host: "{host}"
port: "6379"
enableTLS: "true"
listName: connector_indexing:3
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth

View File

@@ -1,63 +0,0 @@
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:
name: celery-worker-light-scaledobject
namespace: danswer
labels:
app: celery-worker-light
spec:
scaleTargetRef:
name: celery-worker-light
minReplicaCount: 1
maxReplicaCount: 20
triggers:
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: vespa_metadata_sync
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: vespa_metadata_sync:2
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: vespa_metadata_sync:3
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: connector_deletion
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: connector_deletion:2
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth

View File

@@ -1,76 +0,0 @@
apiVersion: keda.sh/v1alpha1
kind: ScaledObject
metadata:
name: celery-worker-primary-scaledobject
namespace: danswer
labels:
app: celery-worker-primary
spec:
scaleTargetRef:
name: celery-worker-primary
pollingInterval: 15 # Check every 15 seconds
cooldownPeriod: 30 # Wait 30 seconds before scaling down
minReplicaCount: 1
maxReplicaCount: 1
triggers:
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: celery
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: celery:1
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: celery:2
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: celery:3
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: periodic_tasks
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth
- type: redis
metadata:
host: "{host}"
port: "6379"
enableTLS: "true"
listName: periodic_tasks:2
listLength: "1"
databaseIndex: "15"
authenticationRef:
name: celery-worker-auth

View File

@@ -1,9 +0,0 @@
apiVersion: v1
kind: Secret
metadata:
name: keda-redis-secret
namespace: danswer
type: Opaque
data:
host: { { base64-encoded-hostname } }
password: { { base64-encoded-password } }

View File

@@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-beat
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
imagePullPolicy: Always
command:
[
@@ -31,7 +31,7 @@ spec:
name: danswer-secrets
key: redis_password
- name: DANSWER_VERSION
value: "v0.11.0-cloud.beta.8"
value: "v0.11.0-cloud.beta.4"
envFrom:
- configMapRef:
name: env-configmap

View File

@@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-worker-heavy
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
imagePullPolicy: Always
command:
[
@@ -34,7 +34,7 @@ spec:
name: danswer-secrets
key: redis_password
- name: DANSWER_VERSION
value: "v0.11.0-cloud.beta.8"
value: "v0.11.0-cloud.beta.4"
envFrom:
- configMapRef:
name: env-configmap

View File

@@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-worker-indexing
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
imagePullPolicy: Always
command:
[
@@ -26,8 +26,6 @@ spec:
"--hostname=indexing@%n",
"-Q",
"connector_indexing",
"--prefetch-multiplier=1",
"--concurrency=10",
]
env:
- name: REDIS_PASSWORD
@@ -36,7 +34,7 @@ spec:
name: danswer-secrets
key: redis_password
- name: DANSWER_VERSION
value: "v0.11.0-cloud.beta.8"
value: "v0.11.0-cloud.beta.4"
envFrom:
- configMapRef:
name: env-configmap

View File

@@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-worker-light
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
imagePullPolicy: Always
command:
[
@@ -26,8 +26,6 @@ spec:
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion",
"--prefetch-multiplier=1",
"--concurrency=10",
]
env:
- name: REDIS_PASSWORD
@@ -36,7 +34,7 @@ spec:
name: danswer-secrets
key: redis_password
- name: DANSWER_VERSION
value: "v0.11.0-cloud.beta.8"
value: "v0.11.0-cloud.beta.4"
envFrom:
- configMapRef:
name: env-configmap

View File

@@ -14,7 +14,7 @@ spec:
spec:
containers:
- name: celery-worker-primary
image: danswer/danswer-backend-cloud:v0.12.0-cloud.beta.2
image: danswer/danswer-backend:v0.11.0-cloud.beta.4
imagePullPolicy: Always
command:
[
@@ -25,9 +25,7 @@ spec:
"--loglevel=INFO",
"--hostname=primary@%n",
"-Q",
"celery,periodic_tasks",
"--prefetch-multiplier=1",
"--concurrency=10",
"celery,periodic_tasks,vespa_metadata_sync",
]
env:
- name: REDIS_PASSWORD
@@ -36,7 +34,7 @@ spec:
name: danswer-secrets
key: redis_password
- name: DANSWER_VERSION
value: "v0.11.0-cloud.beta.8"
value: "v0.11.0-cloud.beta.4"
envFrom:
- configMapRef:
name: env-configmap

View File

@@ -221,10 +221,6 @@ services:
max-size: "50m"
max-file: "6"
# Uncomment the following lines if you need to include a custom CA certificate
# This section enables the use of a custom CA certificate
# If present, the custom CA certificate is mounted as a volume
# The container checks for its existence and updates the system's CA certificates
# This allows for secure communication with services using custom SSL certificates
# Optional volume mount for CA certificate
# volumes:
# # Maps to the CA_CERT_PATH environment variable in the Dockerfile

View File

@@ -65,10 +65,6 @@ services:
max-size: "50m"
max-file: "6"
# Uncomment the following lines if you need to include a custom CA certificate
# This section enables the use of a custom CA certificate
# If present, the custom CA certificate is mounted as a volume
# The container checks for its existence and updates the system's CA certificates
# This allows for secure communication with services using custom SSL certificates
# volumes:
# # Maps to the CA_CERT_PATH environment variable in the Dockerfile
# - ${CA_CERT_PATH:-./custom-ca.crt}:/etc/ssl/certs/custom-ca.crt:ro

View File

@@ -3,13 +3,13 @@ dependencies:
repository: https://charts.bitnami.com/bitnami
version: 14.3.1
- name: vespa
repository: https://danswer-ai.github.io/vespa-helm-charts
version: 0.2.16
repository: https://unoplat.github.io/vespa-helm-charts
version: 0.2.3
- name: nginx
repository: oci://registry-1.docker.io/bitnamicharts
version: 15.14.0
- name: redis
repository: https://charts.bitnami.com/bitnami
version: 20.1.0
digest: sha256:711bbb76ba6ab604a36c9bf1839ab6faa5610afb21e535afd933c78f2d102232
generated: "2024-11-07T09:39:30.17171-08:00"
digest: sha256:fb42426c1d13667a4929d0d6a7d681bf08120e4a4eb1d15437e4ec70920be3f8
generated: "2024-09-11T09:16:03.312328-07:00"

View File

@@ -5,7 +5,7 @@ home: https://www.danswer.ai/
sources:
- "https://github.com/danswer-ai/danswer"
type: application
version: 0.2.1
version: 0.2.0
appVersion: "latest"
annotations:
category: Productivity
@@ -23,8 +23,8 @@ dependencies:
repository: https://charts.bitnami.com/bitnami
condition: postgresql.enabled
- name: vespa
version: 0.2.16
repository: https://danswer-ai.github.io/vespa-helm-charts
version: 0.2.3
repository: https://unoplat.github.io/vespa-helm-charts
condition: vespa.enabled
- name: nginx
version: 15.14.0

View File

@@ -7,7 +7,7 @@ metadata:
data:
INTERNAL_URL: "http://{{ include "danswer-stack.fullname" . }}-api-service:{{ .Values.api.service.port | default 8080 }}"
POSTGRES_HOST: {{ .Release.Name }}-postgresql
VESPA_HOST: da-vespa-0.vespa-service
VESPA_HOST: "document-index-service"
REDIS_HOST: {{ .Release.Name }}-redis-master
MODEL_SERVER_HOST: "{{ include "danswer-stack.fullname" . }}-inference-model-service"
INDEXING_MODEL_SERVER_HOST: "{{ include "danswer-stack.fullname" . }}-indexing-model-service"

View File

@@ -11,5 +11,5 @@ spec:
- name: wget
image: busybox
command: ['wget']
args: ['{{ include "danswer-stack.fullname" . }}-webserver:{{ .Values.webserver.service.port }}']
args: ['{{ include "danswer-stack.fullname" . }}:{{ .Values.webserver.service.port }}']
restartPolicy: Never

View File

@@ -17,13 +17,11 @@ spec:
image: danswer/danswer-backend:latest
imagePullPolicy: IfNotPresent
command:
- "/bin/sh"
- "-c"
- |
if [ -f /etc/ssl/certs/custom-ca.crt ]; then
update-ca-certificates;
fi &&
/usr/bin/supervisord -c /etc/supervisor/conf.d/supervisord.conf
[
"/usr/bin/supervisord",
"-c",
"/etc/supervisor/conf.d/supervisord.conf",
]
env:
- name: REDIS_PASSWORD
valueFrom:
@@ -34,10 +32,6 @@ spec:
- configMapRef:
name: env-configmap
# Uncomment the following lines if you need to include a custom CA certificate
# This section allows for the inclusion of a custom CA certificate
# If a custom CA certificate is present, it updates the system's CA certificates
# This is useful for environments with self-signed or internal CA certificates
# The certificate is mounted as a volume and the container checks for its presence
# Optional volume mount for CA certificate
# volumeMounts:
# - name: my-ca-cert-volume

Binary file not shown.

Before

Width:  |  Height:  |  Size: 10 KiB

View File

@@ -11,7 +11,6 @@ import {
LLM_PROVIDERS_ADMIN_URL,
} from "../../configuration/llm/constants";
import { mutate } from "swr";
import { testEmbedding } from "../pages/utils";
export function ChangeCredentialsModal({
provider,
@@ -113,15 +112,16 @@ export function ChangeCredentialsModal({
const normalizedProviderType = provider.provider_type
.toLowerCase()
.split(" ")[0];
try {
const testResponse = await testEmbedding({
provider_type: normalizedProviderType,
modelName,
apiKey,
apiUrl,
apiVersion: null,
deploymentName: null,
const testResponse = await fetch("/api/admin/embedding/test-embedding", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_type: normalizedProviderType,
api_key: apiKey,
api_url: apiUrl,
model_name: modelName,
}),
});
if (!testResponse.ok) {

View File

@@ -110,27 +110,20 @@ export function ProviderCreationModal({
setErrorMsg("");
try {
const customConfig = Object.fromEntries(values.custom_config);
const providerType = values.provider_type.toLowerCase().split(" ")[0];
const isOpenAI = providerType === "openai";
const testModelName =
isOpenAI || isAzure ? "text-embedding-3-small" : values.model_name;
const testEmbeddingPayload = {
provider_type: providerType,
api_key: values.api_key,
api_url: values.api_url,
model_name: testModelName,
api_version: values.api_version,
deployment_name: values.deployment_name,
};
const initialResponse = await fetch(
"/api/admin/embedding/test-embedding",
{
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(testEmbeddingPayload),
body: JSON.stringify({
provider_type: values.provider_type.toLowerCase().split(" ")[0],
api_key: values.api_key,
api_url: values.api_url,
model_name: values.model_name,
api_version: values.api_version,
deployment_name: values.deployment_name,
}),
}
);

View File

@@ -8,37 +8,3 @@ export const deleteSearchSettings = async (search_settings_id: number) => {
});
return response;
};
export const testEmbedding = async ({
provider_type,
modelName,
apiKey,
apiUrl,
apiVersion,
deploymentName,
}: {
provider_type: string;
modelName: string;
apiKey: string | null;
apiUrl: string | null;
apiVersion: string | null;
deploymentName: string | null;
}) => {
const testModelName =
provider_type === "openai" ? "text-embedding-3-small" : modelName;
const testResponse = await fetch("/api/admin/embedding/test-embedding", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_type: provider_type,
api_key: apiKey,
api_url: apiUrl,
model_name: testModelName,
api_version: apiVersion,
deployment_name: deploymentName,
}),
});
return testResponse;
};

View File

@@ -4,6 +4,7 @@ import { fetchChatData } from "@/lib/chat/fetchChatData";
import { unstable_noStore as noStore } from "next/cache";
import { redirect } from "next/navigation";
import WrappedAssistantsGallery from "./WrappedAssistantsGallery";
import { AssistantsProvider } from "@/components/context/AssistantsContext";
import { cookies } from "next/headers";
export default async function GalleryPage(props: {

View File

@@ -131,8 +131,6 @@ export function ChatPage({
const {
chatSessions,
availableSources,
availableDocumentSets,
llmProviders,
folders,
openedFolders,
@@ -2173,6 +2171,25 @@ export function ChatPage({
) {
return <></>;
}
const mostRecentNonAIParent = messageHistory
.slice(0, i)
.reverse()
.find((msg) => msg.type !== "assistant");
const hasChildMessage =
message.latestChildMessageId !== null &&
message.latestChildMessageId !== undefined;
const childMessage = hasChildMessage
? messageMap.get(
message.latestChildMessageId!
)
: null;
const hasParentAI =
parentMessage?.type == "assistant";
const hasChildAI =
childMessage?.type == "assistant";
return (
<div
id={`message-${message.messageId}`}
@@ -2184,6 +2201,9 @@ export function ChatPage({
}
>
<AIMessage
setPopup={setPopup}
hasChildAI={hasChildAI}
hasParentAI={hasParentAI}
continueGenerating={
i == messageHistory.length - 1 &&
currentCanContinue()
@@ -2193,7 +2213,7 @@ export function ChatPage({
overriddenModel={message.overridden_model}
regenerate={createRegenerator({
messageId: message.messageId,
parentMessage: parentMessage!,
parentMessage: mostRecentNonAIParent!,
})}
otherMessagesCanSwitchTo={
parentMessage?.childrenMessageIds || []
@@ -2340,6 +2360,7 @@ export function ChatPage({
return (
<div key={messageReactComponentKey}>
<AIMessage
setPopup={setPopup}
currentPersona={liveAssistant}
messageId={message.messageId}
content={
@@ -2382,6 +2403,7 @@ export function ChatPage({
key={`${messageHistory.length}-${chatSessionIdRef.current}`}
>
<AIMessage
setPopup={setPopup}
key={-3}
currentPersona={liveAssistant}
alternativeAssistant={
@@ -2406,6 +2428,7 @@ export function ChatPage({
{loadingError && (
<div key={-1}>
<AIMessage
setPopup={setPopup}
currentPersona={liveAssistant}
messageId={-1}
content={

View File

@@ -144,3 +144,10 @@ export interface StreamingError {
error: string;
stack_trace: string;
}
export interface ImageGenerationResult {
revised_prompt: string;
url: string;
}
export type ImageGenerationResults = ImageGenerationResult[];

View File

@@ -1,29 +1,73 @@
import { Citation } from "@/components/search/results/Citation";
import React, { memo } from "react";
import { IMAGE_GENERATION_TOOL_NAME } from "../tools/constants";
export const MemoizedLink = memo((props: any) => {
const { node, ...rest } = props;
const value = rest.children;
import {
Popover,
PopoverContent,
PopoverTrigger,
} from "@/components/ui/popover";
import { SearchIcon } from "lucide-react";
import DualPromptDisplay from "../tools/ImageCitation";
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
import { ImageGenerationResults, ToolCallFinalResult } from "../interfaces";
if (value?.toString().startsWith("*")) {
return (
<div className="flex-none bg-background-800 inline-block rounded-full h-3 w-3 ml-2" />
);
} else if (value?.toString().startsWith("[")) {
return <Citation link={rest?.href}>{rest.children}</Citation>;
} else {
return (
<a
onMouseDown={() =>
rest.href ? window.open(rest.href, "_blank") : undefined
}
className="cursor-pointer text-link hover:text-link-hover"
>
{rest.children}
</a>
);
export const MemoizedLink = memo(
({
toolCall,
setPopup,
...props
}: {
toolCall?: ToolCallFinalResult;
setPopup: (popupSpec: PopupSpec | null) => void;
} & any) => {
const { node, ...rest } = props;
const value = rest.children;
if (value?.toString().startsWith(IMAGE_GENERATION_TOOL_NAME)) {
const imageGenerationResult =
toolCall?.tool_result as ImageGenerationResults;
return (
<Popover>
<PopoverTrigger asChild>
<span className="inline-block">
<SearchIcon className="cursor-pointer flex-none text-blue-500 hover:text-blue-700 !h-4 !w-4 inline-block" />
</span>
</PopoverTrigger>
<PopoverContent className="w-96" side="top" align="center">
<DualPromptDisplay
arg="Prompt"
setPopup={setPopup!}
prompts={imageGenerationResult.map(
(result) => result.revised_prompt
)}
/>
</PopoverContent>
</Popover>
);
}
if (value?.toString().startsWith("*")) {
return (
<div className="flex-none bg-background-800 inline-block rounded-full h-3 w-3 ml-2" />
);
} else if (value?.toString().startsWith("[")) {
return <Citation link={rest?.href}>{rest.children}</Citation>;
} else {
return (
<a
onMouseDown={() =>
rest.href ? window.open(rest.href, "_blank") : undefined
}
className="cursor-pointer text-link hover:text-link-hover"
>
{rest.children}
</a>
);
}
}
});
);
export const MemoizedParagraph = memo(({ ...props }: any) => {
return <p {...props} className="text-default" />;

View File

@@ -62,6 +62,7 @@ import { MemoizedLink, MemoizedParagraph } from "./MemoizedTextComponents";
import { extractCodeText } from "./codeUtils";
import ToolResult from "../../../components/tools/ToolResult";
import CsvContent from "../../../components/tools/CSVContent";
import { PopupSpec } from "@/components/admin/connectors/Popup";
const TOOLS_WITH_CUSTOM_HANDLING = [
SEARCH_TOOL_NAME,
@@ -153,6 +154,9 @@ function FileDisplay({
}
export const AIMessage = ({
setPopup,
hasChildAI,
hasParentAI,
regenerate,
overriddenModel,
continueGenerating,
@@ -179,6 +183,9 @@ export const AIMessage = ({
otherMessagesCanSwitchTo,
onMessageSelection,
}: {
setPopup?: (popupSpec: PopupSpec | null) => void;
hasChildAI?: boolean;
hasParentAI?: boolean;
shared?: boolean;
isActive?: boolean;
continueGenerating?: () => void;
@@ -227,6 +234,13 @@ export const AIMessage = ({
return content;
}
}
if (
isComplete &&
toolCall?.tool_result &&
toolCall.tool_name == IMAGE_GENERATION_TOOL_NAME
) {
return content + ` [${toolCall.tool_name}]()`;
}
return content + (!isComplete && !toolCallGenerating ? " [*]() " : "");
};
@@ -296,7 +310,9 @@ export const AIMessage = ({
const markdownComponents = useMemo(
() => ({
a: MemoizedLink,
a: (props: any) => (
<MemoizedLink {...props} toolCall={toolCall} setPopup={setPopup} />
),
p: MemoizedParagraph,
code: ({ node, className, children, ...props }: any) => {
const codeText = extractCodeText(
@@ -312,7 +328,7 @@ export const AIMessage = ({
);
},
}),
[finalContent]
[finalContent, toolCall]
);
const renderedMarkdown = useMemo(() => {
@@ -338,7 +354,7 @@ export const AIMessage = ({
<div
id="danswer-ai-message"
ref={trackedElementRef}
className={"py-5 ml-4 px-5 relative flex "}
className={`${hasParentAI ? "pb-5" : "py-5"} px-2 lg:px-5 relative flex `}
>
<div
className={`mx-auto ${
@@ -347,10 +363,14 @@ export const AIMessage = ({
>
<div className={`desktop:mr-12 ${!shared && "mobile:ml-0 md:ml-8"}`}>
<div className="flex">
<AssistantIcon
size="small"
assistant={alternativeAssistant || currentPersona}
/>
{!hasParentAI ? (
<AssistantIcon
size="small"
assistant={alternativeAssistant || currentPersona}
/>
) : (
<div className="w-6" />
)}
<div className="w-full">
<div className="max-w-message-max break-words">
@@ -514,7 +534,8 @@ export const AIMessage = ({
)}
</div>
{handleFeedback &&
{!hasChildAI &&
handleFeedback &&
(isActive ? (
<div
className={`

Some files were not shown because too many files have changed in this diff Show More