Compare commits

..

1 Commits

Author SHA1 Message Date
pablodanswer
87c540246d add cloud configs 2024-10-10 12:29:20 -07:00
253 changed files with 3712 additions and 10653 deletions

View File

@@ -167,7 +167,7 @@ jobs:
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v4
uses: actions/upload-artifact@v3
with:
name: docker-logs
path: ${{ github.workspace }}/docker-compose.log

View File

@@ -14,10 +14,10 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v4
uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v4
with:
python-version: '3.11'
cache: 'pip'

View File

@@ -32,7 +32,7 @@ jobs:
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v4
with:
python-version: "3.11"
cache: "pip"

View File

@@ -27,7 +27,7 @@ jobs:
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v4
with:
python-version: "3.11"
cache: "pip"

View File

@@ -21,7 +21,7 @@ jobs:
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
uses: actions/setup-python@v4
with:
python-version: '3.11'
cache: 'pip'

View File

@@ -18,6 +18,6 @@ jobs:
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- uses: pre-commit/action@v3.0.1
- uses: pre-commit/action@v3.0.0
with:
extra_args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || '' }}

View File

@@ -74,7 +74,7 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
* Organizational understanding and ability to locate and suggest experts from your team.
## Other Notable Benefits of Danswer
## Other Noteable Benefits of Danswer
* User Authentication with document level access management.
* Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
* Admin Dashboard to configure connectors, document-sets, access, etc.

View File

@@ -8,12 +8,10 @@ Edition features outside of personal development or testing purposes. Please rea
founders@danswer.ai for more information. Please visit https://github.com/danswer-ai/danswer"
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.8-dev
ARG DANSWER_VERSION=0.3-dev
ENV DANSWER_VERSION=${DANSWER_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
ARG CA_CERT_CONTENT=""
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# Install system dependencies
# cmake needed for psycopg (postgres)
@@ -38,17 +36,6 @@ RUN apt-get update && \
rm -rf /var/lib/apt/lists/* && \
apt-get clean
# Conditionally write the CA certificate and update certificates
RUN if [ -n "$CA_CERT_CONTENT" ]; then \
echo "Adding custom CA certificate"; \
echo "$CA_CERT_CONTENT" > /usr/local/share/ca-certificates/my-ca.crt && \
chmod 644 /usr/local/share/ca-certificates/my-ca.crt && \
update-ca-certificates; \
else \
echo "No custom CA certificate provided"; \
fi
# Install Python dependencies
# Remove py which is pulled in by retry, py is not needed and is a CVE
COPY ./requirements/default.txt /tmp/requirements.txt
@@ -105,7 +92,6 @@ COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
COPY ./danswer /app/danswer
COPY ./shared_configs /app/shared_configs
COPY ./alembic /app/alembic
COPY ./alembic_tenants /app/alembic_tenants
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf

109
backend/Dockerfile.cloud Normal file
View File

@@ -0,0 +1,109 @@
FROM python:3.11.7-slim-bookworm
LABEL com.danswer.maintainer="founders@danswer.ai"
LABEL com.danswer.description="This image is the web/frontend container of Danswer which \
contains code for both the Community and Enterprise editions of Danswer. If you do not \
have a contract or agreement with DanswerAI, you are not permitted to use the Enterprise \
Edition features outside of personal development or testing purposes. Please reach out to \
founders@danswer.ai for more information. Please visit https://github.com/danswer-ai/danswer"
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.3-dev
ENV DANSWER_VERSION=${DANSWER_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
# Install system dependencies
# cmake needed for psycopg (postgres)
# libpq-dev needed for psycopg (postgres)
# curl included just for users' convenience
# zip for Vespa step futher down
# ca-certificates for HTTPS
RUN apt-get update && \
apt-get install -y \
cmake \
curl \
zip \
ca-certificates \
libgnutls30=3.7.9-2+deb12u3 \
libblkid1=2.38.1-5+deb12u1 \
libmount1=2.38.1-5+deb12u1 \
libsmartcols1=2.38.1-5+deb12u1 \
libuuid1=2.38.1-5+deb12u1 \
libxmlsec1-dev \
pkg-config \
gcc && \
rm -rf /var/lib/apt/lists/* && \
apt-get clean
# Install Python dependencies
# Remove py which is pulled in by retry, py is not needed and is a CVE
COPY ./requirements/default.txt /tmp/requirements.txt
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
RUN pip install --no-cache-dir --upgrade \
--retries 5 \
--timeout 30 \
-r /tmp/requirements.txt \
-r /tmp/ee-requirements.txt && \
pip uninstall -y py && \
playwright install chromium && \
playwright install-deps chromium && \
ln -s /usr/local/bin/supervisord /usr/bin/supervisord
# Cleanup for CVEs and size reduction
# https://github.com/tornadoweb/tornado/issues/3107
# xserver-common and xvfb included by playwright installation but not needed after
# perl-base is part of the base Python Debian image but not needed for Danswer functionality
# perl-base could only be removed with --allow-remove-essential
RUN apt-get update && \
apt-get remove -y --allow-remove-essential \
perl-base \
xserver-common \
xvfb \
cmake \
libldap-2.5-0 \
libxmlsec1-dev \
pkg-config \
gcc && \
apt-get install -y libxmlsec1-openssl && \
apt-get autoremove -y && \
rm -rf /var/lib/apt/lists/* && \
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Pre-downloading models for setups with limited egress
RUN python -c "from tokenizers import Tokenizer; \
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
# Pre-downloading NLTK for setups with limited egress
RUN python -c "import nltk; \
nltk.download('stopwords', quiet=True); \
nltk.download('punkt', quiet=True);"
# nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed
# Set up application files
WORKDIR /app
# Enterprise Version Files
COPY ./ee /app/ee
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
# Set up application files
COPY ./danswer /app/danswer
COPY ./shared_configs /app/shared_configs
COPY ./alembic /app/alembic
COPY ./alembic_tenants /app/alembic_tenants
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf
# Escape hatch
COPY ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
# Put logo in assets
COPY ./assets /app/assets
ENV PYTHONPATH=/app
# Default command which does nothing
# This container is used by api server and background which specify their own CMD
CMD ["tail", "-f", "/dev/null"]

View File

@@ -7,7 +7,7 @@ You can find it at https://hub.docker.com/r/danswer/danswer-model-server. For mo
visit https://github.com/danswer-ai/danswer."
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.8-dev
ARG DANSWER_VERSION=0.3-dev
ENV DANSWER_VERSION=${DANSWER_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"

View File

@@ -1,11 +1,10 @@
from sqlalchemy.engine.base import Connection
from typing import Any
import asyncio
from logging.config import fileConfig
import logging
from alembic import context
from sqlalchemy import pool
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql import text
@@ -13,7 +12,6 @@ from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.engine import build_connection_string
from danswer.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
from danswer.background.celery.celery_app import get_all_tenant_ids
# Alembic Config object
config = context.config
@@ -24,32 +22,14 @@ if config.config_file_name is not None and config.attributes.get(
):
fileConfig(config.config_file_name)
# Add your model's MetaData object here for 'autogenerate' support
# Add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = [Base.metadata, ResultModelBase.metadata]
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
# Set up logging
logger = logging.getLogger(__name__)
def include_object(
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
) -> bool:
"""
Determines whether a database object should be included in migrations.
Excludes specified tables from migrations.
"""
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True
def get_schema_options() -> tuple[str, bool, bool]:
"""
Parses command-line options passed via '-x' in Alembic commands.
Recognizes 'schema', 'create_schema', and 'upgrade_all_tenants' options.
"""
def get_schema_options() -> tuple[str, bool]:
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
@@ -59,7 +39,49 @@ def get_schema_options() -> tuple[str, bool, bool]:
x_args[key.strip()] = value.strip()
schema_name = x_args.get("schema", "public")
create_schema = x_args.get("create_schema", "true").lower() == "true"
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
return schema_name, create_schema
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
def include_object(
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
) -> bool:
if type_ == "table" and name in EXCLUDE_TABLES:
return False
return True
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
schema_name, _ = get_schema_options()
url = build_connection_string()
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
version_table_schema=schema_name,
include_schemas=True,
script_location=config.get_main_option("script_location"),
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
schema_name, create_schema = get_schema_options()
if MULTI_TENANT and schema_name == "public":
raise ValueError(
@@ -67,17 +89,6 @@ def get_schema_options() -> tuple[str, bool, bool]:
"Please specify a tenant-specific schema."
)
return schema_name, create_schema, upgrade_all_tenants
def do_run_migrations(
connection: Connection, schema_name: str, create_schema: bool
) -> None:
"""
Executes migrations in the specified schema.
"""
logger.info(f"About to migrate schema: {schema_name}")
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
connection.execute(text("COMMIT"))
@@ -101,98 +112,18 @@ def do_run_migrations(
async def run_async_migrations() -> None:
"""
Determines whether to run migrations for a single schema or all schemas,
and executes migrations accordingly.
"""
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
engine = create_async_engine(
connectable = create_async_engine(
build_connection_string(),
poolclass=pool.NullPool,
)
if upgrade_all_tenants:
# Run migrations for all tenant schemas sequentially
tenant_schemas = get_all_tenant_ids()
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
for schema in tenant_schemas:
try:
logger.info(f"Migrating schema: {schema}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema,
create_schema=create_schema,
)
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
raise
else:
try:
logger.info(f"Migrating schema: {schema_name}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema_name,
create_schema=create_schema,
)
except Exception as e:
logger.error(f"Error migrating schema {schema_name}: {e}")
raise
await engine.dispose()
def run_migrations_offline() -> None:
"""
Run migrations in 'offline' mode.
"""
schema_name, _, upgrade_all_tenants = get_schema_options()
url = build_connection_string()
if upgrade_all_tenants:
# Run offline migrations for all tenant schemas
engine = create_async_engine(url)
tenant_schemas = get_all_tenant_ids()
engine.sync_engine.dispose()
for schema in tenant_schemas:
logger.info(f"Migrating schema: {schema}")
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
version_table_schema=schema,
include_schemas=True,
script_location=config.get_main_option("script_location"),
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
else:
logger.info(f"Migrating schema: {schema_name}")
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
literal_binds=True,
include_object=include_object,
version_table_schema=schema_name,
include_schemas=True,
script_location=config.get_main_option("script_location"),
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
await connectable.dispose()
def run_migrations_online() -> None:
"""
Runs migrations in 'online' mode using an asynchronous engine.
"""
asyncio.run(run_async_migrations())

View File

@@ -1,26 +0,0 @@
"""add additional data to notifications
Revision ID: 1b10e1fda030
Revises: 6756efa39ada
Create Date: 2024-10-15 19:26:44.071259
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "1b10e1fda030"
down_revision = "6756efa39ada"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"notification", sa.Column("additional_data", postgresql.JSONB(), nullable=True)
)
def downgrade() -> None:
op.drop_column("notification", "additional_data")

View File

@@ -1,30 +0,0 @@
"""add api_version and deployment_name to search settings
Revision ID: 5d12a446f5c0
Revises: e4334d5b33ba
Create Date: 2024-10-08 15:56:07.975636
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5d12a446f5c0"
down_revision = "e4334d5b33ba"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"embedding_provider", sa.Column("api_version", sa.String(), nullable=True)
)
op.add_column(
"embedding_provider", sa.Column("deployment_name", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("embedding_provider", "deployment_name")
op.drop_column("embedding_provider", "api_version")

View File

@@ -1,153 +0,0 @@
"""
Revision ID: 6756efa39ada
Revises: 5d12a446f5c0
Create Date: 2024-10-15 17:47:44.108537
"""
from alembic import op
import sqlalchemy as sa
revision = "6756efa39ada"
down_revision = "5d12a446f5c0"
branch_labels = None
depends_on = None
"""
Migrate chat_session and chat_message tables to use UUID primary keys.
This script:
1. Adds UUID columns to chat_session and chat_message
2. Populates new columns with UUIDs
3. Updates foreign key relationships
4. Removes old integer ID columns
Note: Downgrade will assign new integer IDs, not restore original ones.
"""
def upgrade() -> None:
op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;")
op.add_column(
"chat_session",
sa.Column(
"new_id",
sa.UUID(as_uuid=True),
server_default=sa.text("gen_random_uuid()"),
nullable=False,
),
)
op.execute("UPDATE chat_session SET new_id = gen_random_uuid();")
op.add_column(
"chat_message",
sa.Column("new_chat_session_id", sa.UUID(as_uuid=True), nullable=True),
)
op.execute(
"""
UPDATE chat_message
SET new_chat_session_id = cs.new_id
FROM chat_session cs
WHERE chat_message.chat_session_id = cs.id;
"""
)
op.drop_constraint(
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
)
op.drop_column("chat_message", "chat_session_id")
op.alter_column(
"chat_message", "new_chat_session_id", new_column_name="chat_session_id"
)
op.drop_constraint("chat_session_pkey", "chat_session", type_="primary")
op.drop_column("chat_session", "id")
op.alter_column("chat_session", "new_id", new_column_name="id")
op.create_primary_key("chat_session_pkey", "chat_session", ["id"])
op.create_foreign_key(
"chat_message_chat_session_id_fkey",
"chat_message",
"chat_session",
["chat_session_id"],
["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
op.drop_constraint(
"chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey"
)
op.add_column(
"chat_session",
sa.Column("old_id", sa.Integer, autoincrement=True, nullable=True),
)
op.execute("CREATE SEQUENCE chat_session_old_id_seq OWNED BY chat_session.old_id;")
op.execute(
"ALTER TABLE chat_session ALTER COLUMN old_id SET DEFAULT nextval('chat_session_old_id_seq');"
)
op.execute(
"UPDATE chat_session SET old_id = nextval('chat_session_old_id_seq') WHERE old_id IS NULL;"
)
op.alter_column("chat_session", "old_id", nullable=False)
op.drop_constraint("chat_session_pkey", "chat_session", type_="primary")
op.create_primary_key("chat_session_pkey", "chat_session", ["old_id"])
op.add_column(
"chat_message",
sa.Column("old_chat_session_id", sa.Integer, nullable=True),
)
op.execute(
"""
UPDATE chat_message
SET old_chat_session_id = cs.old_id
FROM chat_session cs
WHERE chat_message.chat_session_id = cs.id;
"""
)
op.drop_column("chat_message", "chat_session_id")
op.alter_column(
"chat_message", "old_chat_session_id", new_column_name="chat_session_id"
)
op.create_foreign_key(
"chat_message_chat_session_id_fkey",
"chat_message",
"chat_session",
["chat_session_id"],
["old_id"],
ondelete="CASCADE",
)
op.drop_column("chat_session", "id")
op.alter_column("chat_session", "old_id", new_column_name="id")
op.alter_column(
"chat_session",
"id",
type_=sa.Integer(),
existing_type=sa.Integer(),
existing_nullable=False,
existing_server_default=False,
)
# Rename the sequence
op.execute("ALTER SEQUENCE chat_session_old_id_seq RENAME TO chat_session_id_seq;")
# Update the default value to use the renamed sequence
op.alter_column(
"chat_session",
"id",
server_default=sa.text("nextval('chat_session_id_seq'::regclass)"),
)

View File

@@ -5,23 +5,18 @@ from datetime import datetime
from datetime import timezone
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import jwt
from email_validator import EmailNotValidError
from email_validator import EmailUndeliverableError
from email_validator import validate_email
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import Request
from fastapi import Response
from fastapi import status
from fastapi.responses import RedirectResponse
from fastapi.security import OAuth2PasswordRequestForm
from fastapi_users import BaseUserManager
from fastapi_users import exceptions
@@ -35,19 +30,8 @@ from fastapi_users.authentication import JWTStrategy
from fastapi_users.authentication import Strategy
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
from fastapi_users.authentication.strategy.db import DatabaseStrategy
from fastapi_users.exceptions import UserAlreadyExists
from fastapi_users.jwt import decode_jwt
from fastapi_users.jwt import generate_jwt
from fastapi_users.jwt import SecretType
from fastapi_users.manager import UserManagerDependency
from fastapi_users.openapi import OpenAPIResponseType
from fastapi_users.router.common import ErrorCode
from fastapi_users.router.common import ErrorModel
from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase
from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
from httpx_oauth.oauth2 import BaseOAuth2
from httpx_oauth.oauth2 import OAuth2Token
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import attributes
from sqlalchemy.orm import Session
@@ -57,8 +41,10 @@ from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRole
from danswer.auth.schemas import UserUpdate
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DATA_PLANE_SECRET
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import EMAIL_FROM
from danswer.configs.app_configs import EXPECTED_API_KEY
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from danswer.configs.app_configs import SECRET_JWT_KEY
@@ -143,10 +129,7 @@ def verify_email_is_invited(email: str) -> None:
if not email:
raise PermissionError("Email must be specified")
try:
email_info = validate_email(email)
except EmailUndeliverableError:
raise PermissionError("Email is not valid")
email_info = validate_email(email) # can raise EmailNotValidError
for email_whitelist in whitelist:
try:
@@ -233,60 +216,35 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
safe: bool = False,
request: Optional[Request] = None,
) -> User:
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if hasattr(user_create, "role"):
user_count = await get_user_count()
if user_count == 0 or user_create.email in get_default_admin_user_emails():
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
tenant_id = (
get_tenant_id_for_email(user_create.email) if MULTI_TENANT else "public"
)
except exceptions.UserNotExists:
raise HTTPException(status_code=401, detail="User not found")
if not tenant_id:
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
)
async with get_async_session_with_tenant(tenant_id) as db_session:
token = current_tenant_id.set(tenant_id)
verify_email_is_invited(user_create.email)
verify_email_domain(user_create.email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db
self.database = tenant_user_db
if hasattr(user_create, "role"):
user_count = await get_user_count()
if (
user_count == 0
or user_create.email in get_default_admin_user_emails()
):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
user = None
try:
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
user = await self.update(user_update, user)
else:
raise exceptions.UserAlreadyExists()
current_tenant_id.reset(token)
return user
user = await super().create(user_create, safe=safe, request=request) # type: ignore
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if (
not user.has_web_login
and hasattr(user_create, "has_web_login")
and user_create.has_web_login
):
user_update = UserUpdate(
password=user_create.password,
has_web_login=True,
role=user_create.role,
is_verified=user_create.is_verified,
)
user = await self.update(user_update, user)
else:
raise exceptions.UserAlreadyExists()
return user
async def on_after_login(
self,
@@ -338,13 +296,13 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
token = None
async with get_async_session_with_tenant(tenant_id) as db_session:
token = current_tenant_id.set(tenant_id)
# Print a list of tables in the current database session
verify_email_in_whitelist(account_email, tenant_id)
verify_email_domain(account_email)
if MULTI_TENANT:
tenant_user_db = SQLAlchemyUserAdminDB(db_session, User, OAuthAccount)
self.user_db = tenant_user_db
self.database = tenant_user_db # type: ignore
self.database = tenant_user_db
oauth_account_dict = {
"oauth_name": oauth_name,
@@ -462,6 +420,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
email = credentials.username
# Get tenant_id from mapping table
tenant_id = get_tenant_id_for_email(email)
if not tenant_id:
# User not found in mapping
@@ -695,184 +654,26 @@ def get_default_admin_user_emails_() -> list[str]:
return []
STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state"
async def control_plane_dep(request: Request) -> None:
api_key = request.headers.get("X-API-KEY")
if api_key != EXPECTED_API_KEY:
logger.warning("Invalid API key")
raise HTTPException(status_code=401, detail="Invalid API key")
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
logger.warning("Invalid authorization header")
raise HTTPException(status_code=401, detail="Invalid authorization header")
class OAuth2AuthorizeResponse(BaseModel):
authorization_url: str
def generate_state_token(
data: Dict[str, str], secret: SecretType, lifetime_seconds: int = 3600
) -> str:
data["aud"] = STATE_TOKEN_AUDIENCE
return generate_jwt(data, secret, lifetime_seconds)
# refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91
def create_danswer_oauth_router(
oauth_client: BaseOAuth2,
backend: AuthenticationBackend,
state_secret: SecretType,
redirect_url: Optional[str] = None,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> APIRouter:
return get_oauth_router(
oauth_client,
backend,
get_user_manager,
state_secret,
redirect_url,
associate_by_email,
is_verified_by_default,
)
def get_oauth_router(
oauth_client: BaseOAuth2,
backend: AuthenticationBackend,
get_user_manager: UserManagerDependency[models.UP, models.ID],
state_secret: SecretType,
redirect_url: Optional[str] = None,
associate_by_email: bool = False,
is_verified_by_default: bool = False,
) -> APIRouter:
"""Generate a router with the OAuth routes."""
router = APIRouter()
callback_route_name = f"oauth:{oauth_client.name}.{backend.name}.callback"
if redirect_url is not None:
oauth2_authorize_callback = OAuth2AuthorizeCallback(
oauth_client,
redirect_url=redirect_url,
)
else:
oauth2_authorize_callback = OAuth2AuthorizeCallback(
oauth_client,
route_name=callback_route_name,
)
@router.get(
"/authorize",
name=f"oauth:{oauth_client.name}.{backend.name}.authorize",
response_model=OAuth2AuthorizeResponse,
)
async def authorize(
request: Request, scopes: List[str] = Query(None)
) -> OAuth2AuthorizeResponse:
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
authorize_redirect_url = str(request.url_for(callback_route_name))
next_url = request.query_params.get("next", "/")
state_data: Dict[str, str] = {"next_url": next_url}
state = generate_state_token(state_data, state_secret)
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
state,
scopes,
)
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
@router.get(
"/callback",
name=callback_route_name,
description="The response varies based on the authentication backend used.",
responses={
status.HTTP_400_BAD_REQUEST: {
"model": ErrorModel,
"content": {
"application/json": {
"examples": {
"INVALID_STATE_TOKEN": {
"summary": "Invalid state token.",
"value": None,
},
ErrorCode.LOGIN_BAD_CREDENTIALS: {
"summary": "User is inactive.",
"value": {"detail": ErrorCode.LOGIN_BAD_CREDENTIALS},
},
}
}
},
},
},
)
async def callback(
request: Request,
access_token_state: Tuple[OAuth2Token, str] = Depends(
oauth2_authorize_callback
),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
) -> RedirectResponse:
token, state = access_token_state
account_id, account_email = await oauth_client.get_id_email(
token["access_token"]
)
if account_email is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL,
)
try:
state_data = decode_jwt(state, state_secret, [STATE_TOKEN_AUDIENCE])
except jwt.DecodeError:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST)
next_url = state_data.get("next_url", "/")
# Authenticate user
try:
user = await user_manager.oauth_callback(
oauth_client.name,
token["access_token"],
account_id,
account_email,
token.get("expires_at"),
token.get("refresh_token"),
request,
associate_by_email=associate_by_email,
is_verified_by_default=is_verified_by_default,
)
except UserAlreadyExists:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.OAUTH_USER_ALREADY_EXISTS,
)
if not user.is_active:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ErrorCode.LOGIN_BAD_CREDENTIALS,
)
# Login user
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
# Prepare redirect response
redirect_response = RedirectResponse(next_url, status_code=302)
# Copy headers and other attributes from 'response' to 'redirect_response'
for header_name, header_value in response.headers.items():
redirect_response.headers[header_name] = header_value
if hasattr(response, "body"):
redirect_response.body = response.body
if hasattr(response, "status_code"):
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router
token = auth_header.split(" ")[1]
try:
payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=["HS256"])
if payload.get("scope") != "tenant:create":
logger.warning("Insufficient permissions")
raise HTTPException(status_code=403, detail="Insufficient permissions")
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
raise HTTPException(status_code=401, detail="Token has expired")
except jwt.InvalidTokenError:
logger.warning("Invalid token")
raise HTTPException(status_code=401, detail="Invalid token")

View File

@@ -1,10 +1,9 @@
import logging
import multiprocessing
import time
from datetime import timedelta
from typing import Any
import sentry_sdk
import redis
from celery import bootsteps # type: ignore
from celery import Celery
from celery import current_task
@@ -12,86 +11,50 @@ from celery import signals
from celery import Task
from celery.exceptions import WorkerShutdown
from celery.signals import beat_init
from celery.signals import celeryd_init
from celery.signals import worker_init
from celery.signals import worker_ready
from celery.signals import worker_shutdown
from celery.states import READY_STATES
from celery.utils.log import get_task_logger
from sentry_sdk.integrations.celery import CeleryIntegration
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import celery_is_worker_primary
from danswer.background.celery.celery_utils import get_all_tenant_ids
from danswer.background.update import get_all_tenant_ids
from danswer.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
from danswer.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import SqlEngine
from danswer.db.search_settings import get_current_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
# use this within celery tasks to get celery task specific logging
task_logger = get_task_logger(__name__)
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[CeleryIntegration()],
traces_sample_rate=0.5,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
celery_app = Celery(__name__)
celery_app.config_from_object(
"danswer.background.celery.celeryconfig"
) # Load configuration from 'celeryconfig.py'
@signals.task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
tenant_id: str | None = None,
kwargs: dict | None = None,
**kwds: Any,
) -> None:
pass
@signals.task_postrun.connect
def on_task_postrun(
def celery_task_postrun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple | None = None,
kwargs: dict[str, Any] | None = None,
kwargs: dict | None = None,
retval: Any | None = None,
state: str | None = None,
**kwds: Any,
@@ -103,24 +66,11 @@ def on_task_postrun(
This function runs after any task completes (both success and failure)
Note that this signal does not fire on a task that failed to complete and is going
to be retried.
This also does not fire if a worker with acks_late=False crashes (which all of our
long running workers are)
"""
if not task:
return
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
if not kwargs:
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
tenant_id = None
else:
tenant_id = kwargs.get("tenant_id")
task_logger.debug(
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
f"{f'for tenant_id={tenant_id}' if tenant_id else ''}"
)
task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}")
if state not in READY_STATES:
return
@@ -128,7 +78,7 @@ def on_task_postrun(
if not task_id:
return
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
if task_id.startswith(RedisConnectorCredentialPair.PREFIX):
r.srem(RedisConnectorCredentialPair.get_taskset_key(), task_id)
@@ -137,38 +87,32 @@ def on_task_postrun(
if task_id.startswith(RedisDocumentSet.PREFIX):
document_set_id = RedisDocumentSet.get_id_from_task_id(task_id)
if document_set_id is not None:
rds = RedisDocumentSet(int(document_set_id))
rds = RedisDocumentSet(document_set_id)
r.srem(rds.taskset_key, task_id)
return
if task_id.startswith(RedisUserGroup.PREFIX):
usergroup_id = RedisUserGroup.get_id_from_task_id(task_id)
if usergroup_id is not None:
rug = RedisUserGroup(int(usergroup_id))
rug = RedisUserGroup(usergroup_id)
r.srem(rug.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorDeletion.PREFIX):
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcd = RedisConnectorDeletion(int(cc_pair_id))
rcd = RedisConnectorDeletion(cc_pair_id)
r.srem(rcd.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorPruning.SUBTASK_PREFIX):
cc_pair_id = RedisConnectorPruning.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcp = RedisConnectorPruning(int(cc_pair_id))
rcp = RedisConnectorPruning(cc_pair_id)
r.srem(rcp.taskset_key, task_id)
return
@celeryd_init.connect
def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None:
"""The first signal sent on celery worker startup"""
multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn
@beat_init.connect
def on_beat_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
@@ -177,12 +121,8 @@ def on_beat_init(sender: Any, **kwargs: Any) -> None:
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
# decide some initial startup settings based on the celery worker's hostname
# (set at the command line)'
# (set at the command line)
hostname = sender.hostname
if hostname.startswith("light"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
@@ -190,171 +130,131 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
elif hostname.startswith("heavy"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
elif hostname.startswith("indexing"):
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
tenant_ids = get_all_tenant_ids()
for tenant_id in tenant_ids:
# TODO: why is this necessary for the indexer to do?
with get_session_with_tenant(tenant_id) as db_session:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if search_settings.provider_type is None:
logger.notice(
"Running a first inference to warm up embedding model"
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(
embedding_model=embedding_model,
)
logger.notice("First inference complete.")
else:
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
SqlEngine.init_engine(pool_size=8, max_overflow=0)
if not hasattr(sender, "primary_worker_locks"):
sender.primary_worker_locks = {}
r = get_redis_client()
tenant_ids = get_all_tenant_ids()
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
time_start = time.monotonic()
logger.info("Redis: Readiness check starting.")
while True:
try:
if r.ping():
break
except Exception:
pass
time_elapsed = time.monotonic() - time_start
logger.info(
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("Redis: Readiness check succeeded. Continuing...")
if not celery_is_worker_primary(sender):
logger.info("Running as a secondary celery worker.")
for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
time_start = time.monotonic()
logger.notice("Redis: Readiness check starting.")
while True:
# Log all the locks in Redis
all_locks = r.keys("*")
logger.notice(f"Current Redis locks: {all_locks}")
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
break
time_elapsed = time.monotonic() - time_start
logger.info(
f"Redis: Ping failed. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
"Redis: Readiness check did not succeed within the timeout "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
time.sleep(WAIT_INTERVAL)
logger.info("Wait for primary worker completed successfully. Continuing...")
return # Exit the function for secondary workers
for tenant_id in tenant_ids:
r = get_redis_client(tenant_id=tenant_id)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
logger.info("Waiting for primary worker to be ready...")
time_start = time.monotonic()
logger.info("Running as the primary celery worker.")
while True:
if r.exists(DanswerRedisLocks.PRIMARY_WORKER):
break
# This is singleton work that should be done on startup exactly once
# by the primary worker
r = get_redis_client(tenant_id=tenant_id)
time.monotonic()
time_elapsed = time.monotonic() - time_start
logger.info(
f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
)
if time_elapsed > WAIT_LIMIT:
msg = (
f"Primary worker was not ready within the timeout. "
f"({WAIT_LIMIT} seconds). Exiting..."
)
logger.error(msg)
raise WorkerShutdown(msg)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
time.sleep(WAIT_INTERVAL)
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
logger.info("Wait for primary worker completed successfully. Continuing...")
return
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
logger.info("Running as the primary celery worker.")
sender.primary_worker_locks[tenant_id] = lock
# This is singleton work that should be done on startup exactly once
# by the primary worker
r = get_redis_client()
# As currently designed, when this worker starts as "primary", we reinitialize redis
# to a clean state (for our purposes, anyway)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
# For the moment, we're assuming that we are the only primary worker
# that should be running.
# TODO: maybe check for or clean up another zombie primary worker if we detect it
r.delete(DanswerRedisLocks.PRIMARY_WORKER)
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
# this process wide lock is taken to help other workers start up in order.
# it is planned to use this lock to enforce singleton behavior on the primary
# worker, since the primary worker does redis cleanup on startup, but this isn't
# implemented yet.
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2)
if acquired:
logger.info("Primary worker lock: Acquire succeeded.")
else:
logger.error("Primary worker lock: Acquire failed!")
raise WorkerShutdown("Primary worker lock could not be acquired!")
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
sender.primary_worker_lock = lock
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
r.delete(DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
r.delete(DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
r.delete(RedisConnectorCredentialPair.get_taskset_key())
r.delete(RedisConnectorCredentialPair.get_fence_key())
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPruning.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
r.delete(key)
# @worker_process_init.connect
# def on_worker_process_init(sender: Any, **kwargs: Any) -> None:
# """This only runs inside child processes when the worker is in pool=prefork mode.
# This may be technically unnecessary since we're finding prefork pools to be
# unstable and currently aren't planning on using them."""
# logger.info("worker_process_init signal received.")
# SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME)
# SqlEngine.init_engine(pool_size=5, max_overflow=0)
# # https://stackoverflow.com/questions/43944787/sqlalchemy-celery-with-scoped-session-error
# SqlEngine.get_engine().dispose(close=False)
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
r.delete(key)
@worker_ready.connect
@@ -367,15 +267,14 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
if not celery_is_worker_primary(sender):
return
if not hasattr(sender, "primary_worker_locks"):
if not sender.primary_worker_lock:
return
logger.info("Releasing primary worker lock.")
for tenant_id, lock in sender.primary_worker_locks.items():
logger.info(f"Releasing primary worker lock for tenant {tenant_id}.")
if lock.owned():
lock.release()
sender.primary_worker_locks = {}
lock = sender.primary_worker_lock
if lock.owned():
lock.release()
sender.primary_worker_lock = None
class CeleryTaskPlainFormatter(PlainFormatter):
@@ -405,7 +304,7 @@ def on_setup_logging(
# TODO: could unhardcode format and colorize and accept these as options from
# celery's config
# reformats the root logger
# reformats celery's worker logger
root_logger = logging.getLogger()
root_handler = logging.StreamHandler() # Set up a handler for the root logger
@@ -450,18 +349,17 @@ def on_setup_logging(
class HubPeriodicTask(bootsteps.StartStopStep):
"""Regularly reacquires the primary worker locks for all tenants outside of the task queue.
"""Regularly reacquires the primary worker lock outside of the task queue.
Use the task_logger in this class to avoid double logging.
This cannot be done inside a regular beat task because it must run on schedule and
a queue of existing work would starve the task from running.
"""
# Requires the Hub component
# it's unclear to me whether using the hub's timer or the bootstep timer is better
requires = {"celery.worker.components:Hub"}
def __init__(self, worker: Any, **kwargs: Any) -> None:
super().__init__(worker, **kwargs)
self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds
self.task_tref = None
@@ -480,58 +378,42 @@ class HubPeriodicTask(bootsteps.StartStopStep):
def run_periodic_task(self, worker: Any) -> None:
try:
if not celery_is_worker_primary(worker):
if not worker.primary_worker_lock:
return
if not hasattr(worker, "primary_worker_locks"):
if not hasattr(worker, "primary_worker_lock"):
return
# Retrieve all tenant IDs
tenant_ids = get_all_tenant_ids()
r = get_redis_client()
for tenant_id in tenant_ids:
lock = worker.primary_worker_locks.get(tenant_id)
if not lock:
continue # Skip if no lock for this tenant
lock: redis.lock.Lock = worker.primary_worker_lock
r = get_redis_client(tenant_id=tenant_id)
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")
lock.reacquire()
else:
task_logger.warning(
"Full acquisition of primary worker lock. "
"Reasons could be computer sleep or a clock change."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
if lock.owned():
task_logger.debug(
f"Reacquiring primary worker lock for tenant {tenant_id}."
)
lock.reacquire()
task_logger.info("Primary worker lock: Acquire starting.")
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info("Primary worker lock: Acquire succeeded.")
else:
task_logger.warning(
f"Full acquisition of primary worker lock for tenant {tenant_id}. "
"Reasons could be worker restart or lock expiration."
)
lock = r.lock(
DanswerRedisLocks.PRIMARY_WORKER,
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
)
task_logger.error("Primary worker lock: Acquire failed!")
raise TimeoutError("Primary worker lock could not be acquired!")
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire starting."
)
acquired = lock.acquire(
blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2
)
if acquired:
task_logger.info(
f"Primary worker lock for tenant {tenant_id}: Acquire succeeded."
)
worker.primary_worker_locks[tenant_id] = lock
else:
task_logger.error(
f"Primary worker lock for tenant {tenant_id}: Acquire failed!"
)
raise TimeoutError(
f"Primary worker lock for tenant {tenant_id} could not be acquired!"
)
except Exception as e:
task_logger.error(f"Error in periodic task: {e}")
worker.primary_worker_lock = lock
except Exception:
task_logger.exception("HubPeriodicTask.run_periodic_task exceptioned.")
def stop(self, worker: Any) -> None:
# Cancel the scheduled task when the worker stops
@@ -545,7 +427,6 @@ celery_app.steps["worker"].add(HubPeriodicTask)
celery_app.autodiscover_tasks(
[
"danswer.background.celery.tasks.connector_deletion",
"danswer.background.celery.tasks.indexing",
"danswer.background.celery.tasks.periodic",
"danswer.background.celery.tasks.pruning",
"danswer.background.celery.tasks.shared",
@@ -572,15 +453,9 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=60),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-indexing",
"task": "check_for_indexing",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
{
"name": "check-for-prune",
"task": "check_for_pruning",
"task": "check_for_prune_task_2",
"schedule": timedelta(seconds=10),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
@@ -601,14 +476,14 @@ tasks_to_schedule = [
# Build the celery beat schedule dynamically
beat_schedule = {}
for id in tenant_ids:
for tenant_id in tenant_ids:
for task in tasks_to_schedule:
task_name = f"{task['name']}-{id}" # Unique name for each scheduled task
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
beat_schedule[task_name] = {
"task": task["task"],
"schedule": task["schedule"],
"options": task["options"],
"kwargs": {"tenant_id": id}, # Must pass tenant_id as an argument
"args": (tenant_id,), # Must pass tenant_id as an argument
}
# Include any existing beat schedules

View File

@@ -29,8 +29,8 @@ class RedisObjectHelper(ABC):
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: str):
self._id: str = id
def __init__(self, id: int):
self._id: int = id
@property
def task_id_prefix(self) -> str:
@@ -47,7 +47,7 @@ class RedisObjectHelper(ABC):
return f"{self.TASKSET_PREFIX}_{self._id}"
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
def get_id_from_fence_key(key: str) -> int | None:
"""
Extracts the object ID from a fence key in the format `PREFIX_fence_X`.
@@ -61,11 +61,15 @@ class RedisObjectHelper(ABC):
if len(parts) != 3:
return None
object_id = parts[2]
try:
object_id = int(parts[2])
except ValueError:
return None
return object_id
@staticmethod
def get_id_from_task_id(task_id: str) -> str | None:
def get_id_from_task_id(task_id: str) -> int | None:
"""
Extracts the object ID from a task ID string.
@@ -89,7 +93,11 @@ class RedisObjectHelper(ABC):
if len(parts) != 3:
return None
object_id = parts[1]
try:
object_id = int(parts[1])
except ValueError:
return None
return object_id
@abstractmethod
@@ -109,9 +117,6 @@ class RedisDocumentSet(RedisObjectHelper):
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
@@ -123,7 +128,7 @@ class RedisDocumentSet(RedisObjectHelper):
last_lock_time = time.monotonic()
async_results = []
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
stmt = construct_document_select_by_docset(self._id, current_only=False)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
@@ -159,9 +164,6 @@ class RedisUserGroup(RedisObjectHelper):
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
@@ -185,7 +187,7 @@ class RedisUserGroup(RedisObjectHelper):
except ModuleNotFoundError:
return 0
stmt = construct_document_select_by_usergroup(int(self._id))
stmt = construct_document_select_by_usergroup(self._id)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
@@ -217,19 +219,13 @@ class RedisUserGroup(RedisObjectHelper):
class RedisConnectorCredentialPair(RedisObjectHelper):
"""This class is used to scan documents by cc_pair in the db and collect them into
a unified set for syncing.
It differs from the other redis helpers in that the taskset used spans
"""This class differs from the default in that the taskset used spans
all connectors and is not per connector."""
PREFIX = "connectorsync"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
@classmethod
def get_fence_key(cls) -> str:
return RedisConnectorCredentialPair.FENCE_PREFIX
@@ -256,7 +252,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
if not cc_pair:
return None
@@ -302,9 +298,6 @@ class RedisConnectorDeletion(RedisObjectHelper):
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def __init__(self, id: int) -> None:
super().__init__(str(id))
def generate_tasks(
self,
celery_app: Celery,
@@ -316,7 +309,7 @@ class RedisConnectorDeletion(RedisObjectHelper):
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
if not cc_pair:
return None
@@ -393,7 +386,9 @@ class RedisConnectorPruning(RedisObjectHelper):
) # a signal that the generator has finished
def __init__(self, id: int) -> None:
super().__init__(str(id))
"""id: the cc_pair_id of the connector credential pair"""
super().__init__(id)
self.documents_to_prune: set[str] = set()
@property
@@ -425,7 +420,7 @@ class RedisConnectorPruning(RedisObjectHelper):
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(int(self._id), db_session)
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
if not cc_pair:
return None
@@ -468,7 +463,7 @@ class RedisConnectorPruning(RedisObjectHelper):
def is_pruning(self, db_session: Session, redis_client: Redis) -> bool:
"""A single example of a helper method being refactored into the redis helper"""
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=int(self._id), db_session=db_session
cc_pair_id=self._id, db_session=db_session
)
if not cc_pair:
raise ValueError(f"cc_pair_id {self._id} does not exist.")
@@ -479,66 +474,6 @@ class RedisConnectorPruning(RedisObjectHelper):
return False
class RedisConnectorIndexing(RedisObjectHelper):
"""Celery will kick off a long running indexing task to crawl the connector and
find any new or updated docs docs, which will each then get a new sync task or be
indexed inline.
ID should be a concatenation of cc_pair_id and search_setting_id, delimited by "/".
e.g. "2/5"
"""
PREFIX = "connectorindexing"
FENCE_PREFIX = PREFIX + "_fence" # a fence for the entire indexing process
GENERATOR_TASK_PREFIX = PREFIX + "+generator"
TASKSET_PREFIX = PREFIX + "_taskset" # stores a list of prune tasks id's
SUBTASK_PREFIX = PREFIX + "+sub"
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # a signal that contains generator progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # a signal that the generator has finished
def __init__(self, cc_pair_id: int, search_settings_id: int) -> None:
super().__init__(f"{cc_pair_id}/{search_settings_id}")
@property
def generator_lock_key(self) -> str:
return f"{self.GENERATOR_LOCK_PREFIX}_{self._id}"
@property
def generator_task_id_prefix(self) -> str:
return f"{self.GENERATOR_TASK_PREFIX}_{self._id}"
@property
def generator_progress_key(self) -> str:
# example: connectorpruning_generator_progress_1
return f"{self.GENERATOR_PROGRESS_PREFIX}_{self._id}"
@property
def generator_complete_key(self) -> str:
# example: connectorpruning_generator_complete_1
return f"{self.GENERATOR_COMPLETE_PREFIX}_{self._id}"
@property
def subtask_id_prefix(self) -> str:
return f"{self.SUBTASK_PREFIX}_{self._id}"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock | None,
tenant_id: str | None,
) -> int | None:
return None
def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists

View File

@@ -3,13 +3,10 @@ from datetime import datetime
from datetime import timezone
from typing import Any
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.constants import TENANT_ID_PREFIX
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
@@ -19,7 +16,6 @@ from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import Document
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import TaskStatus
from danswer.db.models import TaskQueueState
from danswer.redis.redis_pool import get_redis_client
@@ -31,10 +27,7 @@ logger = setup_logger()
def _get_deletion_status(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None = None,
connector_id: int, credential_id: int, db_session: Session
) -> TaskQueueState | None:
"""We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis.
@@ -47,7 +40,7 @@ def _get_deletion_status(
rcd = RedisConnectorDeletion(cc_pair.id)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
if not r.exists(rcd.fence_key):
return None
@@ -57,14 +50,9 @@ def _get_deletion_status(
def get_deletion_attempt_snapshot(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str | None = None,
connector_id: int, credential_id: int, db_session: Session
) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(
connector_id, credential_id, db_session, tenant_id
)
deletion_task = _get_deletion_status(connector_id, credential_id, db_session)
if not deletion_task:
return None
@@ -136,30 +124,10 @@ def celery_is_worker_primary(worker: Any) -> bool:
for the celery worker, which can be done either in celeryconfig.py or on the
command line with '--hostname'."""
hostname = worker.hostname
if hostname.startswith("primary"):
return True
if hostname.startswith("light"):
return False
return False
if hostname.startswith("heavy"):
return False
def get_all_tenant_ids() -> list[str] | list[None]:
if not MULTI_TENANT:
return [None]
with get_session_with_tenant(tenant_id="public") as session:
result = session.execute(
text(
"""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')"""
)
)
tenant_ids = [row[0] for row in result]
valid_tenants = [
tenant
for tenant in tenant_ids
if tenant is None or tenant.startswith(TENANT_ID_PREFIX)
]
return valid_tenants
return True

View File

@@ -41,11 +41,6 @@ result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PO
# can stall other tasks.
worker_prefetch_multiplier = 4
# Leaving this to the default of True may cause double logging since both our own app
# and celery think they are controlling the logger.
# TODO: Configure celery's logger entirely manually and set this to False
# worker_hijack_root_logger = False
broker_connection_retry_on_startup = True
broker_pool_limit = CELERY_BROKER_POOL_LIMIT

View File

@@ -12,7 +12,7 @@ from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerRedisLocks
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_pool import get_redis_client
@@ -23,8 +23,8 @@ from danswer.redis.redis_pool import get_redis_client
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def check_for_connector_deletion_task(*, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
def check_for_connector_deletion_task(tenant_id: str | None) -> None:
r = get_redis_client()
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
@@ -36,7 +36,7 @@ def check_for_connector_deletion_task(*, tenant_id: str | None) -> None:
if not lock_beat.acquire(blocking=False):
return
with get_session_with_tenant(tenant_id) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
try_generate_document_cc_pair_cleanup_tasks(

View File

@@ -1,455 +0,0 @@
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from time import sleep
from typing import cast
from uuid import uuid4
from celery import shared_task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import celery_app
from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
from danswer.configs.constants import DocumentSource
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
logger = setup_logger()
@shared_task(
name="check_for_indexing",
soft_time_limit=300,
)
def check_for_indexing(*, tenant_id: str | None) -> int | None:
tasks_created = 0
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
DanswerRedisLocks.CHECK_INDEXING_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
task_logger.info(f"Lock acquired for tenant (Y): {tenant_id}")
return None
else:
task_logger.info(f"Lock acquired for tenant (N): {tenant_id}")
with get_session_with_tenant(tenant_id) as db_session:
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
for search_settings_instance in search_settings:
rci = RedisConnectorIndexing(
cc_pair.id, search_settings_instance.id
)
if r.exists(rci.fence_key):
continue
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, search_settings_instance.id, db_session
)
if not _should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
secondary_index_building=len(search_settings) > 1,
db_session=db_session,
):
continue
# using a task queue and only allowing one task per cc_pair/search_setting
# prevents us from starving out certain attempts
attempt_id = try_creating_indexing_task(
cc_pair,
search_settings_instance,
False,
db_session,
r,
tenant_id,
)
if attempt_id:
task_logger.info(
f"Indexing queued: cc_pair_id={cc_pair.id} index_attempt_id={attempt_id}"
)
tasks_created += 1
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
return tasks_created
def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
secondary_index_building: bool,
db_session: Session,
) -> bool:
"""Checks various global settings and past indexing attempts to determine if
we should try to start indexing the cc pair / search setting combination.
Note that tactical checks such as preventing overlap with a currently running task
are not handled here.
Return True if we should try to index, False if not.
"""
connector = cc_pair.connector
# uncomment for debugging
# task_logger.info(f"_should_index: "
# f"cc_pair={cc_pair.id} "
# f"connector={cc_pair.connector_id} "
# f"refresh_freq={connector.refresh_freq}")
# don't kick off indexing for `NOT_APPLICABLE` sources
if connector.source == DocumentSource.NOT_APPLICABLE:
return False
# User can still manually create single indexing attempts via the UI for the
# currently in use index
if DISABLE_INDEX_UPDATE_ON_SWAP:
if (
search_settings_instance.status == IndexModelStatus.PRESENT
and secondary_index_building
):
return False
# When switching over models, always index at least once
if search_settings_instance.status == IndexModelStatus.FUTURE:
if last_index:
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
if last_index.status == IndexingStatus.SUCCESS:
return False
# No new index if the last index attempt is waiting to start
if last_index.status == IndexingStatus.NOT_STARTED:
return False
# No new index if the last index attempt is running
if last_index.status == IndexingStatus.IN_PROGRESS:
return False
else:
if (
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
): # Ingestion API
return False
return True
# If the connector is paused or is the ingestion API, don't index
# NOTE: during an embedding model switch over, the following logic
# is bypassed by the above check for a future model
if (
not cc_pair.status.is_active()
or connector.id == 0
or connector.source == DocumentSource.INGESTION_API
):
return False
# if no attempt has ever occurred, we should index regardless of refresh_freq
if not last_index:
return True
if connector.refresh_freq is None:
return False
current_db_time = get_db_current_time(db_session)
time_since_index = current_db_time - last_index.time_updated
if time_since_index.total_seconds() < connector.refresh_freq:
return False
return True
def try_creating_indexing_task(
cc_pair: ConnectorCredentialPair,
search_settings: SearchSettings,
reindex: bool,
db_session: Session,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Checks for any conditions that should block the indexing task from being
created, then creates the task.
Does not check for scheduling related conditions as this function
is used to trigger indexing immediately.
"""
LOCK_TIMEOUT = 30
# we need to serialize any attempt to trigger indexing since it can be triggered
# either via celery beat or manually (API call)
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task",
timeout=LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
rci = RedisConnectorIndexing(cc_pair.id, search_settings.id)
# skip if already indexing
if r.exists(rci.fence_key):
return None
# skip indexing if the cc_pair is deleting
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
r.delete(rci.generator_complete_key)
r.delete(rci.taskset_key)
custom_task_id = f"{rci.generator_task_id_prefix}_{uuid4()}"
# create the index attempt ... just for tracking purposes
index_attempt_id = create_index_attempt(
cc_pair.id,
search_settings.id,
from_beginning=reindex,
db_session=db_session,
)
result = celery_app.send_task(
"connector_indexing_proxy_task",
kwargs=dict(
index_attempt_id=index_attempt_id,
cc_pair_id=cc_pair.id,
search_settings_id=search_settings.id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_INDEXING,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
if not result:
return None
# set this only after all tasks have been added
fence_value = RedisConnectorIndexingFenceData(
index_attempt_id=index_attempt_id,
started=None,
submitted=datetime.now(timezone.utc),
celery_task_id=result.id,
)
r.set(rci.fence_key, fence_value.model_dump_json())
except Exception:
task_logger.exception("Unexpected exception")
return None
finally:
if lock.owned():
lock.release()
return index_attempt_id
@shared_task(name="connector_indexing_proxy_task", acks_late=False, track_started=True)
def connector_indexing_proxy_task(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
client = SimpleJobClient()
job = client.submit(
connector_indexing_task,
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
global_version.is_ee_version(),
pure=False,
)
if not job:
return
while True:
sleep(10)
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
# do nothing for ongoing jobs that haven't been stopped
if not job.done():
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
if job.status == "error":
logger.error(job.exception())
job.release()
break
return
def connector_indexing_task(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
) -> int | None:
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list
acks_late must be set to False. Otherwise, celery's visibility timeout will
cause any task that runs longer than the timeout to be redispatched by the broker.
There appears to be no good workaround for this, so we need to handle redispatching
manually.
Returns None if the task did not run (possibly due to a conflict).
Otherwise, returns an int >= 0 representing the number of indexed docs.
"""
attempt = None
n_final_progress = 0
r = get_redis_client(tenant_id=tenant_id)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
lock = r.lock(
rci.generator_lock_key,
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"Indexing task already running, exiting...: "
f"cc_pair_id={cc_pair_id} search_settings_id={search_settings_id}"
)
# r.set(rci.generator_complete_key, HTTPStatus.CONFLICT.value)
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise ValueError(
f"Index attempt not found: index_attempt_id={index_attempt_id}"
)
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
)
if not cc_pair:
raise ValueError(f"cc_pair not found: cc_pair_id={cc_pair_id}")
if not cc_pair.connector:
raise ValueError(
f"Connector not found: connector_id={cc_pair.connector_id}"
)
if not cc_pair.credential:
raise ValueError(
f"Credential not found: credential_id={cc_pair.credential_id}"
)
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
# Define the callback function
def redis_increment_callback(amount: int) -> None:
lock.reacquire()
r.incrby(rci.generator_progress_key, amount)
run_indexing_entrypoint(
index_attempt_id,
tenant_id,
cc_pair_id,
is_ee,
progress_callback=redis_increment_callback,
)
# get back the total number of indexed docs and return it
generator_progress_value = r.get(rci.generator_progress_key)
if generator_progress_value is not None:
try:
n_final_progress = int(cast(int, generator_progress_value))
except ValueError:
pass
r.set(rci.generator_complete_key, HTTPStatus.OK.value)
except Exception as e:
task_logger.exception(f"Failed to run indexing for cc_pair_id={cc_pair_id}.")
if attempt:
mark_attempt_failed(attempt, db_session, failure_reason=str(e))
r.delete(rci.generator_lock_key)
r.delete(rci.generator_progress_key)
r.delete(rci.taskset_key)
r.delete(rci.fence_key)
raise e
finally:
if lock.owned():
lock.release()
return n_final_progress

View File

@@ -14,7 +14,7 @@ from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import task_logger
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine # type: ignore
@shared_task(
@@ -23,7 +23,7 @@ from danswer.db.engine import get_session_with_tenant
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
def kombu_message_cleanup_task(self: Any) -> int:
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up
@@ -35,7 +35,7 @@ def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with get_session_with_tenant(tenant_id) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
# Exit the task if we can't take the advisory lock
result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),

View File

@@ -3,6 +3,7 @@ from datetime import timedelta
from datetime import timezone
from uuid import uuid4
import redis
from celery import shared_task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
@@ -14,9 +15,7 @@ from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.configs.constants import DanswerRedisLocks
@@ -31,15 +30,16 @@ from danswer.db.models import ConnectorCredentialPair
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
logger = setup_logger()
@shared_task(
name="check_for_pruning",
name="check_for_prune_task_2",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_pruning(*, tenant_id: str | None) -> None:
r = get_redis_client(tenant_id=tenant_id)
def check_for_prune_task_2(tenant_id: str | None) -> None:
r = get_redis_client()
lock_beat = r.lock(
DanswerRedisLocks.CHECK_PRUNE_BEAT_LOCK,
@@ -54,17 +54,13 @@ def check_for_pruning(*, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
lock_beat.reacquire()
if not is_pruning_due(cc_pair, db_session, r):
continue
tasks_created = try_creating_prune_generator_task(
cc_pair, db_session, r, tenant_id
tasks_created = ccpair_pruning_generator_task_creation_helper(
cc_pair, db_session, tenant_id, r, lock_beat
)
if not tasks_created:
continue
task_logger.info(f"Pruning queued: cc_pair_id={cc_pair.id}")
task_logger.info(f"Pruning started: cc_pair_id={cc_pair.id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -76,11 +72,13 @@ def check_for_pruning(*, tenant_id: str | None) -> None:
lock_beat.release()
def is_pruning_due(
def ccpair_pruning_generator_task_creation_helper(
cc_pair: ConnectorCredentialPair,
db_session: Session,
tenant_id: str | None,
r: Redis,
) -> bool:
lock_beat: redis.lock.Lock,
) -> int | None:
"""Returns an int if pruning is triggered.
The int represents the number of prune tasks generated (in this case, only one
because the task is a long running generator task.)
@@ -91,30 +89,24 @@ def is_pruning_due(
try_creating_prune_generator_task.
"""
lock_beat.reacquire()
# skip pruning if no prune frequency is set
# pruning can still be forced via the API which will run a pruning task directly
if not cc_pair.connector.prune_freq:
return False
# skip pruning if not active
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
return False
return None
# skip pruning if the next scheduled prune time hasn't been reached yet
last_pruned = cc_pair.last_pruned
if not last_pruned:
if not cc_pair.last_successful_index_time:
# if we've never indexed, we can't prune
return False
# if never pruned, use the last time the connector indexed successfully
last_pruned = cc_pair.last_successful_index_time
# if never pruned, use the connector time created as the last_pruned time
last_pruned = cc_pair.connector.time_created
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return False
return None
return True
return try_creating_prune_generator_task(cc_pair, db_session, r, tenant_id)
def try_creating_prune_generator_task(
@@ -127,101 +119,59 @@ def try_creating_prune_generator_task(
created, then creates the task.
Does not check for scheduling related conditions as this function
is used to trigger prunes immediately, e.g. via the web ui.
is used to trigger prunes immediately.
"""
if not ALLOW_SIMULTANEOUS_PRUNING:
for key in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
return None
LOCK_TIMEOUT = 30
rcp = RedisConnectorPruning(cc_pair.id)
# we need to serialize starting pruning since it can be triggered either via
# celery beat or manually (API call)
lock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_prune_generator_task",
timeout=LOCK_TIMEOUT,
# skip pruning if already pruning
if r.exists(rcp.fence_key):
return None
# skip pruning if the cc_pair is deleting
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
r.delete(rcp.generator_complete_key)
r.delete(rcp.taskset_key)
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
celery_app.send_task(
"connector_pruning_generator_task",
kwargs=dict(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_PRUNING,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
return None
try:
rcp = RedisConnectorPruning(cc_pair.id)
# skip pruning if already pruning
if r.exists(rcp.fence_key):
return None
# skip pruning if the cc_pair is deleting
db_session.refresh(cc_pair)
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
return None
# add a long running generator task to the queue
r.delete(rcp.generator_complete_key)
r.delete(rcp.taskset_key)
custom_task_id = f"{rcp.generator_task_id_prefix}_{uuid4()}"
celery_app.send_task(
"connector_pruning_generator_task",
kwargs=dict(
cc_pair_id=cc_pair.id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
tenant_id=tenant_id,
),
queue=DanswerCeleryQueues.CONNECTOR_PRUNING,
task_id=custom_task_id,
priority=DanswerCeleryPriority.LOW,
)
# set this only after all tasks have been added
r.set(rcp.fence_key, 1)
except Exception:
task_logger.exception("Unexpected exception")
return None
finally:
if lock.owned():
lock.release()
# set this only after all tasks have been added
r.set(rcp.fence_key, 1)
return 1
@shared_task(
name="connector_pruning_generator_task",
acks_late=False,
soft_time_limit=JOB_TIMEOUT,
track_started=True,
trail=False,
)
@shared_task(name="connector_pruning_generator_task", soft_time_limit=JOB_TIMEOUT)
def connector_pruning_generator_task(
cc_pair_id: int, connector_id: int, credential_id: int, tenant_id: str | None
connector_id: int, credential_id: int, tenant_id: str | None
) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
rcp = RedisConnectorPruning(cc_pair_id)
lock = r.lock(
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{rcp._id}",
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"Pruning task already running, exiting...: cc_pair_id={cc_pair_id}"
)
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
try:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
@@ -230,13 +180,14 @@ def connector_pruning_generator_task(
if not cc_pair:
task_logger.warning(
f"cc_pair not found for {connector_id} {credential_id}"
f"ccpair not found for {connector_id} {credential_id}"
)
return
rcp = RedisConnectorPruning(cc_pair.id)
# Define the callback function
def redis_increment_callback(amount: int) -> None:
lock.reacquire()
r.incrby(rcp.generator_progress_key, amount)
runnable_connector = instantiate_connector(
@@ -289,13 +240,12 @@ def connector_pruning_generator_task(
)
r.set(rcp.generator_complete_key, tasks_generated)
except Exception as e:
task_logger.exception(f"Failed to run pruning for connector id {connector_id}.")
except Exception as e:
task_logger.exception(
f"Failed to run pruning for connector id {connector_id}."
)
r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key)
r.delete(rcp.fence_key)
raise e
finally:
if lock.owned():
lock.release()
r.delete(rcp.generator_progress_key)
r.delete(rcp.taskset_key)
r.delete(rcp.fence_key)
raise e

View File

@@ -1,9 +1,6 @@
from datetime import datetime
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import BaseModel
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_app import task_logger
@@ -20,13 +17,6 @@ from danswer.document_index.interfaces import VespaDocumentFields
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
class RedisConnectorIndexingFenceData(BaseModel):
index_attempt_id: int
started: datetime | None
submitted: datetime
celery_task_id: str
@shared_task(
name="document_by_cc_pair_cleanup_task",
bind=True,
@@ -56,8 +46,6 @@ def document_by_cc_pair_cleanup_task(
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
task_logger.info(f"document_id={document_id}")
try:
with get_session_with_tenant(tenant_id) as db_session:
action = "skip"
@@ -123,17 +111,11 @@ def document_by_cc_pair_cleanup_task(
pass
task_logger.info(
f"tenant_id={tenant_id} "
f"document_id={document_id} "
f"action={action} "
f"refcount={count} "
f"chunks={chunks_affected}"
f"document_id={document_id} refcount={count} action={action} chunks={chunks_affected}"
)
db_session.commit()
except SoftTimeLimitExceeded:
task_logger.info(
f"SoftTimeLimitExceeded exception. tenant_id={tenant_id} doc_id={document_id}"
)
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
except Exception as e:
task_logger.exception("Unexpected exception")

View File

@@ -1,15 +1,10 @@
import traceback
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from typing import cast
import redis
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult
from celery.states import READY_STATES
from redis import Redis
from sqlalchemy.orm import Session
@@ -19,11 +14,9 @@ from danswer.background.celery.celery_app import task_logger
from danswer.background.celery.celery_redis import celery_get_queue_length
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_redis import RedisConnectorPruning
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.tasks.shared.tasks import RedisConnectorIndexingFenceData
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryQueues
@@ -38,7 +31,6 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair_f
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.document import count_documents_by_needs_sync
from danswer.db.document import get_document
from danswer.db.document import get_document_ids_for_connector_credential_pair
from danswer.db.document import mark_document_as_synced
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
@@ -47,19 +39,14 @@ from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import IndexingStatus
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_all_index_attempts_by_status
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import DocumentSet
from danswer.db.models import IndexAttempt
from danswer.db.models import UserGroup
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaDocumentFields
from danswer.document_index.interfaces import UpdateRequest
from danswer.redis.redis_pool import get_redis_client
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
@@ -67,8 +54,6 @@ from danswer.utils.variable_functionality import (
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import noop_fallback
logger = setup_logger()
# celery auto associates tasks created inside another task,
# which bloats the result metadata considerably. trail=False prevents this.
@@ -77,11 +62,11 @@ logger = setup_logger()
soft_time_limit=JOB_TIMEOUT,
trail=False,
)
def check_for_vespa_sync_task(*, tenant_id: str | None) -> None:
def check_for_vespa_sync_task(tenant_id: str | None) -> None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
lock_beat = r.lock(
DanswerRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
@@ -311,13 +296,11 @@ def monitor_document_set_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
if document_set_id_str is None:
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
if document_set_id is None:
task_logger.warning(f"could not parse document set id from {fence_key}")
return
document_set_id = int(document_set_id_str)
rds = RedisDocumentSet(document_set_id)
fence_value = r.get(rds.fence_key)
@@ -332,8 +315,7 @@ def monitor_document_set_taskset(
count = cast(int, r.scard(rds.taskset_key))
task_logger.info(
f"Document set sync progress: document_set_id={document_set_id} "
f"remaining={count} initial={initial_count}"
f"Document set sync progress: document_set_id={document_set_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
@@ -359,17 +341,13 @@ def monitor_document_set_taskset(
r.delete(rds.fence_key)
def monitor_connector_deletion_taskset(
key_bytes: bytes, r: Redis, tenant_id: str | None
) -> None:
def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id is None:
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
return
cc_pair_id = int(cc_pair_id_str)
rcd = RedisConnectorDeletion(cc_pair_id)
fence_value = r.get(rcd.fence_key)
@@ -384,36 +362,25 @@ def monitor_connector_deletion_taskset(
count = cast(int, r.scard(rcd.taskset_key))
task_logger.info(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={count} initial={initial_count}"
f"Connector deletion progress: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
with get_session_with_tenant(tenant_id) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
task_logger.warning(
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
f"monitor_connector_deletion_taskset - cc_pair_id not found: cc_pair_id={cc_pair_id}"
)
return
try:
doc_ids = get_document_ids_for_connector_credential_pair(
db_session, cc_pair.connector_id, cc_pair.credential_id
)
if len(doc_ids) > 0:
# if this happens, documents somehow got added while deletion was in progress. Likely a bug
# gating off pruning and indexing work before deletion starts
task_logger.warning(
f"Connector deletion - documents still found after taskset completion: "
f"cc_pair={cc_pair_id} num={len(doc_ids)}"
)
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
cc_pair_id=cc_pair_id,
cc_pair_id=cc_pair.id,
)
# document sets
@@ -430,7 +397,7 @@ def monitor_connector_deletion_taskset(
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair_id,
cc_pair_id=cc_pair.id,
db_session=db_session,
)
@@ -452,21 +419,20 @@ def monitor_connector_deletion_taskset(
db_session.delete(connector)
db_session.commit()
except Exception as e:
db_session.rollback()
stack_trace = traceback.format_exc()
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair_id, error_message)
add_deletion_failure_message(db_session, cc_pair.id, error_message)
task_logger.exception(
f"Failed to run connector_deletion. "
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
f"cc_pair_id={cc_pair_id} connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}"
)
raise e
task_logger.info(
f"Successfully deleted cc_pair: "
f"cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"cc_pair_id={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"docs_deleted={initial_count}"
)
@@ -478,15 +444,13 @@ def monitor_ccpair_pruning_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnectorPruning.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
cc_pair_id = RedisConnectorPruning.get_id_from_fence_key(fence_key)
if cc_pair_id is None:
task_logger.warning(
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
f"monitor_connector_pruning_taskset: could not parse cc_pair_id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
rcp = RedisConnectorPruning(cc_pair_id)
fence_value = r.get(rcp.fence_key)
@@ -510,7 +474,7 @@ def monitor_ccpair_pruning_taskset(
if count > 0:
return
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
mark_ccpair_as_pruned(cc_pair_id, db_session)
task_logger.info(
f"Successfully pruned connector credential pair. cc_pair_id={cc_pair_id}"
)
@@ -521,129 +485,16 @@ def monitor_ccpair_pruning_taskset(
r.delete(rcp.fence_key)
def monitor_ccpair_indexing_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
composite_id = RedisConnectorIndexing.get_id_from_fence_key(fence_key)
if composite_id is None:
task_logger.warning(
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
)
return
# parse out metadata and initialize the helper class with it
parts = composite_id.split("/")
if len(parts) != 2:
return
cc_pair_id = int(parts[0])
search_settings_id = int(parts[1])
rci = RedisConnectorIndexing(cc_pair_id, search_settings_id)
# read related data and evaluate/print task progress
fence_value = cast(bytes, r.get(rci.fence_key))
if fence_value is None:
return
try:
fence_json = fence_value.decode("utf-8")
fence_data = RedisConnectorIndexingFenceData.model_validate_json(
cast(str, fence_json)
)
except ValueError:
task_logger.exception(
"monitor_ccpair_indexing_taskset: fence_data not decodeable."
)
raise
elapsed_submitted = datetime.now(timezone.utc) - fence_data.submitted
generator_progress_value = r.get(rci.generator_progress_key)
if generator_progress_value is not None:
try:
progress_count = int(cast(int, generator_progress_value))
task_logger.info(
f"Connector indexing progress: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"progress={progress_count} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
except ValueError:
task_logger.error(
"monitor_ccpair_indexing_taskset: generator_progress_value is not an integer."
)
# Read result state BEFORE generator_complete_key to avoid a race condition
result: AsyncResult = AsyncResult(fence_data.celery_task_id)
result_state = result.state
generator_complete_value = r.get(rci.generator_complete_key)
if generator_complete_value is None:
if result_state in READY_STATES:
# IF the task state is READY, THEN generator_complete should be set
# if it isn't, then the worker crashed
task_logger.info(
f"Connector indexing aborted: "
f"cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
index_attempt = get_index_attempt(db_session, fence_data.index_attempt_id)
if index_attempt:
mark_attempt_failed(
index_attempt=index_attempt,
db_session=db_session,
failure_reason="Connector indexing aborted or exceptioned.",
)
r.delete(rci.generator_lock_key)
r.delete(rci.taskset_key)
r.delete(rci.generator_progress_key)
r.delete(rci.generator_complete_key)
r.delete(rci.fence_key)
return
status_enum = HTTPStatus.INTERNAL_SERVER_ERROR
try:
status_value = int(cast(int, generator_complete_value))
status_enum = HTTPStatus(status_value)
except ValueError:
task_logger.error(
f"monitor_ccpair_indexing_taskset: "
f"generator_complete_value=f{generator_complete_value} could not be parsed."
)
task_logger.info(
f"Connector indexing finished: cc_pair_id={cc_pair_id} "
f"search_settings_id={search_settings_id} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
)
r.delete(rci.generator_lock_key)
r.delete(rci.taskset_key)
r.delete(rci.generator_progress_key)
r.delete(rci.generator_complete_key)
r.delete(rci.fence_key)
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> None:
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
It scans for fence values and then gets the counts of any associated tasksets.
If the count is 0, that means all tasks finished and we should clean up.
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
do anything too expensive in this function!
Returns True if the task actually did work, False
"""
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
lock_beat: redis.lock.Lock = r.lock(
DanswerRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
@@ -653,14 +504,11 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
try:
# prevent overlapping tasks
if not lock_beat.acquire(blocking=False):
return False
return
# print current queue lengths
r_celery = self.app.broker_connection().channel().client # type: ignore
n_celery = celery_get_queue_length("celery", r)
n_indexing = celery_get_queue_length(
DanswerCeleryQueues.CONNECTOR_INDEXING, r_celery
)
n_sync = celery_get_queue_length(
DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery
)
@@ -672,11 +520,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning}"
f"Queue lengths: celery={n_celery} sync={n_sync} deletion={n_deletion} pruning={n_pruning}"
)
lock_beat.reacquire()
@@ -685,7 +529,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
monitor_connector_deletion_taskset(key_bytes, r, tenant_id)
monitor_connector_deletion_taskset(key_bytes, r)
with get_session_with_tenant(tenant_id) as db_session:
lock_beat.reacquire()
@@ -707,29 +551,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
for key_bytes in r.scan_iter(RedisConnectorPruning.FENCE_PREFIX + "*"):
monitor_ccpair_pruning_taskset(key_bytes, r, db_session)
# do some cleanup before clearing fences
# check the db for any outstanding index attempts
attempts: list[IndexAttempt] = []
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.NOT_STARTED, db_session)
)
attempts.extend(
get_all_index_attempts_by_status(IndexingStatus.IN_PROGRESS, db_session)
)
for a in attempts:
# if attempts exist in the db but we don't detect them in redis, mark them as failed
rci = RedisConnectorIndexing(
a.connector_credential_pair_id, a.search_settings_id
)
failure_reason = f"Unknown index attempt {a.id}. Might be left over from a process restart."
if not r.exists(rci.fence_key):
mark_attempt_failed(a, db_session, failure_reason=failure_reason)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorIndexing.FENCE_PREFIX + "*"):
monitor_ccpair_indexing_taskset(key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
@@ -742,8 +563,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
if lock_beat.owned():
lock_beat.release()
return True
@shared_task(
name="vespa_metadata_sync_task",
@@ -776,24 +595,20 @@ def vespa_metadata_sync_task(
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
fields = VespaDocumentFields(
update_request = UpdateRequest(
document_ids=[document_id],
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = document_index.update_single(document_id, fields=fields)
# update Vespa
document_index.update(update_requests=[update_request])
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
task_logger.info(
f"document_id={document_id} action=sync chunks={chunks_affected}"
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
except Exception as e:

View File

@@ -1,10 +1,10 @@
import time
import traceback
from collections.abc import Callable
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from sqlalchemy import text
from sqlalchemy.orm import Session
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
@@ -20,10 +20,11 @@ from danswer.db.connector_credential_pair import get_last_successful_attempt_tim
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_session_with_tenant
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.index_attempt import mark_attempt_in_progress
from danswer.db.index_attempt import mark_attempt_partially_succeeded
from danswer.db.index_attempt import mark_attempt_succeeded
from danswer.db.index_attempt import transition_attempt_to_in_progress
from danswer.db.index_attempt import update_docs_indexed
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
@@ -64,7 +65,6 @@ def _get_connector_runner(
input_type=task,
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
credential=attempt.connector_credential_pair.credential,
tenant_id=tenant_id,
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
@@ -89,18 +89,12 @@ def _get_connector_runner(
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str | None,
progress_callback: Callable[[int], None] | None = None,
db_session: Session, index_attempt: IndexAttempt, tenant_id: str | None
) -> None:
"""
1. Get documents which are either new or updated from specified application
2. Embed and index these documents into the chosen datastore (vespa)
3. Updates Postgres to record the indexed documents + the outcome of this run
TODO: do not change index attempt statuses here ... instead, set signals in redis
and allow the monitor function to clean them up
"""
start_time = time.time()
@@ -206,7 +200,7 @@ def _run_indexing(
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
# contents still need to be initially pulled.
db_session.refresh(db_cc_pair)
db_session.refresh(db_connector)
if (
(
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
@@ -243,8 +237,6 @@ def _run_indexing(
logger.debug(f"Indexing batch of documents: {batch_description}")
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
# real work happens here!
new_docs, total_batch_chunks = indexing_pipeline(
document_batch=doc_batch,
index_attempt_metadata=index_attempt_md,
@@ -263,9 +255,6 @@ def _run_indexing(
# be inaccurate
db_session.commit()
if progress_callback:
progress_callback(len(doc_batch))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(
db_session=db_session,
@@ -389,12 +378,47 @@ def _run_indexing(
)
def _prepare_index_attempt(
db_session: Session, index_attempt_id: int, tenant_id: str | None
) -> IndexAttempt:
# make sure that the index attempt can't change in between checking the
# status and marking it as in_progress. This setting will be discarded
# after the next commit:
# https://docs.sqlalchemy.org/en/20/orm/session_transaction.html#setting-isolation-for-individual-transactions
db_session.connection(execution_options={"isolation_level": "SERIALIZABLE"}) # type: ignore
if tenant_id is not None:
# Explicitly set the search path for the given tenant
db_session.execute(text(f'SET search_path TO "{tenant_id}"'))
# Verify the search path was set correctly
result = db_session.execute(text("SHOW search_path"))
current_search_path = result.scalar()
logger.info(f"Current search path set to: {current_search_path}")
attempt = get_index_attempt(
db_session=db_session,
index_attempt_id=index_attempt_id,
)
if attempt is None:
raise RuntimeError(f"Unable to find IndexAttempt for ID '{index_attempt_id}'")
if attempt.status != IndexingStatus.NOT_STARTED:
raise RuntimeError(
f"Indexing attempt with ID '{index_attempt_id}' is not in NOT_STARTED status. "
f"Current status is '{attempt.status}'."
)
# only commit once, to make sure this all happens in a single transaction
mark_attempt_in_progress(attempt, db_session)
return attempt
def run_indexing_entrypoint(
index_attempt_id: int,
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
progress_callback: Callable[[int], None] | None = None,
) -> None:
try:
if is_ee:
@@ -406,7 +430,7 @@ def run_indexing_entrypoint(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_tenant(tenant_id) as db_session:
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
attempt = _prepare_index_attempt(db_session, index_attempt_id, tenant_id)
logger.info(
f"Indexing starting for tenant {tenant_id}: "
@@ -417,7 +441,7 @@ def run_indexing_entrypoint(
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
_run_indexing(db_session, attempt, tenant_id, progress_callback)
_run_indexing(db_session, attempt, tenant_id)
logger.info(
f"Indexing finished for tenant {tenant_id}: "

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,6 @@
import re
from typing import cast
from uuid import UUID
from fastapi.datastructures import Headers
from sqlalchemy.orm import Session
from danswer.chat.models import CitationInfo
@@ -35,7 +33,7 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
def create_chat_chain(
chat_session_id: UUID,
chat_session_id: int,
db_session: Session,
prefetch_tool_calls: bool = True,
# Optional id at which we finish processing
@@ -168,31 +166,3 @@ def reorganize_citations(
new_citation_info[citation.citation_num] = citation
return new_answer, list(new_citation_info.values())
def extract_headers(
headers: dict[str, str] | Headers, pass_through_headers: list[str] | None
) -> dict[str, str]:
"""
Extract headers specified in pass_through_headers from input headers.
Handles both dict and FastAPI Headers objects, accounting for lowercase keys.
Args:
headers: Input headers as dict or Headers object.
Returns:
dict: Filtered headers based on pass_through_headers.
"""
if not pass_through_headers:
return {}
extracted_headers: dict[str, str] = {}
for key in pass_through_headers:
if key in headers:
extracted_headers[key] = headers[key]
else:
# fastapi makes all header keys lowercase, handling that here
lowercase_key = key.lower()
if lowercase_key in headers:
extracted_headers[lowercase_key] = headers[lowercase_key]
return extracted_headers

View File

@@ -105,7 +105,6 @@ from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.headers import header_dict_to_header_list
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
@@ -277,7 +276,6 @@ def stream_chat_message_objects(
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
enforce_chat_session_id_for_search_docs: bool = True,
) -> ChatPacketStream:
@@ -641,12 +639,7 @@ def stream_chat_message_objects(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
custom_headers=(db_tool_model.custom_headers or [])
+ (
header_dict_to_header_list(
custom_tool_additional_headers or {}
)
),
custom_headers=db_tool_model.custom_headers,
),
)
@@ -869,7 +862,6 @@ def stream_chat_message(
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
custom_tool_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
) -> Iterator[str]:
with get_session_context_manager() as db_session:
@@ -879,7 +871,6 @@ def stream_chat_message(
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
custom_tool_additional_headers=custom_tool_additional_headers,
is_connected=is_connected,
)
for obj in objects:

View File

@@ -53,6 +53,7 @@ MASK_CREDENTIAL_PREFIX = (
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
)
SESSION_EXPIRE_TIME_SECONDS = int(
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
@@ -115,16 +116,10 @@ VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
VESPA_CONFIG_SERVER_HOST = os.environ.get("VESPA_CONFIG_SERVER_HOST") or VESPA_HOST
VESPA_PORT = os.environ.get("VESPA_PORT") or "8081"
VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071"
VESPA_CLOUD_URL = os.environ.get("VESPA_CLOUD_URL", "")
# The default below is for dockerized deployment
VESPA_DEPLOYMENT_ZIP = (
os.environ.get("VESPA_DEPLOYMENT_ZIP") or "/app/danswer/vespa-app.zip"
)
VESPA_CLOUD_CERT_PATH = os.environ.get("VESPA_CLOUD_CERT_PATH")
VESPA_CLOUD_KEY_PATH = os.environ.get("VESPA_CLOUD_KEY_PATH")
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
try:
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16))
@@ -410,8 +405,6 @@ VESPA_REQUEST_TIMEOUT = int(os.environ.get("VESPA_REQUEST_TIMEOUT") or "5")
SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000")
PARSE_WITH_TRAFILATURA = os.environ.get("PARSE_WITH_TRAFILATURA", "").lower() == "true"
#####
# Enterprise Edition Configs
#####
@@ -430,31 +423,11 @@ AZURE_DALLE_API_BASE = os.environ.get("AZURE_DALLE_API_BASE")
AZURE_DALLE_DEPLOYMENT_NAME = os.environ.get("AZURE_DALLE_DEPLOYMENT_NAME")
# Cloud configuration
# Multi-tenancy configuration
MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true"
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "")
# Use managed Vespa (Vespa Cloud). If set, must also set VESPA_CLOUD_URL, VESPA_CLOUD_CERT_PATH and VESPA_CLOUD_KEY_PATH
MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true"
DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "")
EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "")
ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true"
# Security and authentication
SECRET_JWT_KEY = os.environ.get(
"SECRET_JWT_KEY", ""
) # Used for encryption of the JWT token for user's tenant context
DATA_PLANE_SECRET = os.environ.get(
"DATA_PLANE_SECRET", ""
) # Used for secure communication between the control and data plane
EXPECTED_API_KEY = os.environ.get(
"EXPECTED_API_KEY", ""
) # Additional security check for the control plane API
# API configuration
CONTROL_PLANE_API_BASE_URL = os.environ.get(
"CONTROL_PLANE_API_BASE_URL", "http://localhost:8082"
)
# JWT configuration
JWT_ALGORITHM = "HS256"

View File

@@ -42,8 +42,6 @@ POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary"
POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light"
POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy"
POSTGRES_CELERY_WORKER_INDEXING_APP_NAME = "celery_worker_indexing"
POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child"
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
POSTGRES_UNKNOWN_APP_NAME = "unknown"
POSTGRES_DEFAULT_SCHEMA = "public"
@@ -75,16 +73,6 @@ KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_INDEXING_LOCK_TIMEOUT = 60 * 60 # 60 min
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:"
class DocumentSource(str, Enum):
# Special case, document passed in via Danswer APIs without specifying a source type
@@ -130,12 +118,8 @@ class DocumentSource(str, Enum):
NOT_APPLICABLE = "not_applicable"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
class NotificationType(str, Enum):
REINDEX = "reindex"
PERSONA_SHARED = "persona_shared"
class BlobType(str, Enum):
@@ -160,9 +144,6 @@ class AuthType(str, Enum):
OIDC = "oidc"
SAML = "saml"
# google auth and basic
CLOUD = "cloud"
class SessionType(str, Enum):
CHAT = "Chat"
@@ -212,19 +193,14 @@ class DanswerCeleryQueues:
VESPA_METADATA_SYNC = "vespa_metadata_sync"
CONNECTOR_DELETION = "connector_deletion"
CONNECTOR_PRUNING = "connector_pruning"
CONNECTOR_INDEXING = "connector_indexing"
class DanswerRedisLocks:
PRIMARY_WORKER = "da_lock:primary_worker"
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
PRUNING_LOCK_PREFIX = "da_lock:pruning"
INDEXING_METADATA_PREFIX = "da_metadata:indexing"
class DanswerCeleryPriority(int, Enum):

View File

@@ -1,22 +0,0 @@
import json
import os
# if specified, will pass through request headers to the call to API calls made by custom tools
CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get(
"CUSTOM_TOOL_PASS_THROUGH_HEADERS"
)
if _CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW:
try:
CUSTOM_TOOL_PASS_THROUGH_HEADERS = json.loads(
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW
)
except Exception:
# need to import here to avoid circular imports
from danswer.utils.logger import setup_logger
logger = setup_logger()
logger.error(
"Failed to parse CUSTOM_TOOL_PASS_THROUGH_HEADERS, must be a valid JSON object"
)

View File

@@ -6,8 +6,7 @@ from datetime import datetime
from datetime import timezone
from functools import lru_cache
from typing import Any
from urllib.parse import parse_qs
from urllib.parse import urlparse
from typing import cast
import bs4
from atlassian import Confluence # type:ignore
@@ -33,6 +32,7 @@ from danswer.connectors.confluence.rate_limit_handler import (
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
@@ -56,40 +56,8 @@ NO_PARENT_OR_NO_PERMISSIONS_ERROR_STR = (
)
class DanswerConfluence(Confluence):
"""
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
This is necessary because the default Confluence class does not properly support cql expansions.
"""
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
super(DanswerConfluence, self).__init__(url, *args, **kwargs)
def danswer_cql(
self,
cql: str,
expand: str | None = None,
cursor: str | None = None,
limit: int = 500,
include_archived_spaces: bool = False,
) -> dict[str, Any]:
url_suffix = f"rest/api/content/search?cql={cql}"
if expand:
url_suffix += f"&expand={expand}"
if cursor:
url_suffix += f"&cursor={cursor}"
url_suffix += f"&limit={limit}"
if include_archived_spaces:
url_suffix += "&includeArchivedSpaces=true"
try:
response = self.get(url_suffix)
return response
except Exception as e:
raise e
@lru_cache()
def _get_user(user_id: str, confluence_client: DanswerConfluence) -> str:
def _get_user(user_id: str, confluence_client: Confluence) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
Args:
@@ -105,7 +73,6 @@ def _get_user(user_id: str, confluence_client: DanswerConfluence) -> str:
confluence_client.get_user_details_by_accountid
)
try:
logger.info(f"_get_user - get_user_details_by_accountid: id={user_id}")
return get_user_details_by_accountid(user_id).get("displayName", user_not_found)
except Exception as e:
logger.warning(
@@ -114,7 +81,7 @@ def _get_user(user_id: str, confluence_client: DanswerConfluence) -> str:
return user_not_found
def parse_html_page(text: str, confluence_client: DanswerConfluence) -> str:
def parse_html_page(text: str, confluence_client: Confluence) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
@@ -145,7 +112,7 @@ def parse_html_page(text: str, confluence_client: DanswerConfluence) -> str:
def _comment_dfs(
comments_str: str,
comment_pages: Collection[dict[str, Any]],
confluence_client: DanswerConfluence,
confluence_client: Confluence,
) -> str:
get_page_child_by_type = make_confluence_call_handle_rate_limit(
confluence_client.get_page_child_by_type
@@ -157,9 +124,6 @@ def _comment_dfs(
comment_html, confluence_client
)
try:
logger.info(
f"_comment_dfs - get_page_by_child_type: id={comment_page['id']}"
)
child_comment_pages = get_page_child_by_type(
comment_page["id"],
type="comment",
@@ -199,103 +163,130 @@ class RecursiveIndexer:
index_recursively: bool,
origin_page_id: str,
) -> None:
self.batch_size = batch_size
self.batch_size = 1
# batch_size
self.confluence_client = confluence_client
self.index_recursively = index_recursively
self.origin_page_id = origin_page_id
self.pages = self.recurse_children_pages(self.origin_page_id)
self.pages = self.recurse_children_pages(0, self.origin_page_id)
def get_origin_page(self) -> list[dict[str, Any]]:
return [self._fetch_origin_page()]
def get_pages(self) -> list[dict[str, Any]]:
return self.pages
def get_pages(self, ind: int, size: int) -> list[dict]:
if ind * size > len(self.pages):
return []
return self.pages[ind * size : (ind + 1) * size]
def _fetch_origin_page(self) -> dict[str, Any]:
def _fetch_origin_page(
self,
) -> dict[str, Any]:
get_page_by_id = make_confluence_call_handle_rate_limit(
self.confluence_client.get_page_by_id
)
try:
logger.info(
f"_fetch_origin_page - get_page_by_id: id={self.origin_page_id}"
)
origin_page = get_page_by_id(
self.origin_page_id, expand="body.storage.value,version,space"
self.origin_page_id, expand="body.storage.value,version"
)
return origin_page
except Exception:
logger.exception(
f"Appending origin page with id {self.origin_page_id} failed."
except Exception as e:
logger.warning(
f"Appending orgin page with id {self.origin_page_id} failed: {e}"
)
return {}
def recurse_children_pages(
self,
start_ind: int,
page_id: str,
) -> list[dict[str, Any]]:
pages: list[dict[str, Any]] = []
queue: list[str] = [page_id]
visited_pages: set[str] = set()
current_level_pages: list[dict[str, Any]] = []
next_level_pages: list[dict[str, Any]] = []
get_page_by_id = make_confluence_call_handle_rate_limit(
self.confluence_client.get_page_by_id
)
# Initial fetch of first level children
index = start_ind
while batch := self._fetch_single_depth_child_pages(
index, self.batch_size, page_id
):
current_level_pages.extend(batch)
index += len(batch)
pages.extend(current_level_pages)
# Recursively index children and children's children, etc.
while current_level_pages:
for child in current_level_pages:
child_index = 0
while child_batch := self._fetch_single_depth_child_pages(
child_index, self.batch_size, child["id"]
):
next_level_pages.extend(child_batch)
child_index += len(child_batch)
pages.extend(next_level_pages)
current_level_pages = next_level_pages
next_level_pages = []
try:
origin_page = self._fetch_origin_page()
pages.append(origin_page)
except Exception as e:
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
return pages
def _fetch_single_depth_child_pages(
self, start_ind: int, batch_size: int, page_id: str
) -> list[dict[str, Any]]:
child_pages: list[dict[str, Any]] = []
get_page_child_by_type = make_confluence_call_handle_rate_limit(
self.confluence_client.get_page_child_by_type
)
while queue:
current_page_id = queue.pop(0)
if current_page_id in visited_pages:
continue
visited_pages.add(current_page_id)
try:
child_page = get_page_child_by_type(
page_id,
type="page",
start=start_ind,
limit=batch_size,
expand="body.storage.value,version",
)
try:
# Fetch the page itself
logger.info(
f"recurse_children_pages - get_page_by_id: id={current_page_id}"
)
page = get_page_by_id(
current_page_id, expand="body.storage.value,version,space"
)
pages.append(page)
except Exception:
logger.exception(f"Failed to fetch page {current_page_id}.")
continue
child_pages.extend(child_page)
return child_pages
if not self.index_recursively:
continue
except Exception:
logger.warning(
f"Batch failed with page {page_id} at offset {start_ind} "
f"with size {batch_size}, processing pages individually..."
)
# Fetch child pages
start = 0
while True:
logger.info(
f"recurse_children_pages - get_page_by_child_type: id={current_page_id}"
)
child_pages_response = get_page_child_by_type(
current_page_id,
type="page",
start=start,
limit=self.batch_size,
expand="",
)
if not child_pages_response:
break
for child_page in child_pages_response:
child_page_id = child_page["id"]
queue.append(child_page_id)
start += len(child_pages_response)
for i in range(batch_size):
ind = start_ind + i
try:
child_page = get_page_child_by_type(
page_id,
type="page",
start=ind,
limit=1,
expand="body.storage.value,version",
)
child_pages.extend(child_page)
except Exception as e:
logger.warning(f"Page {page_id} at offset {ind} failed: {e}")
raise e
return pages
return child_pages
class ConfluenceConnector(LoadConnector, PollConnector):
def __init__(
self,
wiki_base: str,
space: str,
is_cloud: bool,
space: str = "",
page_id: str = "",
index_recursively: bool = True,
batch_size: int = INDEX_BATCH_SIZE,
@@ -304,167 +295,104 @@ class ConfluenceConnector(LoadConnector, PollConnector):
# skip it. This is generally used to avoid indexing extra sensitive
# pages.
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
cql_query: str | None = None,
) -> None:
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self.labels_to_skip = set(labels_to_skip)
self.recursive_indexer: RecursiveIndexer | None = None
self.index_recursively = False if cql_query else index_recursively
self.index_recursively = index_recursively
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
self.page_id = "" if cql_query else page_id
self.space_level_scan = bool(not self.page_id)
self.space = space
self.page_id = page_id
self.is_cloud = is_cloud
self.confluence_client: DanswerConfluence | None = None
self.space_level_scan = False
self.confluence_client: Confluence | None = None
# if a cql_query is provided, we will use it to fetch the pages
# if no cql_query is provided, we will use the space to fetch the pages
# if no space is provided and no cql_query, we will default to fetching all pages, regardless of space
if cql_query:
self.cql_query = cql_query
elif space:
self.cql_query = f"type=page and space='{space}'"
else:
self.cql_query = "type=page"
if self.page_id is None or self.page_id == "":
self.space_level_scan = True
logger.info(
f"wiki_base: {self.wiki_base}, space: {space}, page_id: {self.page_id},"
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively},"
+ f" cql_query: {self.cql_query}"
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id},"
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively}"
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
username = credentials["confluence_username"]
access_token = credentials["confluence_access_token"]
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
# for a list of other hidden constructor args
self.confluence_client = DanswerConfluence(
self.confluence_client = Confluence(
url=self.wiki_base,
# passing in username causes issues for Confluence data center
username=username if self.is_cloud else None,
password=access_token if self.is_cloud else None,
token=access_token if not self.is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=60,
max_backoff_seconds=60,
)
return None
def _fetch_pages(
self,
cursor: str | None,
) -> tuple[list[dict[str, Any]], str | None]:
if self.confluence_client is None:
raise Exception("Confluence client is not initialized")
def _fetch_space(
cursor: str | None, batch_size: int
) -> tuple[list[dict[str, Any]], str | None]:
if not self.confluence_client:
raise Exception("Confluence client is not initialized")
get_all_pages = make_confluence_call_handle_rate_limit(
self.confluence_client.danswer_cql
confluence_client: Confluence,
start_ind: int,
) -> list[dict[str, Any]]:
def _fetch_space(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
get_all_pages_from_space = make_confluence_call_handle_rate_limit(
confluence_client.get_all_pages_from_space
)
include_archived_spaces = (
CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
if not self.is_cloud
else False
)
try:
logger.info(
f"_fetch_space - get_all_pages: cursor={cursor} limit={batch_size}"
)
response = get_all_pages(
cql=self.cql_query,
cursor=cursor,
return get_all_pages_from_space(
self.space,
start=start_ind,
limit=batch_size,
expand="body.storage.value,version,space",
include_archived_spaces=include_archived_spaces,
status=(
None if CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES else "current"
),
expand="body.storage.value,version",
)
pages = response.get("results", [])
next_cursor = None
if "_links" in response and "next" in response["_links"]:
next_link = response["_links"]["next"]
parsed_url = urlparse(next_link)
query_params = parse_qs(parsed_url.query)
cursor_list = query_params.get("cursor", [])
if cursor_list:
next_cursor = cursor_list[0]
return pages, next_cursor
except Exception:
logger.warning(
f"Batch failed with cql {self.cql_query} with cursor {cursor} "
f"and size {batch_size}, processing pages individually..."
f"Batch failed with space {self.space} at offset {start_ind} "
f"with size {batch_size}, processing pages individually..."
)
view_pages: list[dict[str, Any]] = []
for _ in range(self.batch_size):
for i in range(self.batch_size):
try:
logger.info(
f"_fetch_space - get_all_pages: cursor={cursor} limit=1"
# Could be that one of the pages here failed due to this bug:
# https://jira.atlassian.com/browse/CONFCLOUD-76433
view_pages.extend(
get_all_pages_from_space(
self.space,
start=start_ind + i,
limit=1,
status=(
None
if CONFLUENCE_CONNECTOR_INDEX_ARCHIVED_PAGES
else "current"
),
expand="body.storage.value,version",
)
)
response = get_all_pages(
cql=self.cql_query,
cursor=cursor,
limit=1,
expand="body.view.value,version,space",
include_archived_spaces=include_archived_spaces,
)
pages = response.get("results", [])
view_pages.extend(pages)
if "_links" in response and "next" in response["_links"]:
next_link = response["_links"]["next"]
parsed_url = urlparse(next_link)
query_params = parse_qs(parsed_url.query)
cursor_list = query_params.get("cursor", [])
if cursor_list:
cursor = cursor_list[0]
else:
cursor = None
else:
cursor = None
break
except HTTPError as e:
logger.warning(
f"Page failed with cql {self.cql_query} with cursor {cursor}, "
f"Page failed with space {self.space} at offset {start_ind + i}, "
f"trying alternative expand option: {e}"
)
logger.info(
f"_fetch_space - get_all_pages - trying alternative expand: cursor={cursor} limit=1"
# Use view instead, which captures most info but is less complete
view_pages.extend(
get_all_pages_from_space(
self.space,
start=start_ind + i,
limit=1,
expand="body.view.value,version",
)
)
response = get_all_pages(
cql=self.cql_query,
cursor=cursor,
limit=1,
expand="body.view.value,version,space",
)
pages = response.get("results", [])
view_pages.extend(pages)
if "_links" in response and "next" in response["_links"]:
next_link = response["_links"]["next"]
parsed_url = urlparse(next_link)
query_params = parse_qs(parsed_url.query)
cursor_list = query_params.get("cursor", [])
if cursor_list:
cursor = cursor_list[0]
else:
cursor = None
else:
cursor = None
break
return view_pages, cursor
def _fetch_page() -> tuple[list[dict[str, Any]], str | None]:
if self.confluence_client is None:
raise Exception("Confluence client is not initialized")
return view_pages
def _fetch_page(start_ind: int, batch_size: int) -> list[dict[str, Any]]:
if self.recursive_indexer is None:
self.recursive_indexer = RecursiveIndexer(
origin_page_id=self.page_id,
@@ -473,22 +401,41 @@ class ConfluenceConnector(LoadConnector, PollConnector):
index_recursively=self.index_recursively,
)
pages = self.recursive_indexer.get_pages()
return pages, None # Since we fetched all pages, no cursor
if self.index_recursively:
return self.recursive_indexer.get_pages(start_ind, batch_size)
else:
return self.recursive_indexer.get_origin_page()
pages: list[dict[str, Any]] = []
try:
pages, next_cursor = (
_fetch_space(cursor, self.batch_size)
pages = (
_fetch_space(start_ind, self.batch_size)
if self.space_level_scan
else _fetch_page()
else _fetch_page(start_ind, self.batch_size)
)
return pages, next_cursor
return pages
except Exception as e:
if not self.continue_on_failure:
raise e
logger.exception("Ran into exception when fetching pages from Confluence")
return [], None
# error checking phase, only reachable if `self.continue_on_failure=True`
for i in range(self.batch_size):
try:
pages = (
_fetch_space(start_ind, self.batch_size)
if self.space_level_scan
else _fetch_page(start_ind, self.batch_size)
)
return pages
except Exception:
logger.exception(
"Ran into exception when fetching pages from Confluence"
)
return pages
def _fetch_comments(self, confluence_client: Confluence, page_id: str) -> str:
get_page_child_by_type = make_confluence_call_handle_rate_limit(
@@ -496,22 +443,24 @@ class ConfluenceConnector(LoadConnector, PollConnector):
)
try:
logger.info(f"_fetch_comments - get_page_child_by_type: id={page_id}")
comment_pages = list(
comment_pages = cast(
Collection[dict[str, Any]],
get_page_child_by_type(
page_id,
type="comment",
start=None,
limit=None,
expand="body.storage.value",
)
),
)
return _comment_dfs("", comment_pages, confluence_client)
except Exception as e:
if not self.continue_on_failure:
raise e
logger.exception("Fetching comments from Confluence exceptioned")
logger.exception(
"Ran into exception when fetching comments from Confluence"
)
return ""
def _fetch_labels(self, confluence_client: Confluence, page_id: str) -> list[str]:
@@ -519,14 +468,13 @@ class ConfluenceConnector(LoadConnector, PollConnector):
confluence_client.get_page_labels
)
try:
logger.info(f"_fetch_labels - get_page_labels: id={page_id}")
labels_response = get_page_labels(page_id)
return [label["name"] for label in labels_response["results"]]
except Exception as e:
if not self.continue_on_failure:
raise e
logger.exception("Fetching labels from Confluence exceptioned")
logger.exception("Ran into exception when fetching labels from Confluence")
return []
@classmethod
@@ -563,7 +511,6 @@ class ConfluenceConnector(LoadConnector, PollConnector):
)
return None
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
response = confluence_client._session.get(download_link)
if response.status_code != 200:
logger.warning(
@@ -587,22 +534,22 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return extracted_text
def _fetch_attachments(
self, confluence_client: Confluence, page_id: str, files_in_use: list[str]
self, confluence_client: Confluence, page_id: str, files_in_used: list[str]
) -> tuple[str, list[dict[str, Any]]]:
unused_attachments: list[dict[str, Any]] = []
files_attachment_content: list[str] = []
unused_attachments: list = []
get_attachments_from_content = make_confluence_call_handle_rate_limit(
confluence_client.get_attachments_from_content
)
files_attachment_content: list = []
try:
expand = "history.lastUpdated,metadata.labels"
attachments_container = get_attachments_from_content(
page_id, start=None, limit=None, expand=expand
page_id, start=0, limit=500, expand=expand
)
for attachment in attachments_container.get("results", []):
if attachment["title"] not in files_in_use:
for attachment in attachments_container["results"]:
if attachment["title"] not in files_in_used:
unused_attachments.append(attachment)
continue
@@ -620,33 +567,36 @@ class ConfluenceConnector(LoadConnector, PollConnector):
f"User does not have access to attachments on page '{page_id}'"
)
return "", []
if not self.continue_on_failure:
raise e
logger.exception("Fetching attachments from Confluence exceptioned.")
logger.exception(
f"Ran into exception when fetching attachments from Confluence: {e}"
)
return "\n".join(files_attachment_content), unused_attachments
def _get_doc_batch(
self, cursor: str | None
) -> tuple[list[Any], str | None, list[dict[str, Any]]]:
if self.confluence_client is None:
raise Exception("Confluence client is not initialized")
doc_batch: list[Any] = []
self, start_ind: int, time_filter: Callable[[datetime], bool] | None = None
) -> tuple[list[Document], list[dict[str, Any]], int]:
doc_batch: list[Document] = []
unused_attachments: list[dict[str, Any]] = []
batch, next_cursor = self._fetch_pages(cursor)
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
batch = self._fetch_pages(self.confluence_client, start_ind)
for page in batch:
last_modified = _datetime_from_string(page["version"]["when"])
author = page["version"].get("by", {}).get("email")
author = cast(str | None, page["version"].get("by", {}).get("email"))
if time_filter and not time_filter(last_modified):
continue
page_id = page["id"]
if self.labels_to_skip or not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
page_labels = self._fetch_labels(self.confluence_client, page_id)
else:
page_labels = []
# check disallowed labels
if self.labels_to_skip:
@@ -656,6 +606,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
f"Page with ID '{page_id}' has a label which has been "
f"designated as disallowed: {label_intersection}. Skipping."
)
continue
page_html = (
@@ -670,18 +621,16 @@ class ConfluenceConnector(LoadConnector, PollConnector):
continue
page_text = parse_html_page(page_html, self.confluence_client)
files_in_use = get_used_attachments(page_html)
files_in_used = get_used_attachments(page_html)
attachment_text, unused_page_attachments = self._fetch_attachments(
self.confluence_client, page_id, files_in_use
self.confluence_client, page_id, files_in_used
)
unused_attachments.extend(unused_page_attachments)
page_text += "\n" + attachment_text if attachment_text else ""
comments_text = self._fetch_comments(self.confluence_client, page_id)
page_text += comments_text
doc_metadata: dict[str, str | list[str]] = {
"Wiki Space Name": page["space"]["name"]
}
doc_metadata: dict[str, str | list[str]] = {"Wiki Space Name": self.space}
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING and page_labels:
doc_metadata["labels"] = page_labels
@@ -700,8 +649,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
)
return (
doc_batch,
next_cursor,
unused_attachments,
len(batch),
)
def _get_attachment_batch(
@@ -709,8 +658,8 @@ class ConfluenceConnector(LoadConnector, PollConnector):
start_ind: int,
attachments: list[dict[str, Any]],
time_filter: Callable[[datetime], bool] | None = None,
) -> tuple[list[Any], int]:
doc_batch: list[Any] = []
) -> tuple[list[Document], int]:
doc_batch: list[Document] = []
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
@@ -738,7 +687,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
creator_email = attachment["history"]["createdBy"].get("email")
comment = attachment["metadata"].get("comment", "")
doc_metadata: dict[str, Any] = {"comment": comment}
doc_metadata: dict[str, str | list[str]] = {"comment": comment}
attachment_labels: list[str] = []
if not CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING:
@@ -765,36 +714,30 @@ class ConfluenceConnector(LoadConnector, PollConnector):
return doc_batch, end_ind - start_ind
def _handle_batch_retrieval(
self,
start: float | None = None,
end: float | None = None,
) -> GenerateDocumentsOutput:
start_time = datetime.fromtimestamp(start, tz=timezone.utc) if start else None
end_time = datetime.fromtimestamp(end, tz=timezone.utc) if end else None
def load_from_state(self) -> GenerateDocumentsOutput:
unused_attachments = []
unused_attachments: list[dict[str, Any]] = []
cursor = None
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
start_ind = 0
while True:
doc_batch, cursor, new_unused_attachments = self._get_doc_batch(cursor)
unused_attachments.extend(new_unused_attachments)
doc_batch, unused_attachments_batch, num_pages = self._get_doc_batch(
start_ind
)
unused_attachments.extend(unused_attachments_batch)
start_ind += num_pages
if doc_batch:
yield doc_batch
if not cursor:
if num_pages < self.batch_size:
break
# Process attachments if any
start_ind = 0
while True:
attachment_batch, num_attachments = self._get_attachment_batch(
start_ind=start_ind,
attachments=unused_attachments,
time_filter=(lambda t: start_time <= t <= end_time)
if start_time and end_time
else None,
start_ind, unused_attachments
)
start_ind += num_attachments
if attachment_batch:
yield attachment_batch
@@ -802,11 +745,44 @@ class ConfluenceConnector(LoadConnector, PollConnector):
if num_attachments < self.batch_size:
break
def load_from_state(self) -> GenerateDocumentsOutput:
return self._handle_batch_retrieval()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
unused_attachments = []
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
return self._handle_batch_retrieval(start=start, end=end)
if self.confluence_client is None:
raise ConnectorMissingCredentialError("Confluence")
start_time = datetime.fromtimestamp(start, tz=timezone.utc)
end_time = datetime.fromtimestamp(end, tz=timezone.utc)
start_ind = 0
while True:
doc_batch, unused_attachments_batch, num_pages = self._get_doc_batch(
start_ind, time_filter=lambda t: start_time <= t <= end_time
)
unused_attachments.extend(unused_attachments_batch)
start_ind += num_pages
if doc_batch:
yield doc_batch
if num_pages < self.batch_size:
break
start_ind = 0
while True:
attachment_batch, num_attachments = self._get_attachment_batch(
start_ind,
unused_attachments,
time_filter=lambda t: start_time <= t <= end_time,
)
start_ind += num_attachments
if attachment_batch:
yield attachment_batch
if num_attachments < self.batch_size:
break
if __name__ == "__main__":

View File

@@ -1,4 +1,3 @@
import math
import time
from collections.abc import Callable
from typing import Any
@@ -22,198 +21,62 @@ class ConfluenceRateLimitError(Exception):
pass
# commenting out while we try using confluence's rate limiter instead
# # https://developer.atlassian.com/cloud/confluence/rate-limiting/
# def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
# def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
# max_retries = 5
# starting_delay = 5
# backoff = 2
# # max_delay is used when the server doesn't hand back "Retry-After"
# # and we have to decide the retry delay ourselves
# max_delay = 30 # Atlassian uses max_delay = 30 in their examples
# # max_retry_after is used when we do get a "Retry-After" header
# max_retry_after = 300 # should we really cap the maximum retry delay?
# NEXT_RETRY_KEY = BaseConnector.REDIS_KEY_PREFIX + "confluence_next_retry"
# # for testing purposes, rate limiting is written to fall back to a simpler
# # rate limiting approach when redis is not available
# r = get_redis_client(tenant_id=tenant_id)
# for attempt in range(max_retries):
# try:
# # if multiple connectors are waiting for the next attempt, there could be an issue
# # where many connectors are "released" onto the server at the same time.
# # That's not ideal ... but coming up with a mechanism for queueing
# # all of these connectors is a bigger problem that we want to take on
# # right now
# try:
# next_attempt = r.get(NEXT_RETRY_KEY)
# if next_attempt is None:
# next_attempt = 0
# else:
# next_attempt = int(cast(int, next_attempt))
# # TODO: all connectors need to be interruptible moving forward
# while time.monotonic() < next_attempt:
# time.sleep(1)
# except ConnectionError:
# pass
# return confluence_call(*args, **kwargs)
# except HTTPError as e:
# # Check if the response or headers are None to avoid potential AttributeError
# if e.response is None or e.response.headers is None:
# logger.warning("HTTPError with `None` as response or as headers")
# raise e
# retry_after_header = e.response.headers.get("Retry-After")
# if (
# e.response.status_code == 429
# or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
# ):
# retry_after = None
# if retry_after_header is not None:
# try:
# retry_after = int(retry_after_header)
# except ValueError:
# pass
# if retry_after is not None:
# if retry_after > max_retry_after:
# logger.warning(
# f"Clamping retry_after from {retry_after} to {max_delay} seconds..."
# )
# retry_after = max_delay
# logger.warning(
# f"Rate limit hit. Retrying after {retry_after} seconds..."
# )
# try:
# r.set(
# NEXT_RETRY_KEY,
# math.ceil(time.monotonic() + retry_after),
# )
# except ConnectionError:
# pass
# else:
# logger.warning(
# "Rate limit hit. Retrying with exponential backoff..."
# )
# delay = min(starting_delay * (backoff**attempt), max_delay)
# delay_until = math.ceil(time.monotonic() + delay)
# try:
# r.set(NEXT_RETRY_KEY, delay_until)
# except ConnectionError:
# while time.monotonic() < delay_until:
# time.sleep(1)
# else:
# # re-raise, let caller handle
# raise
# except AttributeError as e:
# # Some error within the Confluence library, unclear why it fails.
# # Users reported it to be intermittent, so just retry
# logger.warning(f"Confluence Internal Error, retrying... {e}")
# delay = min(starting_delay * (backoff**attempt), max_delay)
# delay_until = math.ceil(time.monotonic() + delay)
# try:
# r.set(NEXT_RETRY_KEY, delay_until)
# except ConnectionError:
# while time.monotonic() < delay_until:
# time.sleep(1)
# if attempt == max_retries - 1:
# raise e
# return cast(F, wrapped_call)
def _handle_http_error(e: HTTPError, attempt: int) -> int:
MIN_DELAY = 2
MAX_DELAY = 60
STARTING_DELAY = 5
BACKOFF = 2
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
):
raise e
retry_after = None
retry_after_header = e.response.headers.get("Retry-After")
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
if retry_after > MAX_DELAY:
logger.warning(
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
)
retry_after = MAX_DELAY
if retry_after < MIN_DELAY:
retry_after = MIN_DELAY
except ValueError:
pass
if retry_after is not None:
logger.warning(
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
)
delay = retry_after
else:
logger.warning(
"Rate limiting without retry header. Retrying with exponential backoff..."
)
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
delay_until = math.ceil(time.monotonic() + delay)
return delay_until
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
def make_confluence_call_handle_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 3600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
if time.monotonic() > timeout_at:
raise TimeoutError(
f"Confluence call attempts took longer than {TIMEOUT} seconds."
)
max_retries = 5
starting_delay = 5
backoff = 2
max_delay = 600
for attempt in range(max_retries):
try:
# we're relying more on the client to rate limit itself
# and applying our own retries in a more specific set of circumstances
return confluence_call(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
retry_after_header = e.response.headers.get("Retry-After")
if (
e.response.status_code == 429
or RATE_LIMIT_MESSAGE_LOWERCASE in e.response.text.lower()
):
retry_after = None
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
except ValueError:
pass
if retry_after is not None:
if retry_after > 600:
logger.warning(
f"Clamping retry_after from {retry_after} to {max_delay} seconds..."
)
retry_after = max_delay
logger.warning(
f"Rate limit hit. Retrying after {retry_after} seconds..."
)
time.sleep(retry_after)
else:
logger.warning(
"Rate limit hit. Retrying with exponential backoff..."
)
delay = min(starting_delay * (backoff**attempt), max_delay)
time.sleep(delay)
else:
# re-raise, let caller handle
raise
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
if attempt == MAX_RETRIES - 1:
logger.warning(f"Confluence Internal Error, retrying... {e}")
delay = min(starting_delay * (backoff**attempt), max_delay)
time.sleep(delay)
if attempt == max_retries - 1:
raise e
logger.exception(
"Confluence Client raised an AttributeError. Retrying..."
)
time.sleep(5)
return cast(F, wrapped_call)

View File

@@ -4,7 +4,6 @@ from typing import Type
from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import DocumentSourceRequiringTenantContext
from danswer.connectors.asana.connector import AsanaConnector
from danswer.connectors.axero.connector import AxeroConnector
from danswer.connectors.blob.connector import BlobStorageConnector
@@ -135,13 +134,8 @@ def instantiate_connector(
input_type: InputType,
connector_specific_config: dict[str, Any],
credential: Credential,
tenant_id: str | None = None,
) -> BaseConnector:
connector_class = identify_connector_class(source, input_type)
if source in DocumentSourceRequiringTenantContext:
connector_specific_config["tenant_id"] = tenant_id
connector = connector_class(**connector_specific_config)
new_credentials = connector.load_credentials(credential.credential_json)

View File

@@ -10,14 +10,13 @@ from sqlalchemy.orm import Session
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.file_processing.extract_file_text import check_file_ext_is_valid
from danswer.file_processing.extract_file_text import detect_encoding
from danswer.file_processing.extract_file_text import extract_file_text
@@ -28,7 +27,6 @@ from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.extract_file_text import read_text_file
from danswer.file_store.file_store import get_default_file_store
from danswer.utils.logger import setup_logger
from shared_configs.configs import current_tenant_id
logger = setup_logger()
@@ -161,12 +159,10 @@ class LocalFileConnector(LoadConnector):
def __init__(
self,
file_locations: list[Path | str],
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.file_locations = [Path(file_location) for file_location in file_locations]
self.batch_size = batch_size
self.tenant_id = tenant_id
self.pdf_pass: str | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
@@ -175,9 +171,7 @@ class LocalFileConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
token = current_tenant_id.set(self.tenant_id)
with get_session_with_tenant(self.tenant_id) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
for file_path in self.file_locations:
current_datetime = datetime.now(timezone.utc)
files = _read_files_and_metadata(
@@ -199,8 +193,6 @@ class LocalFileConnector(LoadConnector):
if documents:
yield documents
current_tenant_id.reset(token)
if __name__ == "__main__":
connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]])

View File

@@ -462,34 +462,8 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
for permission in file["permissions"]
):
continue
try:
text_contents = extract_text(file, service) or ""
except HttpError as e:
reason = (
e.error_details[0]["reason"]
if e.error_details
else e.reason
)
message = (
e.error_details[0]["message"]
if e.error_details
else e.reason
)
# these errors don't represent a failure in the connector, but simply files
# that can't / shouldn't be indexed
ERRORS_TO_CONTINUE_ON = [
"cannotExportFile",
"exportSizeLimitExceeded",
"cannotDownloadFile",
]
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
logger.warning(
f"Could not export file '{file['name']}' due to '{message}', skipping..."
)
continue
raise
text_contents = extract_text(file, service) or ""
doc_batch.append(
Document(

View File

@@ -11,8 +11,6 @@ GenerateDocumentsOutput = Iterator[list[Document]]
class BaseConnector(abc.ABC):
REDIS_KEY_PREFIX = "da_connector_data:"
@abc.abstractmethod
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
raise NotImplementedError

View File

@@ -45,7 +45,8 @@ class FamilyFileGeneratorInMemory(generate_family_file.FamilyFileGenerator):
if any(x not in generate_family_file.NAME_CHARACTERS for x in name):
raise ValueError(
f'ERROR: Name of family "{name}" must be ASCII letters and digits [a-zA-Z0-9]',
'ERROR: Name of family "{}" must be ASCII letters and digits [a-zA-Z0-9]',
name,
)
if isinstance(dointerwiki, bool):

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import datetime
import itertools
from collections.abc import Generator
from collections.abc import Iterator
from typing import Any
from typing import ClassVar
@@ -20,9 +19,6 @@ from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.mediawiki.family import family_class_dispatch
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
logger = setup_logger()
def pywikibot_timestamp_to_utc_datetime(
@@ -78,7 +74,7 @@ def get_doc_from_page(
sections=sections,
semantic_identifier=page.title(),
metadata={"categories": [category.title() for category in page.categories()]},
id=f"MEDIAWIKI_{page.pageid}_{page.full_url()}",
id=page.pageid,
)
@@ -121,18 +117,13 @@ class MediaWikiConnector(LoadConnector, PollConnector):
# short names can only have ascii letters and digits
self.family = family_class_dispatch(hostname, "WikipediaConnector")()
self.family = family_class_dispatch(hostname, "Wikipedia Connector")()
self.site = pywikibot.Site(fam=self.family, code=language_code)
self.categories = [
pywikibot.Category(self.site, f"Category:{category.replace(' ', '_')}")
for category in categories
]
self.pages = []
for page in pages:
if not page:
continue
self.pages.append(pywikibot.Page(self.site, page))
self.pages = [pywikibot.Page(self.site, page) for page in pages]
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
"""Load credentials for a MediaWiki site.
@@ -178,13 +169,8 @@ class MediaWikiConnector(LoadConnector, PollConnector):
]
# Since we can specify both individual pages and categories, we need to iterate over all of them.
all_pages: Iterator[pywikibot.Page] = itertools.chain(
self.pages, *category_pages
)
all_pages = itertools.chain(self.pages, *category_pages)
for page in all_pages:
logger.info(
f"MediaWikiConnector: title='{page.title()}' url={page.full_url()}"
)
doc_batch.append(
get_doc_from_page(page, self.site, self.document_source_type)
)

View File

@@ -29,9 +29,6 @@ logger = setup_logger()
_NOTION_CALL_TIMEOUT = 30 # 30 seconds
# TODO: Tables need to be ingested, Pages need to have their metadata ingested
@dataclass
class NotionPage:
"""Represents a Notion Page object"""
@@ -43,8 +40,6 @@ class NotionPage:
properties: dict[str, Any]
url: str
database_name: str | None # Only applicable to the database type page (wiki)
def __init__(self, **kwargs: dict[str, Any]) -> None:
names = set([f.name for f in fields(self)])
for k, v in kwargs.items():
@@ -52,17 +47,6 @@ class NotionPage:
setattr(self, k, v)
@dataclass
class NotionBlock:
"""Represents a Notion Block object"""
id: str # Used for the URL
text: str
# In a plaintext representation of the page, how this block should be joined
# with the existing text up to this point, separated out from text for clarity
prefix: str
@dataclass
class NotionSearchResponse:
"""Represents the response from the Notion Search API"""
@@ -78,6 +62,7 @@ class NotionSearchResponse:
setattr(self, k, v)
# TODO - Add the ability to optionally limit to specific Notion databases
class NotionConnector(LoadConnector, PollConnector):
"""Notion Page connector that reads all Notion pages
this integration has been granted access to.
@@ -141,46 +126,20 @@ class NotionConnector(LoadConnector, PollConnector):
@retry(tries=3, delay=1, backoff=2)
def _fetch_page(self, page_id: str) -> NotionPage:
"""Fetch a page from its ID via the Notion API, retry with database if page fetch fails."""
"""Fetch a page from it's ID via the Notion API."""
logger.debug(f"Fetching page for ID '{page_id}'")
page_url = f"https://api.notion.com/v1/pages/{page_id}"
block_url = f"https://api.notion.com/v1/pages/{page_id}"
res = rl_requests.get(
page_url,
block_url,
headers=self.headers,
timeout=_NOTION_CALL_TIMEOUT,
)
try:
res.raise_for_status()
except Exception as e:
logger.warning(
f"Failed to fetch page, trying database for ID '{page_id}'. Exception: {e}"
)
# Try fetching as a database if page fetch fails, this happens if the page is set to a wiki
# it becomes a database from the notion perspective
return self._fetch_database_as_page(page_id)
return NotionPage(**res.json())
@retry(tries=3, delay=1, backoff=2)
def _fetch_database_as_page(self, database_id: str) -> NotionPage:
"""Attempt to fetch a database as a page."""
logger.debug(f"Fetching database for ID '{database_id}' as a page")
database_url = f"https://api.notion.com/v1/databases/{database_id}"
res = rl_requests.get(
database_url,
headers=self.headers,
timeout=_NOTION_CALL_TIMEOUT,
)
try:
res.raise_for_status()
except Exception as e:
logger.exception(f"Error fetching database as page - {res.json()}")
logger.exception(f"Error fetching page - {res.json()}")
raise e
database_name = res.json().get("title")
database_name = (
database_name[0].get("text", {}).get("content") if database_name else None
)
return NotionPage(**res.json(), database_name=database_name)
return NotionPage(**res.json())
@retry(tries=3, delay=1, backoff=2)
def _fetch_database(
@@ -212,75 +171,8 @@ class NotionConnector(LoadConnector, PollConnector):
raise e
return res.json()
@staticmethod
def _properties_to_str(properties: dict[str, Any]) -> str:
"""Converts Notion properties to a string"""
def _recurse_properties(inner_dict: dict[str, Any]) -> str | None:
while "type" in inner_dict:
type_name = inner_dict["type"]
inner_dict = inner_dict[type_name]
# If the innermost layer is None, the value is not set
if not inner_dict:
return None
if isinstance(inner_dict, list):
list_properties = [
_recurse_properties(item) for item in inner_dict if item
]
return (
", ".join(
[
list_property
for list_property in list_properties
if list_property
]
)
or None
)
# TODO there may be more types to handle here
if "name" in inner_dict:
return inner_dict["name"]
if "content" in inner_dict:
return inner_dict["content"]
start = inner_dict.get("start")
end = inner_dict.get("end")
if start is not None:
if end is not None:
return f"{start} - {end}"
return start
elif end is not None:
return f"Until {end}"
if "id" in inner_dict:
# This is not useful to index, it's a reference to another Notion object
# and this ID value in plaintext is useless outside of the Notion context
logger.debug("Skipping Notion object id field property")
return None
logger.debug(f"Unreadable property from innermost prop: {inner_dict}")
return None
result = ""
for prop_name, prop in properties.items():
if not prop:
continue
inner_value = _recurse_properties(prop)
# Not a perfect way to format Notion database tables but there's no perfect representation
# since this must be represented as plaintext
if inner_value:
result += f"{prop_name}: {inner_value}\t"
return result
def _read_pages_from_database(
self, database_id: str
) -> tuple[list[NotionBlock], list[str]]:
"""Returns a list of top level blocks and all page IDs in the database"""
result_blocks: list[NotionBlock] = []
def _read_pages_from_database(self, database_id: str) -> list[str]:
"""Returns a list of all page IDs in the database"""
result_pages: list[str] = []
cursor = None
while True:
@@ -289,34 +181,29 @@ class NotionConnector(LoadConnector, PollConnector):
for result in data["results"]:
obj_id = result["id"]
obj_type = result["object"]
text = self._properties_to_str(result.get("properties", {}))
if text:
result_blocks.append(NotionBlock(id=obj_id, text=text, prefix="\n"))
if self.recursive_index_enabled:
if obj_type == "page":
logger.debug(
f"Found page with ID '{obj_id}' in database '{database_id}'"
)
result_pages.append(result["id"])
elif obj_type == "database":
logger.debug(
f"Found database with ID '{obj_id}' in database '{database_id}'"
)
# The inner contents are ignored at this level
_, child_pages = self._read_pages_from_database(obj_id)
result_pages.extend(child_pages)
if obj_type == "page":
logger.debug(
f"Found page with ID '{obj_id}' in database '{database_id}'"
)
result_pages.append(result["id"])
elif obj_type == "database":
logger.debug(
f"Found database with ID '{obj_id}' in database '{database_id}'"
)
result_pages.extend(self._read_pages_from_database(obj_id))
if data["next_cursor"] is None:
break
cursor = data["next_cursor"]
return result_blocks, result_pages
return result_pages
def _read_blocks(self, base_block_id: str) -> tuple[list[NotionBlock], list[str]]:
"""Reads all child blocks for the specified block, returns a list of blocks and child page ids"""
result_blocks: list[NotionBlock] = []
def _read_blocks(
self, base_block_id: str
) -> tuple[list[tuple[str, str]], list[str]]:
"""Reads all child blocks for the specified block"""
result_lines: list[tuple[str, str]] = []
child_pages: list[str] = []
cursor = None
while True:
@@ -324,7 +211,7 @@ class NotionConnector(LoadConnector, PollConnector):
# this happens when a block is not shared with the integration
if data is None:
return result_blocks, child_pages
return result_lines, child_pages
for result in data["results"]:
logger.debug(
@@ -368,70 +255,46 @@ class NotionConnector(LoadConnector, PollConnector):
if result["has_children"]:
if result_type == "child_page":
# Child pages will not be included at this top level, it will be a separate document
child_pages.append(result_block_id)
else:
logger.debug(f"Entering sub-block: {result_block_id}")
subblocks, subblock_child_pages = self._read_blocks(
subblock_result_lines, subblock_child_pages = self._read_blocks(
result_block_id
)
logger.debug(f"Finished sub-block: {result_block_id}")
result_blocks.extend(subblocks)
result_lines.extend(subblock_result_lines)
child_pages.extend(subblock_child_pages)
if result_type == "child_database":
inner_blocks, inner_child_pages = self._read_pages_from_database(
result_block_id
)
# A database on a page often looks like a table, we need to include it for the contents
# of the page but the children (cells) should be processed as other Documents
result_blocks.extend(inner_blocks)
if result_type == "child_database" and self.recursive_index_enabled:
child_pages.extend(self._read_pages_from_database(result_block_id))
if self.recursive_index_enabled:
child_pages.extend(inner_child_pages)
if cur_result_text_arr:
new_block = NotionBlock(
id=result_block_id,
text="\n".join(cur_result_text_arr),
prefix="\n",
)
result_blocks.append(new_block)
cur_result_text = "\n".join(cur_result_text_arr)
if cur_result_text:
result_lines.append((cur_result_text, result_block_id))
if data["next_cursor"] is None:
break
cursor = data["next_cursor"]
return result_blocks, child_pages
return result_lines, child_pages
def _read_page_title(self, page: NotionPage) -> str | None:
def _read_page_title(self, page: NotionPage) -> str:
"""Extracts the title from a Notion page"""
page_title = None
if hasattr(page, "database_name") and page.database_name:
return page.database_name
for _, prop in page.properties.items():
if prop["type"] == "title" and len(prop["title"]) > 0:
page_title = " ".join([t["plain_text"] for t in prop["title"]]).strip()
break
if page_title is None:
page_title = f"Untitled Page [{page.id}]"
return page_title
def _read_pages(
self,
pages: list[NotionPage],
) -> Generator[Document, None, None]:
"""Reads pages for rich text content and generates Documents
Note that a page which is turned into a "wiki" becomes a database but both top level pages and top level databases
do not seem to have any properties associated with them.
Pages that are part of a database can have properties which are like the values of the row in the "database" table
in which they exist
This is not clearly outlined in the Notion API docs but it is observable empirically.
https://developers.notion.com/docs/working-with-page-content
"""
"""Reads pages for rich text content and generates Documents"""
all_child_page_ids: list[str] = []
for page in pages:
if page.id in self.indexed_pages:
@@ -441,23 +304,18 @@ class NotionConnector(LoadConnector, PollConnector):
logger.info(f"Reading page with ID '{page.id}', with url {page.url}")
page_blocks, child_page_ids = self._read_blocks(page.id)
all_child_page_ids.extend(child_page_ids)
if not page_blocks:
continue
page_title = (
self._read_page_title(page) or f"Untitled Page with ID {page.id}"
)
page_title = self._read_page_title(page)
yield (
Document(
id=page.id,
sections=[
# Will add title to the first section later in processing
sections=[Section(link=page.url, text="")]
+ [
Section(
link=f"{page.url}#{block.id.replace('-', '')}",
text=block.prefix + block.text,
link=f"{page.url}#{block_id.replace('-', '')}",
text=block_text,
)
for block in page_blocks
for block_text, block_id in page_blocks
],
source=DocumentSource.NOTION,
semantic_identifier=page_title,

View File

@@ -128,9 +128,6 @@ def get_internal_links(
if not href:
continue
# Account for malformed backslashes in URLs
href = href.replace("\\", "/")
if should_ignore_pound and "#" in href:
href = href.split("#")[0]

View File

@@ -1,66 +0,0 @@
from mistune import Markdown # type: ignore
from mistune import Renderer # type: ignore
def format_slack_message(message: str | None) -> str:
renderer = Markdown(renderer=SlackRenderer())
return renderer.render(message)
class SlackRenderer(Renderer):
SPECIALS: dict[str, str] = {"&": "&amp;", "<": "&lt;", ">": "&gt;"}
def escape_special(self, text: str) -> str:
for special, replacement in self.SPECIALS.items():
text = text.replace(special, replacement)
return text
def header(self, text: str, level: int, raw: str | None = None) -> str:
return f"*{text}*\n"
def emphasis(self, text: str) -> str:
return f"_{text}_"
def double_emphasis(self, text: str) -> str:
return f"*{text}*"
def strikethrough(self, text: str) -> str:
return f"~{text}~"
def list(self, body: str, ordered: bool = True) -> str:
lines = body.split("\n")
count = 0
for i, line in enumerate(lines):
if line.startswith("li: "):
count += 1
prefix = f"{count}. " if ordered else ""
lines[i] = f"{prefix}{line[4:]}"
return "\n".join(lines)
def list_item(self, text: str) -> str:
return f"li: {text}\n"
def link(self, link: str, title: str | None, content: str | None) -> str:
escaped_link = self.escape_special(link)
if content:
return f"<{escaped_link}|{content}>"
if title:
return f"<{escaped_link}|{title}>"
return f"<{escaped_link}>"
def image(self, src: str, title: str | None, text: str | None) -> str:
escaped_src = self.escape_special(src)
display_text = title or text
return f"<{escaped_src}|{display_text}>" if display_text else f"<{escaped_src}>"
def codespan(self, text: str) -> str:
return f"`{text}`"
def block_code(self, text: str, lang: str | None) -> str:
return f"```\n{text}\n```\n"
def paragraph(self, text: str) -> str:
return f"{text}\n"
def autolink(self, link: str, is_email: bool) -> str:
return link if is_email else self.link(link, None, None)

View File

@@ -4,7 +4,9 @@ from typing import cast
from slack_sdk import WebClient
from slack_sdk.models.blocks import SectionBlock
from slack_sdk.models.views import View
from slack_sdk.socket_mode import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType
@@ -33,22 +35,20 @@ from danswer.danswerbot.slack.utils import get_channel_name_from_id
from danswer.danswerbot.slack.utils import get_feedback_visibility
from danswer.danswerbot.slack.utils import read_slack_thread
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import TenantSocketModeClient
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.feedback import create_chat_message_feedback
from danswer.db.feedback import create_doc_retrieval_feedback
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.utils.logger import setup_logger
logger = setup_logger()
def handle_doc_feedback_button(
req: SocketModeRequest,
client: TenantSocketModeClient,
client: SocketModeClient,
) -> None:
if not (actions := req.payload.get("actions")):
logger.error("Missing actions. Unable to build the source feedback view")
@@ -81,7 +81,7 @@ def handle_doc_feedback_button(
def handle_generate_answer_button(
req: SocketModeRequest,
client: TenantSocketModeClient,
client: SocketModeClient,
) -> None:
channel_id = req.payload["channel"]["id"]
channel_name = req.payload["channel"]["name"]
@@ -116,7 +116,7 @@ def handle_generate_answer_button(
thread_ts=thread_ts,
)
with get_session_with_tenant(client.tenant_id) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
)
@@ -136,7 +136,6 @@ def handle_generate_answer_button(
slack_bot_config=slack_bot_config,
receiver_ids=None,
client=client.web_client,
tenant_id=client.tenant_id,
channel=channel_id,
logger=logger,
feedback_reminder_id=None,
@@ -151,11 +150,12 @@ def handle_slack_feedback(
user_id_to_post_confirmation: str,
channel_id_to_post_confirmation: str,
thread_ts_to_post_confirmation: str,
tenant_id: str | None,
) -> None:
engine = get_sqlalchemy_engine()
message_id, doc_id, doc_rank = decompose_action_id(feedback_id)
with get_session_with_tenant(tenant_id) as db_session:
with Session(engine) as db_session:
if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]:
create_chat_message_feedback(
is_positive=feedback_type == LIKE_BLOCK_ACTION_ID,
@@ -232,7 +232,7 @@ def handle_slack_feedback(
def handle_followup_button(
req: SocketModeRequest,
client: TenantSocketModeClient,
client: SocketModeClient,
) -> None:
action_id = None
if actions := req.payload.get("actions"):
@@ -252,7 +252,7 @@ def handle_followup_button(
tag_ids: list[str] = []
group_ids: list[str] = []
with get_session_with_tenant(client.tenant_id) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
channel_name, is_dm = get_channel_name_from_id(
client=client.web_client, channel_id=channel_id
)
@@ -295,7 +295,7 @@ def handle_followup_button(
def get_clicker_name(
req: SocketModeRequest,
client: TenantSocketModeClient,
client: SocketModeClient,
) -> str:
clicker_name = req.payload.get("user", {}).get("name", "Someone")
clicker_real_name = None
@@ -316,7 +316,7 @@ def get_clicker_name(
def handle_followup_resolved_button(
req: SocketModeRequest,
client: TenantSocketModeClient,
client: SocketModeClient,
immediate: bool = False,
) -> None:
channel_id = req.payload["container"]["channel_id"]

View File

@@ -2,6 +2,7 @@ import datetime
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from sqlalchemy.orm import Session
from danswer.configs.danswerbot_configs import DANSWER_BOT_FEEDBACK_REMINDER
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
@@ -18,7 +19,7 @@ from danswer.danswerbot.slack.utils import fetch_user_ids_from_groups
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import slack_usage_report
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import SlackBotConfig
from danswer.db.users import add_non_web_user_if_not_exists
from danswer.utils.logger import setup_logger
@@ -109,7 +110,6 @@ def handle_message(
slack_bot_config: SlackBotConfig | None,
client: WebClient,
feedback_reminder_id: str | None,
tenant_id: str | None,
) -> bool:
"""Potentially respond to the user message depending on filters and if an answer was generated
@@ -135,9 +135,7 @@ def handle_message(
action = "slack_tag_message"
elif is_bot_dm:
action = "slack_dm_message"
slack_usage_report(
action=action, sender_id=sender_id, client=client, tenant_id=tenant_id
)
slack_usage_report(action=action, sender_id=sender_id, client=client)
document_set_names: list[str] | None = None
persona = slack_bot_config.persona if slack_bot_config else None
@@ -211,7 +209,7 @@ def handle_message(
except SlackApiError as e:
logger.error(f"Was not able to react to user message due to: {e}")
with get_session_with_tenant(tenant_id) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
if message_info.email:
add_non_web_user_if_not_exists(db_session, message_info.email)
@@ -237,6 +235,5 @@ def handle_message(
channel=channel,
logger=logger,
feedback_reminder_id=feedback_reminder_id,
tenant_id=tenant_id,
)
return issue_with_regular_answer

View File

@@ -9,6 +9,7 @@ from retry import retry
from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock
from slack_sdk.models.blocks import SectionBlock
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
@@ -26,13 +27,12 @@ from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
from danswer.danswerbot.slack.blocks import build_sources_blocks
from danswer.danswerbot.slack.blocks import get_restate_blocks
from danswer.danswerbot.slack.formatting import format_slack_message
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import SlackRateLimiter
from danswer.danswerbot.slack.utils import update_emote_react
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import Persona
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
@@ -87,7 +87,6 @@ def handle_regular_answer(
channel: str,
logger: DanswerLoggingAdapter,
feedback_reminder_id: str | None,
tenant_id: str | None,
num_retries: int = DANSWER_BOT_NUM_RETRIES,
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
@@ -104,7 +103,8 @@ def handle_regular_answer(
user = None
if message_info.is_bot_dm:
if message_info.email:
with get_session_with_tenant(tenant_id) as db_session:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
user = get_user_by_email(message_info.email, db_session)
document_set_names: list[str] | None = None
@@ -151,7 +151,7 @@ def handle_regular_answer(
max_document_tokens: int | None = None
max_history_tokens: int | None = None
with get_session_with_tenant(tenant_id) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
if len(new_message_request.messages) > 1:
if new_message_request.persona_config:
raise RuntimeError("Slack bot does not support persona config")
@@ -245,7 +245,7 @@ def handle_regular_answer(
)
# Always apply reranking settings if it exists, this is the non-streaming flow
with get_session_with_tenant(tenant_id) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
saved_search_settings = get_current_search_settings(db_session)
# This includes throwing out answer via reflexion
@@ -412,11 +412,10 @@ def handle_regular_answer(
# If called with the DanswerBot slash command, the question is lost so we have to reshow it
restate_question_block = get_restate_blocks(messages[-1].message, is_bot_msg)
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
answer_blocks = build_qa_response_blocks(
message_id=answer.chat_message_id,
answer=formatted_answer,
answer=answer.answer,
quotes=answer.quotes.quotes if answer.quotes else None,
source_filters=retrieval_info.applied_source_filters,
time_cutoff=retrieval_info.applied_time_cutoff,

View File

@@ -4,10 +4,11 @@ from typing import Any
from typing import cast
from slack_sdk import WebClient
from slack_sdk.socket_mode import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import get_all_tenant_ids
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_REPHRASE_MESSAGE
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
@@ -46,8 +47,7 @@ from danswer.danswerbot.slack.utils import read_slack_thread
from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
from danswer.danswerbot.slack.utils import rephrase_slack_message
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.danswerbot.slack.utils import TenantSocketModeClient
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.search_settings import get_current_search_settings
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
@@ -80,7 +80,7 @@ _SLACK_GREETINGS_TO_IGNORE = {
_OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool:
def prefilter_requests(req: SocketModeRequest, client: SocketModeClient) -> bool:
"""True to keep going, False to ignore this Slack request"""
if req.type == "events_api":
# Verify channel is valid
@@ -153,7 +153,8 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
client=client.web_client, channel_id=channel
)
with get_session_with_tenant(client.tenant_id) as db_session:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
)
@@ -220,7 +221,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
return True
def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
def process_feedback(req: SocketModeRequest, client: SocketModeClient) -> None:
if actions := req.payload.get("actions"):
action = cast(dict[str, Any], actions[0])
feedback_type = cast(str, action.get("action_id"))
@@ -242,7 +243,6 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
user_id_to_post_confirmation=user_id,
channel_id_to_post_confirmation=channel_id,
thread_ts_to_post_confirmation=thread_ts,
tenant_id=client.tenant_id,
)
query_event_id, _, _ = decompose_action_id(feedback_id)
@@ -250,7 +250,7 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
def build_request_details(
req: SocketModeRequest, client: TenantSocketModeClient
req: SocketModeRequest, client: SocketModeClient
) -> SlackMessageInfo:
if req.type == "events_api":
event = cast(dict[str, Any], req.payload["event"])
@@ -329,7 +329,7 @@ def build_request_details(
def apologize_for_fail(
details: SlackMessageInfo,
client: TenantSocketModeClient,
client: SocketModeClient,
) -> None:
respond_in_thread(
client=client.web_client,
@@ -341,7 +341,7 @@ def apologize_for_fail(
def process_message(
req: SocketModeRequest,
client: TenantSocketModeClient,
client: SocketModeClient,
respond_every_channel: bool = DANSWER_BOT_RESPOND_EVERY_CHANNEL,
notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER,
) -> None:
@@ -357,7 +357,8 @@ def process_message(
client=client.web_client, channel_id=channel
)
with get_session_with_tenant(client.tenant_id) as db_session:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
slack_bot_config = get_slack_bot_config_for_channel(
channel_name=channel_name, db_session=db_session
)
@@ -389,7 +390,6 @@ def process_message(
slack_bot_config=slack_bot_config,
client=client.web_client,
feedback_reminder_id=feedback_reminder_id,
tenant_id=client.tenant_id,
)
if failed:
@@ -404,12 +404,12 @@ def process_message(
apologize_for_fail(details, client)
def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
def acknowledge_message(req: SocketModeRequest, client: SocketModeClient) -> None:
response = SocketModeResponse(envelope_id=req.envelope_id)
client.send_socket_mode_response(response)
def action_routing(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
def action_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
if actions := req.payload.get("actions"):
action = cast(dict[str, Any], actions[0])
@@ -429,13 +429,13 @@ def action_routing(req: SocketModeRequest, client: TenantSocketModeClient) -> No
return handle_generate_answer_button(req, client)
def view_routing(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
def view_routing(req: SocketModeRequest, client: SocketModeClient) -> None:
if view := req.payload.get("view"):
if view["callback_id"] == VIEW_DOC_FEEDBACK_ID:
return process_feedback(req, client)
def process_slack_event(client: TenantSocketModeClient, req: SocketModeRequest) -> None:
def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> None:
# Always respond right away, if Slack doesn't receive these frequently enough
# it will assume the Bot is DEAD!!! :(
acknowledge_message(req, client)
@@ -453,24 +453,21 @@ def process_slack_event(client: TenantSocketModeClient, req: SocketModeRequest)
logger.error(f"Slack request payload: {req.payload}")
def _get_socket_client(
slack_bot_tokens: SlackBotTokens, tenant_id: str | None
) -> TenantSocketModeClient:
def _get_socket_client(slack_bot_tokens: SlackBotTokens) -> SocketModeClient:
# For more info on how to set this up, checkout the docs:
# https://docs.danswer.dev/slack_bot_setup
return TenantSocketModeClient(
return SocketModeClient(
# This app-level token will be used only for establishing a connection
app_token=slack_bot_tokens.app_token,
web_client=WebClient(token=slack_bot_tokens.bot_token),
tenant_id=tenant_id,
)
def _initialize_socket_client(socket_client: TenantSocketModeClient) -> None:
def _initialize_socket_client(socket_client: SocketModeClient) -> None:
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.notice(f"Listening for messages from Slack {socket_client.tenant_id }...")
logger.notice("Listening for messages from Slack...")
socket_client.connect()
@@ -484,8 +481,8 @@ def _initialize_socket_client(socket_client: TenantSocketModeClient) -> None:
# NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC
# without issue.
if __name__ == "__main__":
slack_bot_tokens: dict[str | None, SlackBotTokens] = {}
socket_clients: dict[str | None, TenantSocketModeClient] = {}
slack_bot_tokens: SlackBotTokens | None = None
socket_client: SocketModeClient | None = None
set_is_ee_based_on_env_variable()
@@ -494,59 +491,46 @@ if __name__ == "__main__":
while True:
try:
tenant_ids = get_all_tenant_ids() # Function to retrieve all tenant IDs
latest_slack_bot_tokens = fetch_tokens()
for tenant_id in tenant_ids:
with get_session_with_tenant(tenant_id) as db_session:
try:
latest_slack_bot_tokens = fetch_tokens()
if latest_slack_bot_tokens != slack_bot_tokens:
if slack_bot_tokens is not None:
logger.notice("Slack Bot tokens have changed - reconnecting")
else:
# This happens on the very first time the listener process comes up
# or the tokens have updated (set up for the first time)
with Session(get_sqlalchemy_engine()) as db_session:
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
if (
tenant_id not in slack_bot_tokens
or latest_slack_bot_tokens != slack_bot_tokens[tenant_id]
):
if tenant_id in slack_bot_tokens:
logger.notice(
f"Slack Bot tokens have changed for tenant {tenant_id} - reconnecting"
)
else:
# Initial setup for this tenant
search_settings = get_current_search_settings(
db_session
)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
warm_up_bi_encoder(
embedding_model=embedding_model,
)
slack_bot_tokens[tenant_id] = latest_slack_bot_tokens
slack_bot_tokens = latest_slack_bot_tokens
# potentially may cause a message to be dropped, but it is complicated
# to avoid + (1) if the user is changing tokens, they are likely okay with some
# "migration downtime" and (2) if a single message is lost it is okay
# as this should be a very rare occurrence
if socket_client:
socket_client.close()
# potentially may cause a message to be dropped, but it is complicated
# to avoid + (1) if the user is changing tokens, they are likely okay with some
# "migration downtime" and (2) if a single message is lost it is okay
# as this should be a very rare occurrence
if tenant_id in socket_clients:
socket_clients[tenant_id].close()
socket_client = _get_socket_client(slack_bot_tokens)
_initialize_socket_client(socket_client)
socket_client = _get_socket_client(
latest_slack_bot_tokens, tenant_id
)
_initialize_socket_client(socket_client)
socket_clients[tenant_id] = socket_client
except KvKeyNotFoundError:
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
if tenant_id in socket_clients:
socket_clients[tenant_id].disconnect()
del socket_clients[tenant_id]
del slack_bot_tokens[tenant_id]
# Wait before checking for updates
# Let the handlers run in the background + re-check for token updates every 60 seconds
Event().wait(timeout=60)
except Exception:
logger.exception("An error occurred outside of main event loop")
except KvKeyNotFoundError:
# try again every 30 seconds. This is needed since the user may add tokens
# via the UI at any point in the programs lifecycle - if we just allow it to
# fail, then the user will need to restart the containers after adding tokens
logger.debug(
"Missing Slack Bot tokens - waiting 60 seconds and trying again"
)
if socket_client:
socket_client.disconnect()
time.sleep(60)

View File

@@ -12,7 +12,7 @@ from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.models.blocks import Block
from slack_sdk.models.metadata import Metadata
from slack_sdk.socket_mode import SocketModeClient
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_TELEMETRY
from danswer.configs.constants import ID_SEPARATOR
@@ -31,7 +31,7 @@ from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import SlackTextCleaner
from danswer.danswerbot.slack.constants import FeedbackVisibility
from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.users import get_user_by_email
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llms
@@ -489,9 +489,7 @@ def read_slack_thread(
return thread_messages
def slack_usage_report(
action: str, sender_id: str | None, client: WebClient, tenant_id: str | None
) -> None:
def slack_usage_report(action: str, sender_id: str | None, client: WebClient) -> None:
if DISABLE_TELEMETRY:
return
@@ -503,7 +501,7 @@ def slack_usage_report(
logger.warning("Unable to find sender email")
if sender_email is not None:
with get_session_with_tenant(tenant_id) as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
danswer_user = get_user_by_email(email=sender_email, db_session=db_session)
optional_telemetry(
@@ -579,9 +577,3 @@ def get_feedback_visibility() -> FeedbackVisibility:
return FeedbackVisibility(DANSWER_BOT_FEEDBACK_VISIBILITY.lower())
except ValueError:
return FeedbackVisibility.PRIVATE
class TenantSocketModeClient(SocketModeClient):
def __init__(self, tenant_id: str | None, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.tenant_id = tenant_id

View File

@@ -10,9 +10,7 @@ from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDataba
from sqlalchemy import func
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import Session
from danswer.auth.invited_users import get_invited_users
from danswer.auth.schemas import UserRole
from danswer.db.engine import get_async_session
from danswer.db.engine import get_async_session_with_tenant
@@ -35,20 +33,10 @@ def get_default_admin_user_emails() -> list[str]:
return get_default_admin_user_emails_fn()
def get_total_users(db_session: Session) -> int:
"""
Returns the total number of users in the system.
This is the sum of users and invited users.
"""
user_count = db_session.query(User).count()
invited_users = len(get_invited_users())
return user_count + invited_users
async def get_user_count() -> int:
async with get_async_session_with_tenant() as session:
async with get_async_session_with_tenant() as asession:
stmt = select(func.count(User.id))
result = await session.execute(stmt)
result = await asession.execute(stmt)
user_count = result.scalar()
if user_count is None:
raise RuntimeError("Was not able to fetch the user count.")

View File

@@ -43,7 +43,7 @@ logger = setup_logger()
def get_chat_session_by_id(
chat_session_id: UUID,
chat_session_id: int,
user_id: UUID | None,
db_session: Session,
include_deleted: bool = False,
@@ -87,9 +87,9 @@ def get_chat_sessions_by_slack_thread_id(
def get_valid_messages_from_query_sessions(
chat_session_ids: list[UUID],
chat_session_ids: list[int],
db_session: Session,
) -> dict[UUID, str]:
) -> dict[int, str]:
user_message_subquery = (
select(
ChatMessage.chat_session_id, func.min(ChatMessage.id).label("user_msg_id")
@@ -196,7 +196,7 @@ def delete_orphaned_search_docs(db_session: Session) -> None:
def delete_messages_and_files_from_chat_session(
chat_session_id: UUID, db_session: Session
chat_session_id: int, db_session: Session
) -> None:
# Select messages older than cutoff_time with files
messages_with_files = db_session.execute(
@@ -253,7 +253,7 @@ def create_chat_session(
def update_chat_session(
db_session: Session,
user_id: UUID | None,
chat_session_id: UUID,
chat_session_id: int,
description: str | None = None,
sharing_status: ChatSessionSharedStatus | None = None,
) -> ChatSession:
@@ -276,7 +276,7 @@ def update_chat_session(
def delete_chat_session(
user_id: UUID | None,
chat_session_id: UUID,
chat_session_id: int,
db_session: Session,
hard_delete: bool = HARD_DELETE_CHATS,
) -> None:
@@ -337,7 +337,7 @@ def get_chat_message(
def get_chat_messages_by_sessions(
chat_session_ids: list[UUID],
chat_session_ids: list[int],
user_id: UUID | None,
db_session: Session,
skip_permission_check: bool = False,
@@ -370,7 +370,7 @@ def get_search_docs_for_chat_message(
def get_chat_messages_by_session(
chat_session_id: UUID,
chat_session_id: int,
user_id: UUID | None,
db_session: Session,
skip_permission_check: bool = False,
@@ -397,7 +397,7 @@ def get_chat_messages_by_session(
def get_or_create_root_message(
chat_session_id: UUID,
chat_session_id: int,
db_session: Session,
) -> ChatMessage:
try:
@@ -433,7 +433,7 @@ def get_or_create_root_message(
def reserve_message_id(
db_session: Session,
chat_session_id: UUID,
chat_session_id: int,
parent_message: int,
message_type: MessageType,
) -> int:
@@ -460,7 +460,7 @@ def reserve_message_id(
def create_new_chat_message(
chat_session_id: UUID,
chat_session_id: int,
parent_message: ChatMessage,
message: str,
prompt_id: int | None,

View File

@@ -248,7 +248,7 @@ def create_initial_default_connector(db_session: Session) -> None:
logger.warning(
"Default connector does not have expected values. Updating to proper state."
)
# Ensure default connector has correct values
# Ensure default connector has correct valuesg
default_connector.source = DocumentSource.INGESTION_API
default_connector.input_type = InputType.LOAD_STATE
default_connector.refresh_freq = None

View File

@@ -241,7 +241,8 @@ def create_credential(
curator_public=credential_data.curator_public,
)
db_session.add(credential)
db_session.flush() # This ensures the credential gets an IDcredentials
db_session.flush() # This ensures the credential gets an ID
_relate_credential_to_user_groups__no_commit(
db_session=db_session,
credential_id=credential.id,

View File

@@ -398,7 +398,7 @@ def mark_document_set_as_to_be_deleted(
def delete_document_set_cc_pair_relationship__no_commit(
connector_id: int, credential_id: int, db_session: Session
) -> int:
) -> None:
"""Deletes all rows from DocumentSet__ConnectorCredentialPair where the
connector_credential_pair_id matches the given cc_pair_id."""
delete_stmt = delete(DocumentSet__ConnectorCredentialPair).where(
@@ -409,8 +409,7 @@ def delete_document_set_cc_pair_relationship__no_commit(
== ConnectorCredentialPair.id,
)
)
result = db_session.execute(delete_stmt)
return result.rowcount # type: ignore
db_session.execute(delete_stmt)
def fetch_document_sets(

View File

@@ -295,32 +295,30 @@ async def get_async_session_with_tenant(
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""Generate a database session bound to a connection with the appropriate tenant schema set."""
"""Generate a database session with the appropriate tenant schema set."""
engine = get_sqlalchemy_engine()
if tenant_id is None:
tenant_id = current_tenant_id.get()
else:
current_tenant_id.set(tenant_id)
event.listen(engine, "checkout", set_search_path_on_checkout)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
# Establish a raw connection
# Establish a raw connection without starting a transaction
with engine.connect() as connection:
# Access the raw DBAPI connection and set the search_path
# Access the raw DBAPI connection
dbapi_connection = connection.connection
# Set the search_path outside of any transaction
# Execute SET search_path outside of any transaction
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
cursor.execute(f'SET search_path TO "{tenant_id}"')
# Optionally verify the search_path was set correctly
cursor.execute("SHOW search_path")
cursor.fetchone()
finally:
cursor.close()
# Bind the session to the connection
# Proceed to create a session using the connection
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
@@ -334,18 +332,6 @@ def get_session_with_tenant(
cursor.close()
def set_search_path_on_checkout(
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
) -> None:
tenant_id = current_tenant_id.get()
if tenant_id and is_valid_schema_name(tenant_id):
with dbapi_conn.cursor() as cursor:
cursor.execute(f'SET search_path TO "{tenant_id}"')
logger.debug(
f"Set search_path to {tenant_id} for connection {connection_record}"
)
def get_session_generator_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
@@ -360,7 +346,6 @@ def get_session() -> Generator[Session, None, None]:
raise HTTPException(status_code=401, detail="User must authenticate")
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as session:
if MULTI_TENANT:
if not is_valid_schema_name(tenant_id):

View File

@@ -1,6 +1,4 @@
from collections.abc import Sequence
from datetime import datetime
from datetime import timezone
from sqlalchemy import and_
from sqlalchemy import delete
@@ -21,6 +19,8 @@ from danswer.db.models import SearchSettings
from danswer.server.documents.models import ConnectorCredentialPair
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
logger = setup_logger()
@@ -66,7 +66,7 @@ def create_index_attempt(
return new_attempt.id
def get_in_progress_index_attempts(
def get_inprogress_index_attempts(
connector_id: int | None,
db_session: Session,
) -> list[IndexAttempt]:
@@ -81,15 +81,13 @@ def get_in_progress_index_attempts(
return list(incomplete_attempts.all())
def get_all_index_attempts_by_status(
status: IndexingStatus, db_session: Session
) -> list[IndexAttempt]:
def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]:
"""This eagerly loads the connector and credential so that the db_session can be expired
before running long-living indexing jobs, which causes increasing memory usage.
Results are ordered by time_created (oldest to newest)."""
stmt = select(IndexAttempt)
stmt = stmt.where(IndexAttempt.status == status)
stmt = stmt.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
stmt = stmt.order_by(IndexAttempt.time_created)
stmt = stmt.options(
joinedload(IndexAttempt.connector_credential_pair).joinedload(
@@ -103,92 +101,31 @@ def get_all_index_attempts_by_status(
return list(new_attempts.all())
def transition_attempt_to_in_progress(
index_attempt_id: int,
db_session: Session,
) -> IndexAttempt:
"""Locks the row when we try to update"""
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt_id)
.with_for_update()
).scalar_one()
if attempt is None:
raise RuntimeError(
f"Unable to find IndexAttempt for ID '{index_attempt_id}'"
)
if attempt.status != IndexingStatus.NOT_STARTED:
raise RuntimeError(
f"Indexing attempt with ID '{index_attempt_id}' is not in NOT_STARTED status. "
f"Current status is '{attempt.status}'."
)
attempt.status = IndexingStatus.IN_PROGRESS
attempt.time_started = attempt.time_started or func.now() # type: ignore
db_session.commit()
return attempt
except Exception:
db_session.rollback()
logger.exception("transition_attempt_to_in_progress exceptioned.")
raise
def mark_attempt_in_progress(
index_attempt: IndexAttempt,
db_session: Session,
) -> None:
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.with_for_update()
).scalar_one()
attempt.status = IndexingStatus.IN_PROGRESS
attempt.time_started = index_attempt.time_started or func.now() # type: ignore
db_session.commit()
except Exception:
db_session.rollback()
raise
index_attempt.status = IndexingStatus.IN_PROGRESS
index_attempt.time_started = index_attempt.time_started or func.now() # type: ignore
db_session.commit()
def mark_attempt_succeeded(
index_attempt: IndexAttempt,
db_session: Session,
) -> None:
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.with_for_update()
).scalar_one()
attempt.status = IndexingStatus.SUCCESS
db_session.commit()
except Exception:
db_session.rollback()
raise
index_attempt.status = IndexingStatus.SUCCESS
db_session.add(index_attempt)
db_session.commit()
def mark_attempt_partially_succeeded(
index_attempt: IndexAttempt,
db_session: Session,
) -> None:
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.with_for_update()
).scalar_one()
attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS
db_session.commit()
except Exception:
db_session.rollback()
raise
index_attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS
db_session.add(index_attempt)
db_session.commit()
def mark_attempt_failed(
@@ -197,22 +134,14 @@ def mark_attempt_failed(
failure_reason: str = "Unknown",
full_exception_trace: str | None = None,
) -> None:
try:
attempt = db_session.execute(
select(IndexAttempt)
.where(IndexAttempt.id == index_attempt.id)
.with_for_update()
).scalar_one()
index_attempt.status = IndexingStatus.FAILED
index_attempt.error_msg = failure_reason
index_attempt.full_exception_trace = full_exception_trace
db_session.add(index_attempt)
db_session.commit()
if not attempt.time_started:
attempt.time_started = datetime.now(timezone.utc)
attempt.status = IndexingStatus.FAILED
attempt.error_msg = failure_reason
attempt.full_exception_trace = full_exception_trace
db_session.commit()
except Exception:
db_session.rollback()
raise
source = index_attempt.connector_credential_pair.connector.source
optional_telemetry(record_type=RecordType.FAILURE, data={"connector": source})
def update_docs_indexed(
@@ -506,13 +435,14 @@ def cancel_indexing_attempts_for_ccpair(
db_session.execute(stmt)
db_session.commit()
def cancel_indexing_attempts_past_model(
db_session: Session,
) -> None:
"""Stops all indexing attempts that are in progress or not started for
any embedding model that not present/future"""
db_session.execute(
update(IndexAttempt)
.where(
@@ -525,6 +455,8 @@ def cancel_indexing_attempts_past_model(
.values(status=IndexingStatus.FAILED)
)
db_session.commit()
def count_unique_cc_pairs_with_successful_index_attempts(
search_settings_id: int | None,

View File

@@ -5,12 +5,9 @@ from typing import Any
from typing import Literal
from typing import NotRequired
from typing import Optional
from uuid import uuid4
from typing_extensions import TypedDict # noreorder
from uuid import UUID
from sqlalchemy.dialects.postgresql import UUID as PGUUID
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
@@ -60,7 +57,6 @@ from danswer.llm.override_models import PromptOverride
from danswer.search.enums import RecencyBiasSetting
from danswer.utils.encryption import decrypt_bytes_to_string
from danswer.utils.encryption import encrypt_string_to_bytes
from danswer.utils.headers import HeaderItemDict
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import RerankerProvider
@@ -235,9 +231,6 @@ class Notification(Base):
first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
user: Mapped[User] = relationship("User", back_populates="notifications")
additional_data: Mapped[dict | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
"""
@@ -622,7 +615,6 @@ class SearchSettings(Base):
normalize: Mapped[bool] = mapped_column(Boolean)
query_prefix: Mapped[str | None] = mapped_column(String, nullable=True)
passage_prefix: Mapped[str | None] = mapped_column(String, nullable=True)
status: Mapped[IndexModelStatus] = mapped_column(
Enum(IndexModelStatus, native_enum=False)
)
@@ -678,20 +670,6 @@ class SearchSettings(Base):
return f"<EmbeddingModel(model_name='{self.model_name}', status='{self.status}',\
cloud_provider='{self.cloud_provider.provider_type if self.cloud_provider else 'None'}')>"
@property
def api_version(self) -> str | None:
return (
self.cloud_provider.api_version if self.cloud_provider is not None else None
)
@property
def deployment_name(self) -> str | None:
return (
self.cloud_provider.deployment_name
if self.cloud_provider is not None
else None
)
@property
def api_url(self) -> str | None:
return self.cloud_provider.api_url if self.cloud_provider is not None else None
@@ -927,9 +905,7 @@ class ToolCall(Base):
class ChatSession(Base):
__tablename__ = "chat_session"
id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), primary_key=True, default=uuid4
)
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
@@ -999,9 +975,7 @@ class ChatMessage(Base):
__tablename__ = "chat_message"
id: Mapped[int] = mapped_column(primary_key=True)
chat_session_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), ForeignKey("chat_session.id")
)
chat_session_id: Mapped[int] = mapped_column(ForeignKey("chat_session.id"))
alternate_assistant_id = mapped_column(
Integer, ForeignKey("persona.id"), nullable=True
@@ -1190,9 +1164,6 @@ class CloudEmbeddingProvider(Base):
)
api_url: Mapped[str | None] = mapped_column(String, nullable=True)
api_key: Mapped[str | None] = mapped_column(EncryptedString())
api_version: Mapped[str | None] = mapped_column(String, nullable=True)
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
search_settings: Mapped[list["SearchSettings"]] = relationship(
"SearchSettings",
back_populates="cloud_provider",
@@ -1292,7 +1263,7 @@ class Tool(Base):
openapi_schema: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
custom_headers: Mapped[list[HeaderItemDict] | None] = mapped_column(
custom_headers: Mapped[list[dict[str, str]] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# user who created / owns the tool. Will be None for built-in tools.

View File

@@ -1,5 +1,3 @@
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.orm import Session
from sqlalchemy.sql import func
@@ -10,37 +8,16 @@ from danswer.db.models import User
def create_notification(
user_id: UUID | None,
user: User | None,
notif_type: NotificationType,
db_session: Session,
additional_data: dict | None = None,
) -> Notification:
# Check if an undismissed notification of the same type and data exists
existing_notification = (
db_session.query(Notification)
.filter_by(
user_id=user_id,
notif_type=notif_type,
dismissed=False,
)
.filter(Notification.additional_data == additional_data)
.first()
)
if existing_notification:
# Update the last_shown timestamp
existing_notification.last_shown = func.now()
db_session.commit()
return existing_notification
# Create a new notification if none exists
notification = Notification(
user_id=user_id,
user_id=user.id if user else None,
notif_type=notif_type,
dismissed=False,
last_shown=func.now(),
first_shown=func.now(),
additional_data=additional_data,
)
db_session.add(notification)
db_session.commit()

View File

@@ -12,7 +12,7 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import IndexAttempt
@@ -152,7 +152,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
if db_session is None:
with get_session_with_tenant() as db_session:
with Session(get_sqlalchemy_engine()) as db_session:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_current_search_settings(db_session)

View File

@@ -1,6 +1,5 @@
from sqlalchemy.orm import Session
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
@@ -9,18 +8,16 @@ from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import (
count_unique_cc_pairs_with_successful_index_attempts,
)
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_search_settings_status
from danswer.key_value_store.factory import get_kv_store
from danswer.utils.logger import setup_logger
logger = setup_logger()
def check_index_swap(db_session: Session) -> SearchSettings | None:
def check_index_swap(db_session: Session) -> None:
"""Get count of cc-pairs and count of successful index_attempts for the
new model grouped by connector + credential, if it's the same, then assume
new index is done building. If so, swap the indices and expire the old one."""
@@ -30,7 +27,7 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
search_settings = get_secondary_search_settings(db_session)
if not search_settings:
return None
return
unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts(
search_settings_id=search_settings.id, db_session=db_session
@@ -66,7 +63,3 @@ def check_index_swap(db_session: Session) -> SearchSettings | None:
# Recount aggregates
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)
if MULTI_TENANT:
return now_old_search_settings
return None

View File

@@ -1,5 +1,4 @@
from typing import Any
from typing import cast
from uuid import UUID
from sqlalchemy import select
@@ -7,7 +6,6 @@ from sqlalchemy.orm import Session
from danswer.db.models import Tool
from danswer.server.features.tool.models import Header
from danswer.utils.headers import HeaderItemDict
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -69,9 +67,7 @@ def update_tool(
if user_id is not None:
tool.user_id = user_id
if custom_headers is not None:
tool.custom_headers = [
cast(HeaderItemDict, header.model_dump()) for header in custom_headers
]
tool.custom_headers = [header.dict() for header in custom_headers]
db_session.commit()
return tool

View File

@@ -1,6 +1,5 @@
from sqlalchemy.orm import Session
from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.vespa.index import VespaIndex
@@ -15,9 +14,7 @@ def get_default_document_index(
index both need to be updated, updates are applied to both indices"""
# Currently only supporting Vespa
return VespaIndex(
index_name=primary_index_name,
secondary_index_name=secondary_index_name,
multitenant=MULTI_TENANT,
index_name=primary_index_name, secondary_index_name=secondary_index_name
)

View File

@@ -127,17 +127,6 @@ class Verifiable(abc.ABC):
"""
raise NotImplementedError
@staticmethod
@abc.abstractmethod
def register_multitenant_indices(
indices: list[str],
embedding_dims: list[int],
) -> None:
"""
Register multitenant indices with the document index.
"""
raise NotImplementedError
class Indexable(abc.ABC):
"""

View File

@@ -1,6 +1,5 @@
schema DANSWER_CHUNK_NAME {
document DANSWER_CHUNK_NAME {
TENANT_ID_REPLACEMENT
# Not to be confused with the UUID generated for this chunk which is called documentid by default
field document_id type string {
indexing: summary | attribute

View File

@@ -7,13 +7,11 @@ from datetime import timezone
from typing import Any
from typing import cast
import httpx
import requests
from retry import retry
from danswer.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.document_index.vespa.shared_utils.utils import get_vespa_http_client
from danswer.document_index.vespa.shared_utils.vespa_request_builders import (
build_vespa_filters,
)
@@ -295,12 +293,13 @@ def query_vespa(
if LOG_VESPA_TIMING_INFORMATION
else {},
)
try:
with get_vespa_http_client() as http_client:
response = http_client.post(SEARCH_ENDPOINT, json=params)
response.raise_for_status()
except httpx.HTTPError as e:
response = requests.post(
SEARCH_ENDPOINT,
json=params,
)
response.raise_for_status()
except requests.HTTPError as e:
request_info = f"Headers: {response.request.headers}\nPayload: {params}"
response_info = (
f"Status Code: {response.status_code}\n"
@@ -313,10 +312,9 @@ def query_vespa(
f"{response_info}\n"
f"Exception: {e}"
)
raise httpx.HTTPError(error_base) from e
raise requests.HTTPError(error_base) from e
response_json: dict[str, Any] = response.json()
if LOG_VESPA_TIMING_INFORMATION:
logger.debug("Vespa timing info: %s", response_json.get("timing"))
hits = response_json["root"].get("children", [])

View File

@@ -4,20 +4,18 @@ import logging
import os
import re
import time
import urllib
import zipfile
from dataclasses import dataclass
from datetime import datetime
from datetime import timedelta
from typing import BinaryIO
from typing import cast
from typing import List
import httpx # type: ignore
import requests # type: ignore
import httpx
import requests
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import VESPA_REQUEST_TIMEOUT
from danswer.configs.chat_configs import DOC_TIME_DECAY
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.chat_configs import TITLE_CONTENT_RATIO
@@ -42,7 +40,6 @@ from danswer.document_index.vespa.indexing_utils import clean_chunk_id_copy
from danswer.document_index.vespa.indexing_utils import (
get_existing_documents_from_chunks,
)
from danswer.document_index.vespa.shared_utils.utils import get_vespa_http_client
from danswer.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
@@ -61,8 +58,6 @@ from danswer.document_index.vespa_constants import DOCUMENT_SETS
from danswer.document_index.vespa_constants import HIDDEN
from danswer.document_index.vespa_constants import NUM_THREADS
from danswer.document_index.vespa_constants import SEARCH_THREAD_NUMBER_PAT
from danswer.document_index.vespa_constants import TENANT_ID_PAT
from danswer.document_index.vespa_constants import TENANT_ID_REPLACEMENT
from danswer.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
from danswer.document_index.vespa_constants import VESPA_DIM_REPLACEMENT_PAT
from danswer.document_index.vespa_constants import VESPA_TIMEOUT
@@ -75,7 +70,6 @@ from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
from shared_configs.model_server_models import Embedding
logger = setup_logger()
# Set the logging level to WARNING to ignore INFO and DEBUG logs
@@ -99,7 +93,7 @@ def in_memory_zip_from_file_bytes(file_contents: dict[str, bytes]) -> BinaryIO:
return zip_buffer
def _create_document_xml_lines(doc_names: list[str | None] | list[str]) -> str:
def _create_document_xml_lines(doc_names: list[str | None]) -> str:
doc_lines = [
f'<document type="{doc_name}" mode="index" />'
for doc_name in doc_names
@@ -124,28 +118,15 @@ def add_ngrams_to_schema(schema_content: str) -> str:
class VespaIndex(DocumentIndex):
def __init__(
self,
index_name: str,
secondary_index_name: str | None,
multitenant: bool = False,
) -> None:
def __init__(self, index_name: str, secondary_index_name: str | None) -> None:
self.index_name = index_name
self.secondary_index_name = secondary_index_name
self.multitenant = multitenant
self.http_client = get_vespa_http_client()
def ensure_indices_exist(
self,
index_embedding_dim: int,
secondary_index_embedding_dim: int | None,
) -> None:
if MULTI_TENANT:
logger.info(
"Skipping Vespa index seup for multitenant (would wipe all indices)"
)
return None
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
logger.info(f"Deploying Vespa application package to {deploy_url}")
@@ -193,14 +174,10 @@ class VespaIndex(DocumentIndex):
with open(schema_file, "r") as schema_f:
schema_template = schema_f.read()
schema_template = schema_template.replace(TENANT_ID_PAT, "")
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, self.index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
schema = schema.replace(TENANT_ID_PAT, "")
zip_dict[f"schemas/{schema_names[0]}.sd"] = schema.encode("utf-8")
if self.secondary_index_name:
@@ -218,91 +195,6 @@ class VespaIndex(DocumentIndex):
f"Failed to prepare Vespa Danswer Index. Response: {response.text}"
)
@staticmethod
def register_multitenant_indices(
indices: list[str],
embedding_dims: list[int],
) -> None:
if not MULTI_TENANT:
raise ValueError("Multi-tenant is not enabled")
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
logger.info(f"Deploying Vespa application package to {deploy_url}")
vespa_schema_path = os.path.join(
os.getcwd(), "danswer", "document_index", "vespa", "app_config"
)
schema_file = os.path.join(vespa_schema_path, "schemas", "danswer_chunk.sd")
services_file = os.path.join(vespa_schema_path, "services.xml")
overrides_file = os.path.join(vespa_schema_path, "validation-overrides.xml")
with open(services_file, "r") as services_f:
services_template = services_f.read()
# Generate schema names from index settings
schema_names = [index_name for index_name in indices]
full_schemas = schema_names
doc_lines = _create_document_xml_lines(full_schemas)
services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines)
services = services.replace(
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
)
kv_store = get_kv_store()
needs_reindexing = False
try:
needs_reindexing = cast(bool, kv_store.load(KV_REINDEX_KEY))
except Exception:
logger.debug("Could not load the reindexing flag. Using ngrams")
with open(overrides_file, "r") as overrides_f:
overrides_template = overrides_f.read()
# Vespa requires an override to erase data including the indices we're no longer using
# It also has a 30 day cap from current so we set it to 7 dynamically
now = datetime.now()
date_in_7_days = now + timedelta(days=7)
formatted_date = date_in_7_days.strftime("%Y-%m-%d")
overrides = overrides_template.replace(DATE_REPLACEMENT, formatted_date)
zip_dict = {
"services.xml": services.encode("utf-8"),
"validation-overrides.xml": overrides.encode("utf-8"),
}
with open(schema_file, "r") as schema_f:
schema_template = schema_f.read()
for i, index_name in enumerate(indices):
embedding_dim = embedding_dims[i]
logger.info(
f"Creating index: {index_name} with embedding dimension: {embedding_dim}"
)
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim))
schema = schema.replace(
TENANT_ID_PAT, TENANT_ID_REPLACEMENT if MULTI_TENANT else ""
)
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
zip_dict[f"schemas/{index_name}.sd"] = schema.encode("utf-8")
zip_file = in_memory_zip_from_file_bytes(zip_dict)
headers = {"Content-Type": "application/zip"}
response = requests.post(deploy_url, headers=headers, data=zip_file)
if response.status_code != 200:
raise RuntimeError(
f"Failed to prepare Vespa Danswer Indexes. Response: {response.text}"
)
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
@@ -320,7 +212,7 @@ class VespaIndex(DocumentIndex):
# indexing / updates / deletes since we have to make a large volume of requests.
with (
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
get_vespa_http_client() as http_client,
httpx.Client(http2=True, timeout=VESPA_REQUEST_TIMEOUT) as http_client,
):
# Check for existing documents, existing documents need to have all of their chunks deleted
# prior to indexing as the document size (num chunks) may have shrunk
@@ -348,7 +240,6 @@ class VespaIndex(DocumentIndex):
chunks=chunk_batch,
index_name=self.index_name,
http_client=http_client,
multitenant=self.multitenant,
executor=executor,
)
@@ -383,10 +274,9 @@ class VespaIndex(DocumentIndex):
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficient for
# indexing / updates / deletes since we have to make a large volume of requests.
with (
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
get_vespa_http_client() as http_client,
httpx.Client(http2=True, timeout=VESPA_REQUEST_TIMEOUT) as http_client,
):
for update_batch in batch_generator(updates, batch_size):
future_to_document_id = {
@@ -530,7 +420,7 @@ class VespaIndex(DocumentIndex):
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
with get_vespa_http_client() as http_client:
with httpx.Client(http2=True, timeout=VESPA_REQUEST_TIMEOUT) as http_client:
for index_name in index_names:
params = httpx.QueryParams(
{
@@ -586,7 +476,7 @@ class VespaIndex(DocumentIndex):
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficial for
# indexing / updates / deletes since we have to make a large volume of requests.
with get_vespa_http_client() as http_client:
with httpx.Client(http2=True, timeout=VESPA_REQUEST_TIMEOUT) as http_client:
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
@@ -614,7 +504,7 @@ class VespaIndex(DocumentIndex):
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
with get_vespa_http_client() as http_client:
with httpx.Client(http2=True, timeout=VESPA_REQUEST_TIMEOUT) as http_client:
for index_name in index_names:
params = httpx.QueryParams(
{
@@ -754,158 +644,3 @@ class VespaIndex(DocumentIndex):
}
return query_vespa(params)
@classmethod
def delete_entries_by_tenant_id(cls, tenant_id: str, index_name: str) -> None:
"""
Deletes all entries in the specified index with the given tenant_id.
Parameters:
tenant_id (str): The tenant ID whose documents are to be deleted.
index_name (str): The name of the index from which to delete documents.
"""
logger.info(
f"Deleting entries with tenant_id: {tenant_id} from index: {index_name}"
)
# Step 1: Retrieve all document IDs with the given tenant_id
document_ids = cls._get_all_document_ids_by_tenant_id(tenant_id, index_name)
if not document_ids:
logger.info(
f"No documents found with tenant_id: {tenant_id} in index: {index_name}"
)
return
# Step 2: Delete documents in batches
delete_requests = [
_VespaDeleteRequest(document_id=doc_id, index_name=index_name)
for doc_id in document_ids
]
cls._apply_deletes_batched(delete_requests)
@classmethod
def _get_all_document_ids_by_tenant_id(
cls, tenant_id: str, index_name: str
) -> List[str]:
"""
Retrieves all document IDs with the specified tenant_id, handling pagination.
Parameters:
tenant_id (str): The tenant ID to search for.
index_name (str): The name of the index to search in.
Returns:
List[str]: A list of document IDs matching the tenant_id.
"""
offset = 0
limit = 1000 # Vespa's maximum hits per query
document_ids = []
logger.debug(
f"Starting document ID retrieval for tenant_id: {tenant_id} in index: {index_name}"
)
while True:
# Construct the query to fetch document IDs
query_params = {
"yql": f'select id from sources * where tenant_id contains "{tenant_id}";',
"offset": str(offset),
"hits": str(limit),
"timeout": "10s",
"format": "json",
"summary": "id",
}
url = f"{VESPA_APPLICATION_ENDPOINT}/search/"
logger.debug(
f"Querying for document IDs with tenant_id: {tenant_id}, offset: {offset}"
)
with get_vespa_http_client(no_timeout=True) as http_client:
response = http_client.get(url, params=query_params)
response.raise_for_status()
search_result = response.json()
hits = search_result.get("root", {}).get("children", [])
if not hits:
break
for hit in hits:
doc_id = hit.get("id")
if doc_id:
document_ids.append(doc_id)
offset += limit # Move to the next page
logger.debug(
f"Retrieved {len(document_ids)} document IDs for tenant_id: {tenant_id}"
)
return document_ids
@classmethod
def _apply_deletes_batched(
cls,
delete_requests: List["_VespaDeleteRequest"],
batch_size: int = BATCH_SIZE,
) -> None:
"""
Deletes documents in batches using multiple threads.
Parameters:
delete_requests (List[_VespaDeleteRequest]): The list of delete requests.
batch_size (int): The number of documents to delete in each batch.
"""
def _delete_document(
delete_request: "_VespaDeleteRequest", http_client: httpx.Client
) -> None:
logger.debug(f"Deleting document with ID {delete_request.document_id}")
response = http_client.delete(
delete_request.url,
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
logger.debug(f"Starting batch deletion for {len(delete_requests)} documents")
with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
with get_vespa_http_client(no_timeout=True) as http_client:
for batch_start in range(0, len(delete_requests), batch_size):
batch = delete_requests[batch_start : batch_start + batch_size]
future_to_document_id = {
executor.submit(
_delete_document,
delete_request,
http_client,
): delete_request.document_id
for delete_request in batch
}
for future in concurrent.futures.as_completed(
future_to_document_id
):
doc_id = future_to_document_id[future]
try:
future.result()
logger.debug(f"Successfully deleted document: {doc_id}")
except httpx.HTTPError as e:
logger.error(f"Failed to delete document {doc_id}: {e}")
# Optionally, implement retry logic or error handling here
logger.info("Batch deletion completed")
class _VespaDeleteRequest:
def __init__(self, document_id: str, index_name: str) -> None:
self.document_id = document_id
# Encode the document ID to ensure it's safe for use in the URL
encoded_doc_id = urllib.parse.quote_plus(self.document_id)
self.url = (
f"{VESPA_APPLICATION_ENDPOINT}/document/v1/"
f"{index_name}/{index_name}/docid/{encoded_doc_id}"
)

View File

@@ -37,7 +37,6 @@ from danswer.document_index.vespa_constants import SEMANTIC_IDENTIFIER
from danswer.document_index.vespa_constants import SKIP_TITLE_EMBEDDING
from danswer.document_index.vespa_constants import SOURCE_LINKS
from danswer.document_index.vespa_constants import SOURCE_TYPE
from danswer.document_index.vespa_constants import TENANT_ID
from danswer.document_index.vespa_constants import TITLE
from danswer.document_index.vespa_constants import TITLE_EMBEDDING
from danswer.indexing.models import DocMetadataAwareIndexChunk
@@ -66,8 +65,6 @@ def _does_document_exist(
raise RuntimeError(
f"Unexpected fetch document by ID value from Vespa "
f"with error {doc_fetch_response.status_code}"
f"Index name: {index_name}"
f"Doc chunk id: {doc_chunk_id}"
)
return True
@@ -120,10 +117,7 @@ def get_existing_documents_from_chunks(
@retry(tries=3, delay=1, backoff=2)
def _index_vespa_chunk(
chunk: DocMetadataAwareIndexChunk,
index_name: str,
http_client: httpx.Client,
multitenant: bool,
chunk: DocMetadataAwareIndexChunk, index_name: str, http_client: httpx.Client
) -> None:
json_header = {
"Content-Type": "application/json",
@@ -180,10 +174,6 @@ def _index_vespa_chunk(
BOOST: chunk.boost,
}
if multitenant:
if chunk.tenant_id:
vespa_document_fields[TENANT_ID] = chunk.tenant_id
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_chunk_id}"
logger.debug(f'Indexing to URL "{vespa_url}"')
res = http_client.post(
@@ -202,7 +192,6 @@ def batch_index_vespa_chunks(
chunks: list[DocMetadataAwareIndexChunk],
index_name: str,
http_client: httpx.Client,
multitenant: bool,
executor: concurrent.futures.ThreadPoolExecutor | None = None,
) -> None:
external_executor = True
@@ -213,9 +202,7 @@ def batch_index_vespa_chunks(
try:
chunk_index_future = {
executor.submit(
_index_vespa_chunk, chunk, index_name, http_client, multitenant
): chunk
executor.submit(_index_vespa_chunk, chunk, index_name, http_client): chunk
for chunk in chunks
}
for future in concurrent.futures.as_completed(chunk_index_future):

View File

@@ -1,12 +1,4 @@
import re
from typing import cast
import httpx
from danswer.configs.app_configs import MANAGED_VESPA
from danswer.configs.app_configs import VESPA_CLOUD_CERT_PATH
from danswer.configs.app_configs import VESPA_CLOUD_KEY_PATH
from danswer.configs.app_configs import VESPA_REQUEST_TIMEOUT
# NOTE: This does not seem to be used in reality despite the Vespa Docs pointing to this code
# See here for reference: https://docs.vespa.ai/en/documents.html
@@ -53,19 +45,3 @@ def remove_invalid_unicode_chars(text: str) -> str:
"[\x00-\x08\x0b\x0c\x0e-\x1F\uD800-\uDFFF\uFFFE\uFFFF]"
)
return _illegal_xml_chars_RE.sub("", text)
def get_vespa_http_client(no_timeout: bool = False) -> httpx.Client:
"""
Configure and return an HTTP client for communicating with Vespa,
including authentication if needed.
"""
return httpx.Client(
cert=cast(tuple[str, str], (VESPA_CLOUD_CERT_PATH, VESPA_CLOUD_KEY_PATH))
if MANAGED_VESPA
else None,
verify=False if not MANAGED_VESPA else True,
timeout=None if no_timeout else VESPA_REQUEST_TIMEOUT,
http2=True,
)

View File

@@ -12,7 +12,6 @@ from danswer.document_index.vespa_constants import DOCUMENT_SETS
from danswer.document_index.vespa_constants import HIDDEN
from danswer.document_index.vespa_constants import METADATA_LIST
from danswer.document_index.vespa_constants import SOURCE_TYPE
from danswer.document_index.vespa_constants import TENANT_ID
from danswer.search.models import IndexFilters
from danswer.utils.logger import setup_logger
@@ -54,9 +53,6 @@ def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) ->
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
if filters.tenant_id:
filter_str += f'({TENANT_ID} contains "{filters.tenant_id}") and '
# CAREFUL touching this one, currently there is no second ACL double-check post retrieval
if filters.access_control_list is not None:
filter_str += _build_or_filters(

View File

@@ -1,4 +1,3 @@
from danswer.configs.app_configs import VESPA_CLOUD_URL
from danswer.configs.app_configs import VESPA_CONFIG_SERVER_HOST
from danswer.configs.app_configs import VESPA_HOST
from danswer.configs.app_configs import VESPA_PORT
@@ -10,30 +9,17 @@ DANSWER_CHUNK_REPLACEMENT_PAT = "DANSWER_CHUNK_NAME"
DOCUMENT_REPLACEMENT_PAT = "DOCUMENT_REPLACEMENT"
SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER"
DATE_REPLACEMENT = "DATE_REPLACEMENT"
SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER"
TENANT_ID_PAT = "TENANT_ID_REPLACEMENT"
TENANT_ID_REPLACEMENT = """field tenant_id type string {
indexing: summary | attribute
rank: filter
attribute: fast-search
}"""
# config server
VESPA_CONFIG_SERVER_URL = (
VESPA_CLOUD_URL or f"http://{VESPA_CONFIG_SERVER_HOST}:{VESPA_TENANT_PORT}"
)
VESPA_CONFIG_SERVER_URL = f"http://{VESPA_CONFIG_SERVER_HOST}:{VESPA_TENANT_PORT}"
VESPA_APPLICATION_ENDPOINT = f"{VESPA_CONFIG_SERVER_URL}/application/v2"
# main search application
VESPA_APP_CONTAINER_URL = VESPA_CLOUD_URL or f"http://{VESPA_HOST}:{VESPA_PORT}"
VESPA_APP_CONTAINER_URL = f"http://{VESPA_HOST}:{VESPA_PORT}"
# danswer_chunk below is defined in vespa/app_configs/schemas/danswer_chunk.sd
DOCUMENT_ID_ENDPOINT = (
f"{VESPA_APP_CONTAINER_URL}/document/v1/default/{{index_name}}/docid"
)
SEARCH_ENDPOINT = f"{VESPA_APP_CONTAINER_URL}/search/"
NUM_THREADS = (
@@ -49,7 +35,7 @@ MAX_OR_CONDITIONS = 10
VESPA_TIMEOUT = "3s"
BATCH_SIZE = 128 # Specific to Vespa
TENANT_ID = "tenant_id"
DOCUMENT_ID = "document_id"
CHUNK_ID = "chunk_id"
BLURB = "blurb"

View File

@@ -208,9 +208,8 @@ def read_pdf_file(
# By user request, keep files that are unreadable just so they
# can be discoverable by title.
return "", metadata
elif pdf_reader.is_encrypted:
logger.warning("No Password available to decrypt pdf, returning empty")
return "", metadata
else:
logger.warning("No Password available to to decrypt pdf")
# Extract metadata from the PDF, removing leading '/' from keys if present
# This standardizes the metadata keys for consistency

View File

@@ -4,17 +4,11 @@ from dataclasses import dataclass
from typing import IO
import bs4
import trafilatura # type: ignore
from trafilatura.settings import use_config # type: ignore
from danswer.configs.app_configs import HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY
from danswer.configs.app_configs import PARSE_WITH_TRAFILATURA
from danswer.configs.app_configs import WEB_CONNECTOR_IGNORED_CLASSES
from danswer.configs.app_configs import WEB_CONNECTOR_IGNORED_ELEMENTS
from danswer.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
from danswer.utils.logger import setup_logger
logger = setup_logger()
MINTLIFY_UNWANTED = ["sticky", "hidden"]
@@ -53,18 +47,6 @@ def format_element_text(element_text: str, link_href: str | None) -> str:
return f"[{element_text_no_newlines}]({link_href})"
def parse_html_with_trafilatura(html_content: str) -> str:
"""Parse HTML content using trafilatura."""
config = use_config()
config.set("DEFAULT", "include_links", "True")
config.set("DEFAULT", "include_tables", "True")
config.set("DEFAULT", "include_images", "True")
config.set("DEFAULT", "include_formatting", "True")
extracted_text = trafilatura.extract(html_content, config=config)
return strip_excessive_newlines_and_spaces(extracted_text) if extracted_text else ""
def format_document_soup(
document: bs4.BeautifulSoup, table_cell_separator: str = "\t"
) -> str:
@@ -201,21 +183,7 @@ def web_html_cleanup(
for undesired_tag in additional_element_types_to_discard:
[tag.extract() for tag in soup.find_all(undesired_tag)]
soup_string = str(soup)
page_text = ""
if PARSE_WITH_TRAFILATURA:
try:
page_text = parse_html_with_trafilatura(soup_string)
if not page_text:
raise ValueError("Empty content returned by trafilatura.")
except Exception as e:
logger.info(f"Trafilatura parsing failed: {e}. Falling back on bs4.")
page_text = format_document_soup(soup)
else:
page_text = format_document_soup(soup)
# 200B is ZeroWidthSpace which we don't care for
cleaned_text = page_text.replace("\u200B", "")
page_text = format_document_soup(soup).replace("\u200B", "")
return ParsedHTML(title=title, cleaned_text=cleaned_text)
return ParsedHTML(title=title, cleaned_text=page_text)

View File

@@ -15,7 +15,7 @@ from danswer.indexing.models import DocAwareChunk
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import shared_precompare_cleanup
from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
# actually help quality at all
@@ -158,24 +158,6 @@ class Chunker:
else None
)
def _split_oversized_chunk(self, text: str, content_token_limit: int) -> list[str]:
"""
Splits the text into smaller chunks based on token count to ensure
no chunk exceeds the content_token_limit.
"""
tokens = self.tokenizer.tokenize(text)
chunks = []
start = 0
total_tokens = len(tokens)
while start < total_tokens:
end = min(start + content_token_limit, total_tokens)
token_chunk = tokens[start:end]
# Join the tokens to reconstruct the text
chunk_text = " ".join(token_chunk)
chunks.append(chunk_text)
start = end
return chunks
def _extract_blurb(self, text: str) -> str:
texts = self.blurb_splitter.split_text(text)
if not texts:
@@ -236,42 +218,14 @@ class Chunker:
chunk_text = ""
split_texts = self.chunk_splitter.split_text(section_text)
for i, split_text in enumerate(split_texts):
split_token_count = len(self.tokenizer.tokenize(split_text))
if STRICT_CHUNK_TOKEN_LIMIT:
split_token_count = len(self.tokenizer.tokenize(split_text))
if split_token_count > content_token_limit:
# Further split the oversized chunk
smaller_chunks = self._split_oversized_chunk(
split_text, content_token_limit
)
for i, small_chunk in enumerate(smaller_chunks):
chunks.append(
_create_chunk(
text=small_chunk,
links={0: section_link_text},
is_continuation=(i != 0),
)
)
else:
chunks.append(
_create_chunk(
text=split_text,
links={0: section_link_text},
)
)
else:
chunks.append(
_create_chunk(
text=split_text,
links={0: section_link_text},
is_continuation=(i != 0),
)
chunks.append(
_create_chunk(
text=split_text,
links={0: section_link_text},
is_continuation=(i != 0),
)
)
continue
current_token_count = len(self.tokenizer.tokenize(chunk_text))

View File

@@ -32,8 +32,6 @@ class IndexingEmbedder(ABC):
provider_type: EmbeddingProvider | None,
api_key: str | None,
api_url: str | None,
api_version: str | None,
deployment_name: str | None,
heartbeat: Heartbeat | None,
):
self.model_name = model_name
@@ -43,8 +41,6 @@ class IndexingEmbedder(ABC):
self.provider_type = provider_type
self.api_key = api_key
self.api_url = api_url
self.api_version = api_version
self.deployment_name = deployment_name
self.embedding_model = EmbeddingModel(
model_name=model_name,
@@ -54,8 +50,6 @@ class IndexingEmbedder(ABC):
api_key=api_key,
provider_type=provider_type,
api_url=api_url,
api_version=api_version,
deployment_name=deployment_name,
# The below are globally set, this flow always uses the indexing one
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
@@ -81,8 +75,6 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type: EmbeddingProvider | None = None,
api_key: str | None = None,
api_url: str | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
heartbeat: Heartbeat | None = None,
):
super().__init__(
@@ -93,8 +85,6 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type,
api_key,
api_url,
api_version,
deployment_name,
heartbeat,
)
@@ -203,7 +193,5 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
heartbeat=heartbeat,
)

View File

@@ -28,8 +28,7 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
class PgRedisKVStore(KeyValueStore):
def __init__(self) -> None:
tenant_id = current_tenant_id.get()
self.redis_client = get_redis_client(tenant_id=tenant_id)
self.redis_client = get_redis_client()
@contextmanager
def get_session(self) -> Iterator[Session]:

View File

@@ -109,7 +109,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
"arguments": json.dumps(tool_call["args"]),
},
"type": "function",
"index": tool_call.get("index", 0),
"index": 0, # only support a single tool call atm
}
for tool_call in message.tool_calls
]
@@ -158,13 +158,12 @@ def _convert_delta_to_message_chunk(
if tool_calls:
tool_call = tool_calls[0]
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
idx = tool_call.index
tool_call_chunk = ToolCallChunk(
name=tool_name,
id=tool_call.id,
args=tool_call.function.arguments,
index=idx,
index=0, # only support a single tool call atm
)
return AIMessageChunk(

View File

@@ -7,9 +7,9 @@ from danswer.db.llm import fetch_provider
from danswer.db.models import Persona
from danswer.llm.chat_llm import DefaultMultiLLM
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.headers import build_llm_extra_headers
from danswer.llm.interfaces import LLM
from danswer.llm.override_models import LLMOverride
from danswer.utils.headers import build_llm_extra_headers
def get_main_llm_from_tuple(

View File

@@ -0,0 +1,34 @@
from fastapi.datastructures import Headers
from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
def get_litellm_additional_request_headers(
headers: dict[str, str] | Headers
) -> dict[str, str]:
if not LITELLM_PASS_THROUGH_HEADERS:
return {}
pass_through_headers: dict[str, str] = {}
for key in LITELLM_PASS_THROUGH_HEADERS:
if key in headers:
pass_through_headers[key] = headers[key]
else:
# fastapi makes all header keys lowercase, handling that here
lowercase_key = key.lower()
if lowercase_key in headers:
pass_through_headers[lowercase_key] = headers[lowercase_key]
return pass_through_headers
def build_llm_extra_headers(
additional_headers: dict[str, str] | None = None
) -> dict[str, str]:
extra_headers: dict[str, str] = {}
if additional_headers:
extra_headers.update(additional_headers)
if LITELLM_EXTRA_HEADERS:
extra_headers.update(LITELLM_EXTRA_HEADERS)
return extra_headers

View File

@@ -5,7 +5,6 @@ from contextlib import asynccontextmanager
from typing import Any
from typing import cast
import sentry_sdk
import uvicorn
from fastapi import APIRouter
from fastapi import FastAPI
@@ -16,8 +15,6 @@ from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from httpx_oauth.clients.google import GoogleOAuth2
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from sqlalchemy.orm import Session
from danswer import __version__
@@ -57,7 +54,6 @@ from danswer.server.features.input_prompt.api import (
admin_router as admin_input_prompt_router,
)
from danswer.server.features.input_prompt.api import basic_router as input_prompt_router
from danswer.server.features.notifications.api import router as notification_router
from danswer.server.features.persona.api import admin_router as admin_persona_router
from danswer.server.features.persona.api import basic_router as persona_router
from danswer.server.features.prompt.api import basic_router as prompt_router
@@ -85,7 +81,6 @@ from danswer.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
from danswer.setup import setup_danswer
from danswer.setup import setup_multitenant_danswer
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import get_or_generate_uuid
from danswer.utils.telemetry import optional_telemetry
@@ -94,7 +89,6 @@ from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import CORS_ALLOWED_ORIGIN
from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
@@ -184,8 +178,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
# If we are multi-tenant, we need to only set up initial public tables
with Session(engine) as db_session:
setup_danswer(db_session)
else:
setup_multitenant_danswer()
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
yield
@@ -209,15 +201,6 @@ def get_application() -> FastAPI:
application = FastAPI(
title="Danswer Backend", version=__version__, lifespan=lifespan
)
if SENTRY_DSN:
sentry_sdk.init(
dsn=SENTRY_DSN,
integrations=[StarletteIntegration(), FastApiIntegration()],
traces_sample_rate=0.5,
)
logger.info("Sentry initialized")
else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
# Add the custom exception handler
application.add_exception_handler(status.HTTP_400_BAD_REQUEST, log_http_error)
@@ -247,7 +230,6 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, admin_persona_router)
include_router_with_global_prefix_prepended(application, input_prompt_router)
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
include_router_with_global_prefix_prepended(application, notification_router)
include_router_with_global_prefix_prepended(application, prompt_router)
include_router_with_global_prefix_prepended(application, tool_router)
include_router_with_global_prefix_prepended(application, admin_tool_router)
@@ -269,7 +251,7 @@ def get_application() -> FastAPI:
# Server logs this during auth setup verification step
pass
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
elif AUTH_TYPE == AuthType.BASIC:
include_router_with_global_prefix_prepended(
application,
fastapi_users.get_auth_router(auth_backend),
@@ -301,7 +283,7 @@ def get_application() -> FastAPI:
tags=["users"],
)
if AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.CLOUD:
elif AUTH_TYPE == AuthType.GOOGLE_OAUTH:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
application,
@@ -317,7 +299,6 @@ def get_application() -> FastAPI:
prefix="/auth/oauth",
tags=["auth"],
)
# Need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
application,

View File

@@ -50,26 +50,23 @@ def clean_model_name(model_str: str) -> str:
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
_WHITELIST = set(
" !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\n\t"
)
_INITIAL_FILTER = re.compile(
"["
"\U0000FFF0-\U0000FFFF" # Specials
"\U0001F000-\U0001F9FF" # Emoticons
"\U00002000-\U0000206F" # General Punctuation
"\U00002190-\U000021FF" # Arrows
"\U00002700-\U000027BF" # Dingbats
"\U00000080-\U0000FFFF" # All Unicode characters beyond ASCII
"\U00010000-\U0010FFFF" # All Unicode characters in supplementary planes
"]+",
flags=re.UNICODE,
)
def clean_openai_text(text: str) -> str:
# Remove specific Unicode ranges that might cause issues
# First, remove all weird characters
cleaned = _INITIAL_FILTER.sub("", text)
# Remove any control characters except for newline and tab
cleaned = "".join(ch for ch in cleaned if ch >= " " or ch in "\n\t")
return cleaned
# Then, keep only whitelisted characters
return "".join(char for char in cleaned if char in _WHITELIST)
def build_model_server_url(
@@ -100,8 +97,6 @@ class EmbeddingModel:
provider_type: EmbeddingProvider | None,
retrim_content: bool = False,
heartbeat: Heartbeat | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
@@ -111,8 +106,6 @@ class EmbeddingModel:
self.model_name = model_name
self.retrim_content = retrim_content
self.api_url = api_url
self.api_version = api_version
self.deployment_name = deployment_name
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
@@ -164,8 +157,6 @@ class EmbeddingModel:
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
api_version=self.api_version,
deployment_name=self.deployment_name,
max_context_length=max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
@@ -248,8 +239,6 @@ class EmbeddingModel:
provider_type=search_settings.provider_type,
api_url=search_settings.api_url,
retrim_content=retrim_content,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
)

View File

@@ -1,7 +1,4 @@
import functools
import threading
from collections.abc import Callable
from typing import Any
from typing import Optional
import redis
@@ -17,72 +14,6 @@ from danswer.configs.app_configs import REDIS_SSL
from danswer.configs.app_configs import REDIS_SSL_CA_CERTS
from danswer.configs.app_configs import REDIS_SSL_CERT_REQS
from danswer.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS
from danswer.utils.logger import setup_logger
logger = setup_logger()
class TenantRedis(redis.Redis):
def __init__(self, tenant_id: str, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.tenant_id: str = tenant_id
def _prefixed(self, key: str | bytes | memoryview) -> str | bytes | memoryview:
prefix: str = f"{self.tenant_id}:"
if isinstance(key, str):
if key.startswith(prefix):
return key
else:
return prefix + key
elif isinstance(key, bytes):
prefix_bytes = prefix.encode()
if key.startswith(prefix_bytes):
return key
else:
return prefix_bytes + key
elif isinstance(key, memoryview):
key_bytes = key.tobytes()
prefix_bytes = prefix.encode()
if key_bytes.startswith(prefix_bytes):
return key
else:
return memoryview(prefix_bytes + key_bytes)
else:
raise TypeError(f"Unsupported key type: {type(key)}")
def _prefix_method(self, method: Callable) -> Callable:
@functools.wraps(method)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if "name" in kwargs:
kwargs["name"] = self._prefixed(kwargs["name"])
elif len(args) > 0:
args = (self._prefixed(args[0]),) + args[1:]
return method(*args, **kwargs)
return wrapper
def __getattribute__(self, item: str) -> Any:
original_attr = super().__getattribute__(item)
methods_to_wrap = [
"lock",
"unlock",
"get",
"set",
"delete",
"exists",
"incrby",
"hset",
"hget",
"getset",
"scan_iter",
"owned",
"reacquire",
"create_lock",
"startswith",
] # Add all methods that need prefixing
if item in methods_to_wrap and callable(original_attr):
return self._prefix_method(original_attr)
return original_attr
class RedisPool:
@@ -101,10 +32,8 @@ class RedisPool:
def _init_pool(self) -> None:
self._pool = RedisPool.create_pool(ssl=REDIS_SSL)
def get_client(self, tenant_id: str | None) -> Redis:
if tenant_id is None:
tenant_id = "public"
return TenantRedis(tenant_id, connection_pool=self._pool)
def get_client(self) -> Redis:
return redis.Redis(connection_pool=self._pool)
@staticmethod
def create_pool(
@@ -155,8 +84,8 @@ class RedisPool:
redis_pool = RedisPool()
def get_redis_client(*, tenant_id: str | None) -> Redis:
return redis_pool.get_client(tenant_id)
def get_redis_client() -> Redis:
return redis_pool.get_client()
# # Usage example

View File

@@ -102,7 +102,6 @@ class BaseFilters(BaseModel):
class IndexFilters(BaseFilters):
access_control_list: list[str] | None
tenant_id: str | None = None
class ChunkMetric(BaseModel):

View File

@@ -1,6 +1,5 @@
from sqlalchemy.orm import Session
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.chat_configs import BASE_RECENCY_DECAY
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
@@ -10,7 +9,6 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import HYBRID_ALPHA_KEYWORD
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.db.engine import current_tenant_id
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.llm.interfaces import LLM
@@ -162,7 +160,6 @@ def retrieval_preprocessing(
time_cutoff=time_filter or predicted_time_cutoff,
tags=preset_filters.tags, # Tags are never auto-extracted
access_control_list=user_acl_filters,
tenant_id=current_tenant_id.get() if MULTI_TENANT else None,
)
llm_evaluation_type = LLMEvaluationType.BASIC

View File

@@ -4,13 +4,13 @@ from fastapi import FastAPI
from fastapi.dependencies.models import Dependant
from starlette.routing import BaseRoute
from danswer.auth.users import control_plane_dep
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import current_user_with_expired_token
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.server.danswer_api.ingestion import api_key_dep
from ee.danswer.server.tenants.access import control_plane_dep
PUBLIC_ENDPOINT_SPECS = [

View File

@@ -1,5 +1,4 @@
import math
from datetime import datetime
from http import HTTPStatus
from fastapi import APIRouter
@@ -24,7 +23,6 @@ from danswer.db.connector_credential_pair import (
)
from danswer.db.document import get_document_counts_for_cc_pairs
from danswer.db.engine import current_tenant_id
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
@@ -91,7 +89,6 @@ def get_cc_pair_full_info(
cc_pair_id: int,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> CCPairFullInfo:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id, db_session, user, get_editable=False
@@ -138,7 +135,6 @@ def get_cc_pair_full_info(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
db_session=db_session,
tenant_id=tenant_id,
),
num_docs_indexed=documents_indexed,
is_editable_for_current_user=is_editable_for_current_user,
@@ -158,7 +154,6 @@ def update_cc_pair_status(
user=user,
get_editable=True,
)
if not cc_pair:
raise HTTPException(
status_code=400,
@@ -168,6 +163,7 @@ def update_cc_pair_status(
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
# Just for good measure
cancel_indexing_attempts_past_model(db_session)
update_connector_credential_pair_from_id(
@@ -176,8 +172,6 @@ def update_cc_pair_status(
status=status_update_request.status,
)
db_session.commit()
@router.put("/admin/cc-pair/{cc_pair_id}/name")
def update_cc_pair_name(
@@ -208,12 +202,12 @@ def update_cc_pair_name(
raise HTTPException(status_code=400, detail="Name must be unique")
@router.get("/admin/cc-pair/{cc_pair_id}/last_pruned")
def get_cc_pair_last_pruned(
@router.get("/admin/cc-pair/{cc_pair_id}/prune")
def get_cc_pair_latest_prune(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> datetime | None:
) -> bool:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
@@ -223,10 +217,11 @@ def get_cc_pair_last_pruned(
if not cc_pair:
raise HTTPException(
status_code=400,
detail="cc_pair not found for current user's permissions",
detail="Connection not found for current user's permissions",
)
return cc_pair.last_pruned
rcp = RedisConnectorPruning(cc_pair.id)
return rcp.is_pruning(db_session, get_redis_client())
@router.post("/admin/cc-pair/{cc_pair_id}/prune")
@@ -234,7 +229,6 @@ def prune_cc_pair(
cc_pair_id: int,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> StatusResponse[list[int]]:
"""Triggers pruning on a particular cc_pair immediately"""
@@ -250,7 +244,7 @@ def prune_cc_pair(
detail="Connection not found for current user's permissions",
)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
rcp = RedisConnectorPruning(cc_pair_id)
if rcp.is_pruning(db_session, r):
raise HTTPException(

View File

@@ -15,9 +15,7 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.background.celery.celery_redis import RedisConnectorIndexing
from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
from danswer.background.celery.tasks.indexing.tasks import try_creating_indexing_task
from danswer.configs.app_configs import ENABLED_CONNECTOR_TYPES
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import FileOrigin
@@ -65,22 +63,19 @@ from danswer.db.credentials import delete_google_drive_service_account_credentia
from danswer.db.credentials import fetch_credential_by_id
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import get_document_counts_for_cc_pairs
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.enums import AccessType
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from danswer.db.index_attempt import get_latest_index_attempts
from danswer.db.index_attempt import get_latest_index_attempts_by_status
from danswer.db.models import IndexingStatus
from danswer.db.models import SearchSettings
from danswer.db.models import User
from danswer.db.models import UserRole
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.file_store.file_store import get_default_file_store
from danswer.key_value_store.interface import KvKeyNotFoundError
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import AuthStatus
from danswer.server.documents.models import AuthUrl
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
@@ -482,12 +477,9 @@ def get_connector_indexing_status(
get_editable: bool = Query(
False, description="If true, return editable document sets"
),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> list[ConnectorIndexingStatus]:
indexing_statuses: list[ConnectorIndexingStatus] = []
r = get_redis_client(tenant_id=tenant_id)
# NOTE: If the connector is deleting behind the scenes,
# accessing cc_pairs can be inconsistent and members like
# connector or credential may be None.
@@ -539,12 +531,6 @@ def get_connector_indexing_status(
relationship.user_group_id
)
search_settings: SearchSettings | None = None
if not secondary_index:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_secondary_search_settings(db_session)
for cc_pair in cc_pairs:
# TODO remove this to enable ingestion API
if cc_pair.name == "DefaultCCPair":
@@ -556,12 +542,6 @@ def get_connector_indexing_status(
# This may happen if background deletion is happening
continue
in_progress = False
if search_settings:
rci = RedisConnectorIndexing(cc_pair.id, search_settings.id)
if r.exists(rci.fence_key):
in_progress = True
latest_index_attempt = cc_pair_to_latest_index_attempt.get(
(connector.id, credential.id)
)
@@ -607,7 +587,6 @@ def get_connector_indexing_status(
connector_id=connector.id,
credential_id=credential.id,
db_session=db_session,
tenant_id=tenant_id,
),
is_deletable=check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair,
@@ -616,7 +595,6 @@ def get_connector_indexing_status(
allow_scheduled=True,
)
is None,
in_progress=in_progress,
)
)
@@ -685,18 +663,15 @@ def create_connector_with_mock_credential(
connector_response = create_connector(
db_session=db_session, connector_data=connector_data
)
mock_credential = CredentialBase(
credential_json={}, admin_public=True, source=connector_data.source
)
credential = create_credential(
mock_credential, user=user, db_session=db_session
)
access_type = (
AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE
)
response = add_credential_to_connector(
db_session=db_session,
user=user,
@@ -775,13 +750,7 @@ def connector_run_once(
run_info: RunConnectorRequest,
_: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> StatusResponse[list[int]]:
"""Used to trigger indexing on a set of cc_pairs associated with a
single connector."""
r = get_redis_client(tenant_id=tenant_id)
connector_id = run_info.connector_id
specified_credential_ids = run_info.credential_ids
@@ -835,24 +804,16 @@ def connector_run_once(
if credential_id not in skipped_credentials
]
index_attempt_ids = []
for cc_pair in connector_credential_pairs:
if cc_pair is not None:
attempt_id = try_creating_indexing_task(
cc_pair,
search_settings,
run_info.from_beginning,
db_session,
r,
tenant_id,
)
if attempt_id:
logger.info(
f"try_creating_indexing_task succeeded: cc_pair={cc_pair.id} attempt_id={attempt_id}"
)
index_attempt_ids.append(attempt_id)
else:
logger.info(f"try_creating_indexing_task failed: cc_pair={cc_pair.id}")
index_attempt_ids = [
create_index_attempt(
connector_credential_pair_id=connector_credential_pair.id,
search_settings_id=search_settings.id,
from_beginning=run_info.from_beginning,
db_session=db_session,
)
for connector_credential_pair in connector_credential_pairs
if connector_credential_pair is not None
]
if not index_attempt_ids:
raise HTTPException(

View File

@@ -307,10 +307,6 @@ class ConnectorIndexingStatus(BaseModel):
deletion_attempt: DeletionAttemptSnapshot | None
is_deletable: bool
# index attempt in db can be marked successful while celery/redis
# is stil running/cleaning up
in_progress: bool
class ConnectorCredentialPairIdentifier(BaseModel):
connector_id: int

View File

@@ -1,5 +1,3 @@
from uuid import UUID
from pydantic import BaseModel
from danswer.server.query_and_chat.models import ChatSessionDetails
@@ -25,7 +23,7 @@ class FolderUpdateRequest(BaseModel):
class FolderChatSessionRequest(BaseModel):
chat_session_id: UUID
chat_session_id: int
class DeleteFolderOptions(BaseModel):

View File

@@ -1,47 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.notification import dismiss_notification
from danswer.db.notification import get_notification_by_id
from danswer.db.notification import get_notifications
from danswer.server.settings.models import Notification as NotificationModel
from danswer.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/notifications")
@router.get("")
def get_notifications_api(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[NotificationModel]:
notifications = [
NotificationModel.from_model(notif)
for notif in get_notifications(user, db_session, include_dismissed=False)
]
return notifications
@router.post("/{notification_id}/dismiss")
def dismiss_notification_endpoint(
notification_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
notification = get_notification_by_id(notification_id, user, db_session)
except PermissionError:
raise HTTPException(
status_code=403, detail="Not authorized to dismiss this notification"
)
except ValueError:
raise HTTPException(status_code=404, detail="Notification not found")
dismiss_notification(notification, db_session)

View File

@@ -13,10 +13,8 @@ from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.auth.users import current_user
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import NotificationType
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.notification import create_notification
from danswer.db.persona import create_update_persona
from danswer.db.persona import get_persona_by_id
from danswer.db.persona import get_personas
@@ -30,7 +28,6 @@ from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import ChatFileType
from danswer.llm.answering.prompts.utils import build_dummy_prompt
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import PersonaSharedNotificationData
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.features.persona.models import PromptTemplateResponse
from danswer.server.models import DisplayPriorityRequest
@@ -186,12 +183,11 @@ class PersonaShareRequest(BaseModel):
user_ids: list[UUID]
# We notify each user when a user is shared with them
@basic_router.patch("/{persona_id}/share")
def share_persona(
persona_id: int,
persona_share_request: PersonaShareRequest,
user: User = Depends(current_user),
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
update_persona_shared_users(
@@ -201,18 +197,6 @@ def share_persona(
db_session=db_session,
)
for user_id in persona_share_request.user_ids:
# Don't notify the user that they have access to their own persona
if user_id != user.id:
create_notification(
user_id=user_id,
notif_type=NotificationType.PERSONA_SHARED,
db_session=db_session,
additional_data=PersonaSharedNotificationData(
persona_id=persona_id,
).model_dump(),
)
@basic_router.delete("/{persona_id}")
def delete_persona(
@@ -232,31 +216,23 @@ def list_personas(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
include_deleted: bool = False,
persona_ids: list[int] = Query(None),
) -> list[PersonaSnapshot]:
personas = get_personas(
user=user,
include_deleted=include_deleted,
db_session=db_session,
get_editable=False,
joinedload_all=True,
)
if persona_ids:
personas = [p for p in personas if p.id in persona_ids]
# Filter out personas with unavailable tools
personas = [
p
for p in personas
return [
PersonaSnapshot.from_model(persona)
for persona in get_personas(
user=user,
include_deleted=include_deleted,
db_session=db_session,
get_editable=False,
joinedload_all=True,
)
# If the persona has an image generation tool and it's not available, don't include it
if not (
any(tool.in_code_tool_id == "ImageGenerationTool" for tool in p.tools)
any(tool.in_code_tool_id == "ImageGenerationTool" for tool in persona.tools)
and not is_image_generation_available(db_session=db_session)
)
]
return [PersonaSnapshot.from_model(p) for p in personas]
@basic_router.get("/{persona_id}")
def get_persona(

View File

@@ -120,7 +120,3 @@ class PersonaSnapshot(BaseModel):
class PromptTemplateResponse(BaseModel):
final_prompt_template: str
class PersonaSharedNotificationData(BaseModel):
persona_id: int

View File

@@ -43,8 +43,6 @@ def test_embedding_configuration(
api_url=test_llm_request.api_url,
provider_type=test_llm_request.provider_type,
model_name=test_llm_request.model_name,
api_version=test_llm_request.api_version,
deployment_name=test_llm_request.deployment_name,
normalize=False,
query_prefix=None,
passage_prefix=None,

View File

@@ -17,8 +17,6 @@ class TestEmbeddingRequest(BaseModel):
api_key: str | None = None
api_url: str | None = None
model_name: str | None = None
api_version: str | None = None
deployment_name: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
@@ -28,8 +26,6 @@ class CloudEmbeddingProvider(BaseModel):
provider_type: EmbeddingProvider
api_key: str | None = None
api_url: str | None = None
api_version: str | None = None
deployment_name: str | None = None
@classmethod
def from_request(
@@ -39,8 +35,6 @@ class CloudEmbeddingProvider(BaseModel):
provider_type=cloud_provider_model.provider_type,
api_key=cloud_provider_model.api_key,
api_url=cloud_provider_model.api_url,
api_version=cloud_provider_model.api_version,
deployment_name=cloud_provider_model.deployment_name,
)
@@ -48,5 +42,3 @@ class CloudEmbeddingProviderCreationRequest(BaseModel):
provider_type: EmbeddingProvider
api_key: str | None = None
api_url: str | None = None
api_version: str | None = None
deployment_name: str | None = None

View File

@@ -115,7 +115,6 @@ def set_new_search_settings(
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)
db_session.commit()
return IdReturn(id=new_search_settings.id)

View File

@@ -3,8 +3,6 @@ from datetime import datetime
from datetime import timezone
import jwt
from email_validator import EmailNotValidError
from email_validator import EmailUndeliverableError
from email_validator import validate_email
from fastapi import APIRouter
from fastapi import Body
@@ -37,7 +35,6 @@ from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.constants import AuthType
from danswer.db.auth import get_total_users
from danswer.db.engine import current_tenant_id
from danswer.db.engine import get_session
from danswer.db.models import AccessToken
@@ -63,7 +60,6 @@ from danswer.utils.logger import setup_logger
from ee.danswer.db.api_key import is_api_key_email_address
from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_commit
from ee.danswer.db.user_group import remove_curator_status__no_commit
from ee.danswer.server.tenants.billing import register_tenant_users
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import remove_users_from_tenant
@@ -178,29 +174,19 @@ def list_all_users(
def bulk_invite_users(
emails: list[str] = Body(..., embed=True),
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> int:
"""emails are string validated. If any email fails validation, no emails are
invited and an exception is raised."""
if current_user is None:
raise HTTPException(
status_code=400, detail="Auth is disabled, cannot invite users"
)
tenant_id = current_tenant_id.get()
normalized_emails = []
try:
for email in emails:
email_info = validate_email(email)
normalized_emails.append(email_info.normalized) # type: ignore
except (EmailUndeliverableError, EmailNotValidError):
raise HTTPException(
status_code=400,
detail="One or more emails in the list are invalid",
)
for email in emails:
email_info = validate_email(email) # can raise EmailNotValidError
normalized_emails.append(email_info.normalized) # type: ignore
if MULTI_TENANT:
try:
@@ -213,58 +199,30 @@ def bulk_invite_users(
)
raise
initial_invited_users = get_invited_users()
all_emails = list(set(normalized_emails) | set(get_invited_users()))
all_emails = list(set(normalized_emails) | set(initial_invited_users))
number_of_invited_users = write_invited_users(all_emails)
if MULTI_TENANT and ENABLE_EMAIL_INVITES:
try:
for email in all_emails:
send_user_email_invite(email, current_user)
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
if not MULTI_TENANT:
return number_of_invited_users
try:
logger.info("Registering tenant users")
register_tenant_users(current_tenant_id.get(), get_total_users(db_session))
if ENABLE_EMAIL_INVITES:
try:
for email in all_emails:
send_user_email_invite(email, current_user)
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
return number_of_invited_users
except Exception as e:
logger.error(f"Failed to register tenant users: {str(e)}")
logger.info(
"Reverting changes: removing users from tenant and resetting invited users"
)
write_invited_users(initial_invited_users) # Reset to original state
remove_users_from_tenant(normalized_emails, tenant_id)
raise e
return write_invited_users(all_emails)
@router.patch("/manage/admin/remove-invited-user")
def remove_invited_user(
user_email: UserByEmail,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> int:
user_emails = get_invited_users()
remaining_users = [user for user in user_emails if user != user_email.user_email]
tenant_id = current_tenant_id.get()
remove_users_from_tenant([user_email.user_email], tenant_id)
number_of_invited_users = write_invited_users(remaining_users)
try:
if MULTI_TENANT:
register_tenant_users(current_tenant_id.get(), get_total_users(db_session))
except Exception:
logger.error(
"Request to update number of seats taken in control plane failed. "
"This may cause synchronization issues/out of date enforcement of seat limits."
)
raise
return number_of_invited_users
return write_invited_users(remaining_users)
@router.patch("/manage/admin/deactivate-user")
@@ -463,6 +421,7 @@ def get_current_token_creation(
@router.get("/me")
def verify_user_logged_in(
request: Request,
user: User | None = Depends(optional_user),
db_session: Session = Depends(get_session),
) -> UserInfo:

View File

@@ -4,7 +4,6 @@ import uuid
from collections.abc import Callable
from collections.abc import Generator
from typing import Tuple
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
@@ -19,12 +18,10 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.chat_utils import extract_headers
from danswer.chat.process_message import stream_chat_message
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import delete_chat_session
@@ -53,6 +50,7 @@ from danswer.llm.answering.prompts.citations_prompt import (
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llms
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.headers import get_litellm_additional_request_headers
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.secondary_llm_flows.chat_session_naming import (
get_renamed_conversation_name,
@@ -73,7 +71,6 @@ from danswer.server.query_and_chat.models import RenameChatSessionResponse
from danswer.server.query_and_chat.models import SearchFeedbackRequest
from danswer.server.query_and_chat.models import UpdateChatSessionThreadRequest
from danswer.server.query_and_chat.token_limit import check_token_rate_limits
from danswer.utils.headers import get_custom_tool_additional_request_headers
from danswer.utils.logger import setup_logger
@@ -132,7 +129,7 @@ def update_chat_session_model(
@router.get("/get-chat-session/{session_id}")
def get_chat_session(
session_id: UUID,
session_id: int,
is_shared: bool = False,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
@@ -232,9 +229,7 @@ def rename_chat_session(
try:
llm, _ = get_default_llms(
additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
)
additional_headers=get_litellm_additional_request_headers(request.headers)
)
except GenAIDisabledException:
# This may be longer than what the LLM tends to produce but is the most
@@ -255,7 +250,7 @@ def rename_chat_session(
@router.patch("/chat-session/{session_id}")
def patch_chat_session(
session_id: UUID,
session_id: int,
chat_session_update_req: ChatSessionUpdateRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
@@ -272,7 +267,7 @@ def patch_chat_session(
@router.delete("/delete-chat-session/{session_id}")
def delete_chat_session_by_id(
session_id: UUID,
session_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
@@ -335,10 +330,7 @@ def handle_new_chat_message(
new_msg_req=chat_message_req,
user=user,
use_existing_user_message=chat_message_req.use_existing_user_message,
litellm_additional_headers=extract_headers(
request.headers, LITELLM_PASS_THROUGH_HEADERS
),
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
litellm_additional_headers=get_litellm_additional_request_headers(
request.headers
),
is_connected=is_disconnected_func,

View File

@@ -1,6 +1,5 @@
from datetime import datetime
from typing import Any
from uuid import UUID
from pydantic import BaseModel
from pydantic import model_validator
@@ -35,7 +34,7 @@ class SimpleQueryRequest(BaseModel):
class UpdateChatSessionThreadRequest(BaseModel):
# If not specified, use Danswer default persona
chat_session_id: UUID
chat_session_id: int
new_alternate_model: str
@@ -46,7 +45,7 @@ class ChatSessionCreationRequest(BaseModel):
class CreateChatSessionID(BaseModel):
chat_session_id: UUID
chat_session_id: int
class ChatFeedbackRequest(BaseModel):
@@ -76,7 +75,7 @@ Currently the different branches are generated by changing the search query
class CreateChatMessageRequest(ChunkContext):
"""Before creating messages, be sure to create a chat_session and get an id"""
chat_session_id: UUID
chat_session_id: int
# This is the primary-key (unique identifier) for the previous message of the tree
parent_message_id: int | None
# New message contents
@@ -116,18 +115,13 @@ class CreateChatMessageRequest(ChunkContext):
)
return self
def model_dump(self, *args: Any, **kwargs: Any) -> dict[str, Any]:
data = super().model_dump(*args, **kwargs)
data["chat_session_id"] = str(data["chat_session_id"])
return data
class ChatMessageIdentifier(BaseModel):
message_id: int
class ChatRenameRequest(BaseModel):
chat_session_id: UUID
chat_session_id: int
name: str | None = None
@@ -140,7 +134,7 @@ class RenameChatSessionResponse(BaseModel):
class ChatSessionDetails(BaseModel):
id: UUID
id: int
name: str
persona_id: int | None = None
time_created: str
@@ -181,7 +175,7 @@ class ChatMessageDetail(BaseModel):
overridden_model: str | None
alternate_assistant_id: int | None = None
# Dict mapping citation number to db_doc_id
chat_session_id: UUID | None = None
chat_session_id: int | None = None
citations: dict[int, int] | None = None
files: list[FileDescriptor]
tool_calls: list[ToolCallFinalResult]
@@ -193,14 +187,14 @@ class ChatMessageDetail(BaseModel):
class SearchSessionDetailResponse(BaseModel):
search_session_id: UUID
search_session_id: int
description: str
documents: list[SearchDoc]
messages: list[ChatMessageDetail]
class ChatSessionDetailResponse(BaseModel):
chat_session_id: UUID
chat_session_id: int
description: str
persona_id: int | None = None
persona_name: str | None

View File

@@ -1,7 +1,3 @@
import json
from collections.abc import Generator
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
@@ -190,7 +186,7 @@ def get_user_search_sessions(
@basic_router.get("/search-session/{session_id}")
def get_search_session(
session_id: UUID,
session_id: int,
is_shared: bool = False,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
@@ -269,17 +265,10 @@ def get_answer_with_quote(
logger.notice(f"Received query for one shot answer with quotes: {query}")
def stream_generator() -> Generator[str, None, None]:
try:
for packet in stream_search_answer(
query_req=query_request,
user=user,
max_document_tokens=None,
max_history_tokens=0,
):
yield json.dumps(packet) if isinstance(packet, dict) else packet
except Exception as e:
logger.exception(f"Error in search answer streaming: {e}")
yield json.dumps({"error": str(e)})
return StreamingResponse(stream_generator(), media_type="application/json")
packets = stream_search_answer(
query_req=query_request,
user=user,
max_document_tokens=None,
max_history_tokens=0,
)
return StreamingResponse(packets, media_type="application/json")

View File

@@ -13,7 +13,6 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.db.engine import get_session_context_manager
from danswer.db.engine import get_session_with_tenant
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
from danswer.db.models import TokenRateLimit
@@ -21,7 +20,6 @@ from danswer.db.models import User
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import fetch_versioned_implementation
from ee.danswer.db.token_limit import fetch_all_global_token_rate_limits
from shared_configs.configs import current_tenant_id
logger = setup_logger()
@@ -41,11 +39,11 @@ def check_token_rate_limits(
versioned_rate_limit_strategy = fetch_versioned_implementation(
"danswer.server.query_and_chat.token_limit", "_check_token_rate_limits"
)
return versioned_rate_limit_strategy(user, current_tenant_id.get())
return versioned_rate_limit_strategy(user)
def _check_token_rate_limits(_: User | None, tenant_id: str | None) -> None:
_user_is_rate_limited_by_global(tenant_id)
def _check_token_rate_limits(_: User | None) -> None:
_user_is_rate_limited_by_global()
"""
@@ -53,8 +51,8 @@ Global rate limits
"""
def _user_is_rate_limited_by_global(tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
def _user_is_rate_limited_by_global() -> None:
with get_session_context_manager() as db_session:
global_rate_limits = fetch_all_global_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)

View File

@@ -15,6 +15,8 @@ from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.notification import create_notification
from danswer.db.notification import dismiss_all_notifications
from danswer.db.notification import dismiss_notification
from danswer.db.notification import get_notification_by_id
from danswer.db.notification import get_notifications
from danswer.db.notification import update_notification_last_shown
from danswer.key_value_store.factory import get_kv_store
@@ -53,7 +55,7 @@ def fetch_settings(
"""Settings and notifications are stuffed into this single endpoint to reduce number of
Postgres calls"""
general_settings = load_settings()
user_notifications = get_reindex_notification(user, db_session)
user_notifications = get_user_notifications(user, db_session)
try:
kv_store = get_kv_store()
@@ -68,7 +70,25 @@ def fetch_settings(
)
def get_reindex_notification(
@basic_router.post("/notifications/{notification_id}/dismiss")
def dismiss_notification_endpoint(
notification_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
notification = get_notification_by_id(notification_id, user, db_session)
except PermissionError:
raise HTTPException(
status_code=403, detail="Not authorized to dismiss this notification"
)
except ValueError:
raise HTTPException(status_code=404, detail="Notification not found")
dismiss_notification(notification, db_session)
def get_user_notifications(
user: User | None, db_session: Session
) -> list[Notification]:
"""Get notifications for the user, currently the logic is very specific to the reindexing flag"""
@@ -101,7 +121,7 @@ def get_reindex_notification(
if not reindex_notifs:
notif = create_notification(
user_id=user.id if user else None,
user=user,
notif_type=NotificationType.REINDEX,
db_session=db_session,
)

View File

@@ -12,19 +12,12 @@ class PageType(str, Enum):
SEARCH = "search"
class GatingType(str, Enum):
FULL = "full" # Complete restriction of access to the product or service
PARTIAL = "partial" # Full access but warning (no credit card on file)
NONE = "none" # No restrictions, full access to all features
class Notification(BaseModel):
id: int
notif_type: NotificationType
dismissed: bool
last_shown: datetime
first_shown: datetime
additional_data: dict | None = None
@classmethod
def from_model(cls, notif: NotificationDBModel) -> "Notification":
@@ -34,7 +27,6 @@ class Notification(BaseModel):
dismissed=notif.dismissed,
last_shown=notif.last_shown,
first_shown=notif.first_shown,
additional_data=notif.additional_data,
)
@@ -46,7 +38,6 @@ class Settings(BaseModel):
default_page: PageType = PageType.SEARCH
maximum_chat_retention_days: int | None = None
gpu_enabled: bool | None = None
product_gating: GatingType = GatingType.NONE
def check_validity(self) -> None:
chat_page_enabled = self.chat_page_enabled

Some files were not shown because too many files have changed in this diff Show More