Compare commits

...

64 Commits

Author SHA1 Message Date
pablodanswer
ebb9b94c6f squash 2024-09-27 13:50:13 -07:00
pablodanswer
a759238372 qol updates 2024-09-27 13:46:31 -07:00
pablodanswer
40cce46996 squash 2024-09-27 13:15:56 -07:00
pablodanswer
c9a84d5084 validated indexing swap 2024-09-27 12:58:07 -07:00
pablodanswer
ee831b73e4 update logic for tenancy 2024-09-27 12:54:46 -07:00
pablodanswer
7419bf6b06 fix typing issues / build issues + backwards-compatibility for indexing 2024-09-27 12:48:21 -07:00
pablodanswer
e5f3f2d73a validated functionality for single-tenant 2024-09-27 10:21:08 -07:00
pablodanswer
dc5a91fd85 update logs 2024-09-26 17:31:37 -07:00
pablodanswer
3158db4239 minor formatting fixes 2024-09-26 17:26:58 -07:00
pablodanswer
341bf26ff2 minor update to typing 2024-09-26 17:22:16 -07:00
pablodanswer
516f1840ce remove more logs for clarity 2024-09-26 16:58:26 -07:00
pablodanswer
1f12b074df update for clarity 2024-09-26 16:55:47 -07:00
pablodanswer
8f3f905a99 update tsx 2024-09-26 16:41:47 -07:00
pablodanswer
01e1bd0ee2 remove logs 2024-09-26 13:35:00 -07:00
pablodanswer
5f7be266f0 fix building + types 2024-09-26 13:29:57 -07:00
pablodanswer
478dd1c4bb functional multi tenant connector deletion 2024-09-26 13:14:48 -07:00
pablodanswer
f0a5ec223f valid vespa indexing + search for multi tenant use case 2024-09-26 12:44:27 -07:00
pablodanswer
daad96d180 valid multi tenant indexing + valid upgrades/updated endpoints 2024-09-26 10:47:29 -07:00
pablodanswer
b853e5f22a robustified alembic upgrade 2024-09-25 14:48:38 -07:00
pablodanswer
dcc4c61fcb update alembic to temporarily include initial seeding 2024-09-25 14:25:59 -07:00
pablodanswer
5775aec498 more solid schema context passing 2024-09-25 14:15:16 -07:00
pablodanswer
b4ee066424 functional janky context passing 2024-09-25 13:52:07 -07:00
pablodanswer
88ade7cb7e investigate context issues 2024-09-25 13:45:52 -07:00
pablodanswer
a69a0333a5 valid but janky db session handling 2024-09-25 13:09:55 -07:00
pablodanswer
198f80d224 remove dependency fully in alembic migrations 2024-09-25 12:08:33 -07:00
pablodanswer
4855a80f86 proper schema isolation 2024-09-24 20:02:34 -07:00
pablodanswer
6e78f2094b update alembic to be stateless 2024-09-24 19:15:45 -07:00
pablodanswer
c2e953633a update vespa + sidebar 2024-09-24 19:00:31 -07:00
pablodanswer
0ff4ff0abc update typing 2024-09-23 17:51:31 -07:00
pablodanswer
38af754968 migrate all tenant upgrade services to data plane endpoint 2024-09-23 16:09:09 -07:00
pablodanswer
4f9420217e more secure callback 2024-09-23 14:12:31 -07:00
pablodanswer
e5584ca364 basic flows functional 2024-09-23 13:51:40 -07:00
pablodanswer
ae3218f941 squash - need to update sql alchemy engine 2024-09-22 16:45:06 -07:00
pablodanswer
5b220ac7b1 validatid across several tenants (+ more secure db_sessions) 2024-09-22 14:43:44 -07:00
pablodanswer
d1f40cfd30 proper loading of yamls on schema startup 2024-09-22 12:20:34 -07:00
pablodanswer
fe3f6d451d ruff 2024-09-22 11:24:48 -07:00
pablodanswer
d1641652a2 squash 2024-09-22 11:19:25 -07:00
pablodanswer
17412fb9f7 fix web build 2024-09-22 11:18:55 -07:00
pablodanswer
a28ac88341 fix build 2024-09-22 11:15:14 -07:00
pablodanswer
95a11b8adc valid non-multitenant defaults 2024-09-22 11:06:51 -07:00
pablodanswer
e9906c37fe functional multi tenancy (excluding accurate provisioning) 2024-09-21 21:31:05 -07:00
pablodanswer
127526d080 temporary stopgap for uperts 2024-09-21 19:50:56 -07:00
pablodanswer
d3d63ee8f7 updated alembic migrations 2024-09-21 19:30:59 -07:00
pablodanswer
f98c77397d update formatting 2024-09-20 15:56:05 -07:00
pablodanswer
918623eb97 callback toast 2024-08-31 18:16:08 -07:00
pablodanswer
482117c4e7 add sso auth 2024-08-30 08:48:44 -07:00
pablodanswer
827e4169c5 squash 2024-08-29 19:36:16 -07:00
pablodanswer
06c3e2064f fully valid auth scheme 2024-08-29 18:57:53 -07:00
pablodanswer
0a1c8ae980 functional sso callback 2024-08-29 18:22:31 -07:00
pablodanswer
db54cb448b revert to functioanl userrole logic 2024-08-29 17:28:53 -07:00
pablodanswer
cb0a1e4fdc functional porting over user 2024-08-29 17:25:43 -07:00
pablodanswer
0547fff1d5 functional redirect auth screen 2024-08-29 13:17:20 -07:00
pablodanswer
01841adb43 squash 2024-08-29 12:58:25 -07:00
pablodanswer
bc5b269446 finalize billing settings page 2024-08-29 12:55:36 -07:00
pablodanswer
64768f82f3 add billing 2024-08-29 12:17:20 -07:00
pablodanswer
90da2166c2 add proper stripe checkout button for user 2024-08-29 11:49:10 -07:00
pablodanswer
c8f7e6185f squash 2024-08-28 16:21:11 -07:00
pablodanswer
7b895008d3 plan out 2024-08-28 14:43:16 -07:00
pablodanswer
3838908e70 proper parentheses 2024-08-28 10:34:33 -07:00
pablodanswer
3112a9df9d extract loading 2024-08-28 10:26:24 -07:00
pablodanswer
fb29c70f37 squash 2024-08-28 08:56:48 -07:00
pablodanswer
e547cd6a79 minor formatting 2024-08-28 08:53:36 -07:00
pablodanswer
1fa324b135 proper padding 2024-08-28 08:52:30 -07:00
pablodanswer
f4f3dd479e ux improvements 2024-08-28 08:45:55 -07:00
111 changed files with 2403 additions and 783 deletions

View File

@@ -1,72 +1,81 @@
import asyncio
from logging.config import fileConfig
from typing import Tuple
from alembic import context
from danswer.db.engine import build_connection_string
from danswer.db.models import Base
from sqlalchemy import pool
from sqlalchemy import pool, text
from sqlalchemy.engine import Connection
from sqlalchemy.ext.asyncio import create_async_engine
from celery.backends.database.session import ResultModelBase # type: ignore
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
# Alembic Config object
config = context.config
# Interpret the config file for Python logging.
# This line sets up loggers basically.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# add your model's MetaData object here
# Add your model's MetaData object here
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = [Base.metadata, ResultModelBase.metadata]
# other values from the config, defined by the needs of env.py,
# can be acquired:
# my_important_option = config.get_main_option("my_important_option")
# ... etc.
def get_schema_options() -> str:
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
for pair in arg.split(','):
if '=' in pair:
key, value = pair.split('=', 1)
x_args[key] = value
schema_name = x_args.get('schema', 'public')
return schema_name
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
"""Run migrations in 'offline' mode."""
url = build_connection_string()
schema = get_schema_options()
context.configure(
url=url,
target_metadata=target_metadata, # type: ignore
target_metadata=target_metadata, # type: ignore
literal_binds=True,
dialect_opts={"paramstyle": "named"},
version_table_schema=schema,
include_schemas=True,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata) # type: ignore
schema = get_schema_options()
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
connection.execute(text('COMMIT'))
connection.execute(text(f'SET search_path TO "{schema}"'))
context.configure(
connection=connection,
target_metadata=target_metadata, # type: ignore
version_table_schema=schema,
include_schemas=True,
compare_type=True,
compare_server_default=True,
)
with context.begin_transaction():
context.run_migrations()
async def run_async_migrations() -> None:
"""In this scenario we need to create an Engine
and associate a connection with the context.
"""
print("Running async migrations")
"""Run migrations in 'online' mode."""
connectable = create_async_engine(
build_connection_string(),
poolclass=pool.NullPool,
@@ -77,13 +86,10 @@ async def run_async_migrations() -> None:
await connectable.dispose()
def run_migrations_online() -> None:
"""Run migrations in 'online' mode."""
asyncio.run(run_async_migrations())
if context.is_offline_mode():
run_migrations_offline()
else:

View File

@@ -9,9 +9,9 @@ from alembic import op
import sqlalchemy as sa
from sqlalchemy.sql import table
from sqlalchemy.dialects import postgresql
from alembic_utils import encrypt_string
import json
from danswer.utils.encryption import encrypt_string_to_bytes
# revision identifiers, used by Alembic.
revision = "0a98909f2757"
@@ -57,7 +57,7 @@ def upgrade() -> None:
# In other words, this upgrade does not apply the encryption. Porting existing sensitive data
# and key rotation currently is not supported and will come out in the future
for row_id, creds, _ in results:
creds_binary = encrypt_string_to_bytes(json.dumps(creds))
creds_binary = encrypt_string(json.dumps(creds))
connection.execute(
creds_table.update()
.where(creds_table.c.id == row_id)
@@ -86,7 +86,7 @@ def upgrade() -> None:
results = connection.execute(sa.select(llm_table))
for row_id, api_key, _ in results:
llm_key = encrypt_string_to_bytes(api_key)
llm_key = encrypt_string(api_key)
connection.execute(
llm_table.update()
.where(llm_table.c.id == row_id)

View File

@@ -8,7 +8,7 @@ Create Date: 2023-11-11 20:51:24.228999
from alembic import op
import sqlalchemy as sa
from danswer.configs.constants import DocumentSource
from alembic_utils import DocumentSource
# revision identifiers, used by Alembic.
revision = "15326fcec57e"

View File

@@ -9,8 +9,7 @@ Create Date: 2024-08-25 12:39:51.731632
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
from alembic_utils import NUM_POSTPROCESSED_RESULTS
# revision identifiers, used by Alembic.
revision = "1f60f60c3401"

View File

@@ -5,11 +5,8 @@ Revises: fad14119fb92
Create Date: 2024-04-15 01:36:02.952809
"""
import json
from typing import cast
from alembic import op
import sqlalchemy as sa
from danswer.dynamic_configs.factory import get_dynamic_config_store
# revision identifiers, used by Alembic.
revision = "703313b75876"
@@ -53,30 +50,6 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("rate_limit_id", "user_group_id"),
)
try:
settings_json = cast(
str, get_dynamic_config_store().load("token_budget_settings")
)
settings = json.loads(settings_json)
is_enabled = settings.get("enable_token_budget", False)
token_budget = settings.get("token_budget", -1)
period_hours = settings.get("period_hours", -1)
if is_enabled and token_budget > 0 and period_hours > 0:
op.execute(
f"INSERT INTO token_rate_limit \
(enabled, token_budget, period_hours, scope) VALUES \
({is_enabled}, {token_budget}, {period_hours}, 'GLOBAL')"
)
# Delete the dynamic config
get_dynamic_config_store().delete("token_budget_settings")
except Exception:
# Ignore if the dynamic config is not found
pass
def downgrade() -> None:
op.drop_table("token_rate_limit__user_group")

View File

@@ -7,10 +7,8 @@ Create Date: 2024-03-22 21:34:27.629444
"""
from alembic import op
import sqlalchemy as sa
from alembic_utils import IndexModelStatus, RecencyBiasSetting, SearchType
from danswer.db.models import IndexModelStatus
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
# revision identifiers, used by Alembic.
revision = "776b3bbe9092"

View File

@@ -7,7 +7,7 @@ Create Date: 2024-03-21 12:05:23.956734
"""
from alembic import op
import sqlalchemy as sa
from danswer.configs.constants import DocumentSource
from alembic_utils import DocumentSource
# revision identifiers, used by Alembic.
revision = "91fd3b470d1a"

View File

@@ -10,7 +10,7 @@ from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import ENUM
from danswer.configs.constants import DocumentSource
from alembic_utils import DocumentSource
# revision identifiers, used by Alembic.
revision = "b156fa702355"

View File

@@ -0,0 +1,24 @@
"""add tenant id to user model
Revision ID: b25c363470f3
Revises: 1f60f60c3401
Create Date: 2024-08-29 17:03:20.794120
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "b25c363470f3"
down_revision = "1f60f60c3401"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column("user", sa.Column("tenant_id", sa.Text(), nullable=True))
def downgrade() -> None:
op.drop_column("user", "tenant_id")

View File

@@ -19,15 +19,16 @@ depends_on: None = None
def upgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text("select id, chosen_assistants from public.user")
sa.text('SELECT id, chosen_assistants FROM "user"')
)
op.drop_column(
"user",
'user',
"chosen_assistants",
)
op.add_column(
"user",
'user',
sa.Column(
"chosen_assistants",
postgresql.JSONB(astext_type=sa.Text()),
@@ -37,7 +38,7 @@ def upgrade() -> None:
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
'UPDATE user SET chosen_assistants = :chosen_assistants WHERE id = :id'
),
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
)
@@ -46,20 +47,20 @@ def upgrade() -> None:
def downgrade() -> None:
conn = op.get_bind()
existing_ids_and_chosen_assistants = conn.execute(
sa.text("select id, chosen_assistants from public.user")
sa.text('SELECT id, chosen_assistants FROM user')
)
op.drop_column(
"user",
'user',
"chosen_assistants",
)
op.add_column(
"user",
'user',
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
)
for id, chosen_assistants in existing_ids_and_chosen_assistants:
conn.execute(
sa.text(
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
'UPDATE user SET chosen_assistants = :chosen_assistants WHERE id = :id'
),
{"chosen_assistants": chosen_assistants, "id": id},
)

View File

@@ -8,20 +8,13 @@ Create Date: 2024-01-25 17:12:31.813160
from alembic import op
import sqlalchemy as sa
from sqlalchemy import table, column, String, Integer, Boolean
from danswer.db.search_settings import (
get_new_default_embedding_model,
get_old_default_embedding_model,
user_has_overridden_embedding_model,
)
from danswer.db.models import IndexModelStatus
from alembic_utils import IndexModelStatus
# revision identifiers, used by Alembic.
revision = "dbaa756c2ccf"
down_revision = "7f726bad5367"
branch_labels: None = None
depends_on: None = None
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
@@ -40,9 +33,32 @@ def upgrade() -> None:
),
sa.PrimaryKeyConstraint("id"),
)
# since all index attempts must be associated with an embedding model,
# need to put something in here to avoid nulls. On server startup,
# this value will be overriden
# Define the old default embedding model directly
old_embedding_model = {
"model_name": "sentence-transformers/all-distilroberta-v1",
"model_dim": 768,
"normalize": True,
"query_prefix": "",
"passage_prefix": "",
"index_name": "OPENSEARCH_INDEX_NAME",
"status": IndexModelStatus.PAST,
}
# Define the new default embedding model directly
new_embedding_model = {
"model_name": "intfloat/e5-small-v2",
"model_dim": 384,
"normalize": False,
"query_prefix": "query: ",
"passage_prefix": "passage: ",
"index_name": "danswer_chunk_intfloat_e5_small_v2",
"status": IndexModelStatus.PRESENT,
}
# Assume the user has not overridden the embedding model
user_overridden_embedding_model = False
EmbeddingModel = table(
"embedding_model",
column("id", Integer),
@@ -52,45 +68,23 @@ def upgrade() -> None:
column("query_prefix", String),
column("passage_prefix", String),
column("index_name", String),
column(
"status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False)
),
column("status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False)),
)
# insert an embedding model row that corresponds to the embedding model
# the user selected via env variables before this change. This is needed since
# all index_attempts must be associated with an embedding model, so without this
# we will run into violations of non-null contraints
old_embedding_model = get_old_default_embedding_model()
# Insert the old embedding model
op.bulk_insert(
EmbeddingModel,
[
{
"model_name": old_embedding_model.model_name,
"model_dim": old_embedding_model.model_dim,
"normalize": old_embedding_model.normalize,
"query_prefix": old_embedding_model.query_prefix,
"passage_prefix": old_embedding_model.passage_prefix,
"index_name": old_embedding_model.index_name,
"status": IndexModelStatus.PRESENT,
}
old_embedding_model
],
)
# if the user has not overridden the default embedding model via env variables,
# insert the new default model into the database to auto-upgrade them
if not user_has_overridden_embedding_model():
new_embedding_model = get_new_default_embedding_model()
# If the user has not overridden the embedding model, insert the new default model
if not user_overridden_embedding_model:
op.bulk_insert(
EmbeddingModel,
[
{
"model_name": new_embedding_model.model_name,
"model_dim": new_embedding_model.model_dim,
"normalize": new_embedding_model.normalize,
"query_prefix": new_embedding_model.query_prefix,
"passage_prefix": new_embedding_model.passage_prefix,
"index_name": new_embedding_model.index_name,
"status": IndexModelStatus.FUTURE,
}
new_embedding_model
],
)
@@ -129,7 +123,6 @@ def upgrade() -> None:
postgresql_where=sa.text("status = 'FUTURE'"),
)
def downgrade() -> None:
op.drop_constraint(
"index_attempt__embedding_model_fk", "index_attempt", type_="foreignkey"

View File

@@ -8,7 +8,7 @@ Create Date: 2024-03-14 18:06:08.523106
from alembic import op
import sqlalchemy as sa
from danswer.configs.constants import DocumentSource
from alembic_utils import DocumentSource
# revision identifiers, used by Alembic.
revision = "e50154680a5c"

99
backend/alembic_utils.py Normal file
View File

@@ -0,0 +1,99 @@
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
from os import urandom
import os
from enum import Enum
ENCRYPTION_KEY_SECRET = os.environ.get("ENCRYPTION_KEY_SECRET") or ""
def _get_trimmed_key(key: str) -> bytes:
encoded_key = key.encode()
key_length = len(encoded_key)
if key_length < 16:
raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short")
elif key_length > 32:
key = key[:32]
elif key_length not in (16, 24, 32):
valid_lengths = [16, 24, 32]
key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))]
return encoded_key
def encrypt_string(input_str: str) -> bytes:
if not ENCRYPTION_KEY_SECRET:
return input_str.encode()
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
iv = urandom(16)
padder = padding.PKCS7(algorithms.AES.block_size).padder()
padded_data = padder.update(input_str.encode()) + padder.finalize()
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
encryptor = cipher.encryptor()
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
return iv + encrypted_data
NUM_POSTPROCESSED_RESULTS = 20
class IndexModelStatus(str, Enum):
PAST = "PAST"
PRESENT = "PRESENT"
FUTURE = "FUTURE"
class RecencyBiasSetting(str, Enum):
FAVOR_RECENT = "favor_recent" # 2x decay rate
BASE_DECAY = "base_decay"
NO_DECAY = "no_decay"
# Determine based on query if to use base_decay or favor_recent
AUTO = "auto"
class SearchType(str, Enum):
KEYWORD = "keyword"
SEMANTIC = "semantic"
class DocumentSource(str, Enum):
# Special case, document passed in via Danswer APIs without specifying a source type
INGESTION_API = "ingestion_api"
SLACK = "slack"
WEB = "web"
GOOGLE_DRIVE = "google_drive"
GMAIL = "gmail"
REQUESTTRACKER = "requesttracker"
GITHUB = "github"
GITLAB = "gitlab"
GURU = "guru"
BOOKSTACK = "bookstack"
CONFLUENCE = "confluence"
SLAB = "slab"
JIRA = "jira"
PRODUCTBOARD = "productboard"
FILE = "file"
NOTION = "notion"
ZULIP = "zulip"
LINEAR = "linear"
HUBSPOT = "hubspot"
DOCUMENT360 = "document360"
GONG = "gong"
GOOGLE_SITES = "google_sites"
ZENDESK = "zendesk"
LOOPIO = "loopio"
DROPBOX = "dropbox"
SHAREPOINT = "sharepoint"
TEAMS = "teams"
SALESFORCE = "salesforce"
DISCOURSE = "discourse"
AXERO = "axero"
CLICKUP = "clickup"
MEDIAWIKI = "mediawiki"
WIKIPEDIA = "wikipedia"
S3 = "s3"
R2 = "r2"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
OCI_STORAGE = "oci_storage"
NOT_APPLICABLE = "not_applicable"

View File

@@ -33,6 +33,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
tenant_id: str | None = None
class UserUpdate(schemas.BaseUserUpdate):

View File

@@ -1,3 +1,6 @@
from danswer.configs.app_configs import SECRET_JWT_KEY
from datetime import timedelta
import contextlib
import smtplib
import uuid
from collections.abc import AsyncGenerator
@@ -8,6 +11,7 @@ from email.mime.text import MIMEText
from typing import Optional
from typing import Tuple
import jwt
from email_validator import EmailNotValidError
from email_validator import validate_email
from fastapi import APIRouter
@@ -54,6 +58,7 @@ from danswer.db.auth import get_access_token_db
from danswer.db.auth import get_default_admin_user_emails
from danswer.db.auth import get_user_count
from danswer.db.auth import get_user_db
from danswer.db.engine import get_async_session
from danswer.db.engine import get_session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import AccessToken
@@ -191,11 +196,88 @@ def send_user_verification_email(
s.login(SMTP_USER, SMTP_PASS)
s.send_message(msg)
def verify_sso_token(token: str) -> dict:
try:
payload = jwt.decode(token, "SSO_SECRET_KEY", algorithms=["HS256"])
if datetime.now(timezone.utc) > datetime.fromtimestamp(
payload["exp"], timezone.utc
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired"
)
return payload
except jwt.PyJWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
)
async def get_or_create_user(email: str, user_id: str) -> User:
get_async_session_context = contextlib.asynccontextmanager(get_async_session)
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
async with get_async_session_context() as session:
async with get_user_db_context(session) as user_db:
existing_user = await user_db.get_by_email(email)
if existing_user:
return existing_user
new_user = {
"email": email,
"id": uuid.UUID(user_id),
"role": UserRole.BASIC,
"oidc_expiry": None,
"default_model": None,
"chosen_assistants": None,
"hashed_password": "p",
"is_active": True,
"is_superuser": False,
"is_verified": True,
}
created_user: User = await user_db.create(new_user)
return created_user
async def create_user_session(user: User, tenant_id: str) -> str:
# Create a payload user information and tenant_id
payload = {
"sub": str(user.id),
"email": user.email,
"tenant_id": tenant_id,
"exp": datetime.utcnow() + timedelta(seconds=SESSION_EXPIRE_TIME_SECONDS)
}
token = jwt.encode(payload, SECRET_JWT_KEY, algorithm="HS256")
return token
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
async def sso_authenticate(
self,
email: str,
tenant_id: str,
) -> User:
try:
user = await self.get_by_email(email)
except Exception:
# user_create = UserCreate(email=email, password=secrets.token_urlsafe(32))
user_create = UserCreate(
role=UserRole.BASIC, password="password", email=email, is_verified=True
)
user = await self.create(user_create)
# Update user with tenant information if needed
if user.tenant_id != tenant_id:
await self.user_db.update(user, {"tenant_id": tenant_id})
return user
async def create(
self,
user_create: schemas.UC | UserCreate,
@@ -210,6 +292,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user_create.role = UserRole.ADMIN
else:
user_create.role = UserRole.BASIC
return await super().create(user_create, safe=safe, request=request) # type: ignore
async def oauth_callback(
@@ -298,7 +381,6 @@ def get_database_strategy(
strategy = DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
)
return strategy

View File

@@ -1,3 +1,6 @@
from danswer.configs.app_configs import MULTI_TENANT
from danswer.background.update import get_all_tenant_ids
import json
from datetime import timedelta
from typing import Any
@@ -67,11 +70,12 @@ _SYNC_BATCH_SIZE = 100
def cleanup_connector_credential_pair_task(
connector_id: int,
credential_id: int,
tenant_id: str | None
) -> int:
"""Connector deletion task. This is run as an async task because it is a somewhat slow job.
Needs to potentially update a large number of Postgres and Vespa docs, including deleting them
or updating the ACL"""
engine = get_sqlalchemy_engine()
engine = get_sqlalchemy_engine(schema=tenant_id)
with Session(engine) as db_session:
# validate that the connector / credential pair is deletable
cc_pair = get_connector_credential_pair(
@@ -101,6 +105,7 @@ def cleanup_connector_credential_pair_task(
db_session=db_session,
document_index=document_index,
cc_pair=cc_pair,
tenant_id=tenant_id,
)
except Exception as e:
logger.exception(f"Failed to run connector_deletion due to {e}")
@@ -109,7 +114,7 @@ def cleanup_connector_credential_pair_task(
@build_celery_task_wrapper(name_cc_prune_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def prune_documents_task(connector_id: int, credential_id: int) -> None:
def prune_documents_task(connector_id: int, credential_id: int, tenant_id: str | None) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
@@ -167,6 +172,7 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
tenant_id=tenant_id,
)
except Exception as e:
logger.exception(
@@ -177,7 +183,7 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
@build_celery_task_wrapper(name_document_set_sync_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_document_set_task(document_set_id: int) -> None:
def sync_document_set_task(document_set_id: int, tenant_id: str | None) -> None:
"""For document sets marked as not up to date, sync the state from postgres
into the datastore. Also handles deletions."""
@@ -210,7 +216,7 @@ def sync_document_set_task(document_set_id: int) -> None:
]
document_index.update(update_requests=update_requests)
with Session(get_sqlalchemy_engine()) as db_session:
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
try:
cursor = None
while True:
@@ -261,10 +267,10 @@ def sync_document_set_task(document_set_id: int) -> None:
name="check_for_document_sets_sync_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_document_sets_sync_task() -> None:
def check_for_document_sets_sync_task(tenant_id: str | None) -> None:
"""Runs periodically to check if any sync tasks should be run and adds them
to the queue"""
with Session(get_sqlalchemy_engine()) as db_session:
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
@@ -281,9 +287,9 @@ def check_for_document_sets_sync_task() -> None:
name="check_for_cc_pair_deletion_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_cc_pair_deletion_task() -> None:
def check_for_cc_pair_deletion_task(tenant_id: str | None) -> None:
"""Runs periodically to check if any deletion tasks should be run"""
with Session(get_sqlalchemy_engine()) as db_session:
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
# check if any document sets are not synced
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
@@ -293,6 +299,7 @@ def check_for_cc_pair_deletion_task() -> None:
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
tenant_id=tenant_id
),
)
@@ -303,7 +310,7 @@ def check_for_cc_pair_deletion_task() -> None:
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any) -> int:
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up
@@ -315,7 +322,7 @@ def kombu_message_cleanup_task(self: Any) -> int:
ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with Session(get_sqlalchemy_engine()) as db_session:
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
# Exit the task if we can't take the advisory lock
result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),
@@ -416,11 +423,11 @@ def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
name="check_for_prune_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_prune_task() -> None:
def check_for_prune_task(tenant_id: str | None) -> None:
"""Runs periodically to check if any prune tasks should be run and adds them
to the queue"""
with Session(get_sqlalchemy_engine()) as db_session:
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
all_cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in all_cc_pairs:
@@ -435,6 +442,7 @@ def check_for_prune_task() -> None:
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
tenant_id=tenant_id,
)
)
@@ -442,31 +450,33 @@ def check_for_prune_task() -> None:
#####
# Celery Beat (Periodic Tasks) Settings
#####
celery_app.conf.beat_schedule = {
"check-for-document-set-sync": {
"task": "check_for_document_sets_sync_task",
"schedule": timedelta(seconds=5),
},
"check-for-cc-pair-deletion": {
"task": "check_for_cc_pair_deletion_task",
# don't need to check too often, since we kick off a deletion initially
# during the API call that actually marks the CC pair for deletion
"schedule": timedelta(minutes=1),
},
}
celery_app.conf.beat_schedule.update(
{
"check-for-prune": {
def schedule_tenant_tasks() -> None:
tenants = get_all_tenant_ids()
for tenant_id in tenants:
# Schedule tasks specific to each tenant
celery_app.conf.beat_schedule[f"check-for-document-set-sync-{tenant_id}"] = {
"task": "check_for_document_sets_sync_task",
"schedule": timedelta(seconds=5),
"args": (tenant_id,),
}
celery_app.conf.beat_schedule[f"check-for-cc-pair-deletion-{tenant_id}"] = {
"task": "check_for_cc_pair_deletion_task",
"schedule": timedelta(seconds=5),
"args": (tenant_id,),
}
celery_app.conf.beat_schedule[f"check-for-prune-{tenant_id}"] = {
"task": "check_for_prune_task",
"schedule": timedelta(seconds=5),
},
}
)
celery_app.conf.beat_schedule.update(
{
"kombu-message-cleanup": {
"args": (tenant_id,),
}
# Schedule tasks that are not tenant-specific
celery_app.conf.beat_schedule["kombu-message-cleanup"] = {
"task": "kombu_message_cleanup_task",
"schedule": timedelta(seconds=3600),
},
}
)
"args": (tenant_id,),
}
schedule_tenant_tasks()

View File

@@ -35,6 +35,7 @@ from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from danswer.utils.variable_functionality import noop_fallback
from danswer.configs.app_configs import DEFAULT_SCHEMA
logger = setup_logger()
@@ -46,12 +47,13 @@ def delete_connector_credential_pair_batch(
connector_id: int,
credential_id: int,
document_index: DocumentIndex,
tenant_id: str | None
) -> None:
"""
Removes a batch of documents ids from a cc-pair. If no other cc-pair uses a document anymore
it gets permanently deleted.
"""
with Session(get_sqlalchemy_engine()) as db_session:
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
# acquire lock for all documents in this batch so that indexing can't
# override the deletion
with prepare_to_modify_documents(
@@ -124,6 +126,7 @@ def delete_connector_credential_pair(
db_session: Session,
document_index: DocumentIndex,
cc_pair: ConnectorCredentialPair,
tenant_id: str | None
) -> int:
connector_id = cc_pair.connector_id
credential_id = cc_pair.credential_id
@@ -135,6 +138,7 @@ def delete_connector_credential_pair(
connector_id=connector_id,
credential_id=credential_id,
limit=_DELETION_BATCH_SIZE,
)
if not documents:
break
@@ -144,6 +148,7 @@ def delete_connector_credential_pair(
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
tenant_id=tenant_id,
)
num_docs_deleted += len(documents)

View File

@@ -1,3 +1,4 @@
import time
import traceback
from datetime import datetime
@@ -5,7 +6,7 @@ from datetime import timedelta
from datetime import timezone
from sqlalchemy.orm import Session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
from danswer.background.indexing.tracer import DanswerTracer
from danswer.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
@@ -16,7 +17,6 @@ from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import IndexAttemptMetadata
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
from danswer.db.connector_credential_pair import update_connector_credential_pair
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import mark_attempt_failed
@@ -44,6 +44,7 @@ def _get_connector_runner(
attempt: IndexAttempt,
start_time: datetime,
end_time: datetime,
tenant_id: str | None
) -> ConnectorRunner:
"""
NOTE: `start_time` and `end_time` are only used for poll connectors
@@ -61,6 +62,7 @@ def _get_connector_runner(
attempt.connector_credential_pair.connector.connector_specific_config,
attempt.connector_credential_pair.credential,
db_session,
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
@@ -82,6 +84,7 @@ def _get_connector_runner(
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str | None
) -> None:
"""
1. Get documents which are either new or updated from specified application
@@ -102,6 +105,7 @@ def _run_indexing(
primary_index_name=index_name, secondary_index_name=None
)
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
)
@@ -113,6 +117,7 @@ def _run_indexing(
ignore_time_skip=index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE),
db_session=db_session,
tenant_id=tenant_id,
)
db_cc_pair = index_attempt.connector_credential_pair
@@ -169,6 +174,7 @@ def _run_indexing(
attempt=index_attempt,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id
)
all_connector_doc_ids: set[str] = set()
@@ -196,7 +202,7 @@ def _run_indexing(
db_session.refresh(index_attempt)
if index_attempt.status != IndexingStatus.IN_PROGRESS:
# Likely due to user manually disabling it or model swap
raise RuntimeError("Index Attempt was canceled")
raise RuntimeError(f"Index Attempt was canceled, status is {index_attempt.status}")
batch_description = []
for doc in doc_batch:
@@ -383,38 +389,30 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
return attempt
def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
def run_indexing_entrypoint(index_attempt_id: int, tenant_id: str | None, is_ee: bool = False) -> None:
try:
if is_ee:
global_version.set_ee()
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
with Session(get_sqlalchemy_engine()) as db_session:
# make sure that it is valid to run this indexing attempt + mark it
# as in progress
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
attempt = _prepare_index_attempt(db_session, index_attempt_id)
logger.info(
f"Indexing starting: "
f"Indexing starting for tenant {tenant_id}: " if tenant_id is not None else "" +
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
_run_indexing(db_session, attempt)
_run_indexing(db_session, attempt, tenant_id)
logger.info(
f"Indexing finished: "
f"Indexing finished for tenant {tenant_id}: " if tenant_id is not None else "" +
f"connector='{attempt.connector_credential_pair.connector.name}' "
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
f"credentials='{attempt.connector_credential_pair.connector_id}'"
)
except Exception as e:
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
logger.exception(f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}")

View File

@@ -14,8 +14,11 @@ from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str:
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
def name_cc_cleanup_task(connector_id: int, credential_id: int, tenant_id: str | None = None) -> str:
task_name = f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
if tenant_id is not None:
task_name += f"_{tenant_id}"
return task_name
def name_document_set_sync_task(document_set_id: int) -> str:

View File

@@ -8,6 +8,7 @@ from dask.distributed import Future
from distributed import LocalCluster
from sqlalchemy.orm import Session
from sqlalchemy import text
from danswer.background.indexing.dask_utils import ResourceLogger
from danswer.background.indexing.job_client import SimpleJob
from danswer.background.indexing.job_client import SimpleJobClient
@@ -45,7 +46,8 @@ from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import LOG_LEVEL
from shared_configs.configs import MODEL_SERVER_PORT
from danswer.configs.app_configs import MULTI_TENANT
from sqlalchemy.exc import ProgrammingError
logger = setup_logger()
@@ -143,13 +145,14 @@ def _mark_run_failed(
"""Main funcs"""
def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None) -> None:
"""Creates new indexing jobs for each connector / credential pair which is:
1. Enabled
2. `refresh_frequency` time has passed since the last indexing run for this pair
3. There is not already an ongoing indexing attempt for this pair
"""
with Session(get_sqlalchemy_engine()) as db_session:
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
ongoing: set[tuple[int | None, int]] = set()
for attempt_id in existing_jobs:
attempt = get_index_attempt(
@@ -204,12 +207,13 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
def cleanup_indexing_jobs(
existing_jobs: dict[int, Future | SimpleJob],
tenant_id: str | None,
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
) -> dict[int, Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
# clean up completed jobs
with Session(get_sqlalchemy_engine()) as db_session:
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
for attempt_id, job in existing_jobs.items():
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=attempt_id
@@ -247,38 +251,42 @@ def cleanup_indexing_jobs(
)
# clean up in-progress jobs that were never completed
connectors = fetch_connectors(db_session)
for connector in connectors:
in_progress_indexing_attempts = get_inprogress_index_attempts(
connector.id, db_session
)
for index_attempt in in_progress_indexing_attempts:
if index_attempt.id in existing_jobs:
# If index attempt is canceled, stop the run
if index_attempt.status == IndexingStatus.FAILED:
existing_jobs[index_attempt.id].cancel()
# check to see if the job has been updated in last `timeout_hours` hours, if not
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
# on the fact that the `time_updated` field is constantly updated every
# batch of documents indexed
current_db_time = get_db_current_time(db_session=db_session)
time_since_update = current_db_time - index_attempt.time_updated
if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
existing_jobs[index_attempt.id].cancel()
try:
connectors = fetch_connectors(db_session)
for connector in connectors:
in_progress_indexing_attempts = get_inprogress_index_attempts(
connector.id, db_session
)
for index_attempt in in_progress_indexing_attempts:
if index_attempt.id in existing_jobs:
# If index attempt is canceled, stop the run
if index_attempt.status == IndexingStatus.FAILED:
existing_jobs[index_attempt.id].cancel()
# check to see if the job has been updated in last `timeout_hours` hours, if not
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
# on the fact that the `time_updated` field is constantly updated every
# batch of documents indexed
current_db_time = get_db_current_time(db_session=db_session)
time_since_update = current_db_time - index_attempt.time_updated
if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
existing_jobs[index_attempt.id].cancel()
_mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason="Indexing run frozen - no updates in the last three hours. "
"The run will be re-attempted at next scheduled indexing time.",
)
else:
# If job isn't known, simply mark it as failed
_mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason="Indexing run frozen - no updates in the last three hours. "
"The run will be re-attempted at next scheduled indexing time.",
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
)
else:
# If job isn't known, simply mark it as failed
_mark_run_failed(
db_session=db_session,
index_attempt=index_attempt,
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
)
except ProgrammingError as _:
logger.debug(f"No Connector Table exists for: {tenant_id}")
pass
return existing_jobs_copy
@@ -286,9 +294,11 @@ def kickoff_indexing_jobs(
existing_jobs: dict[int, Future | SimpleJob],
client: Client | SimpleJobClient,
secondary_client: Client | SimpleJobClient,
tenant_id: str | None,
) -> dict[int, Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
engine = get_sqlalchemy_engine()
engine = get_sqlalchemy_engine(schema=tenant_id)
# Don't include jobs waiting in the Dask queue that just haven't started running
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
@@ -337,13 +347,16 @@ def kickoff_indexing_jobs(
run = secondary_client.submit(
run_indexing_entrypoint,
attempt.id,
tenant_id,
global_version.get_is_ee_version(),
pure=False,
)
else:
run = client.submit(
run_indexing_entrypoint,
attempt.id,
tenant_id,
global_version.get_is_ee_version(),
pure=False,
)
@@ -376,41 +389,32 @@ def kickoff_indexing_jobs(
return existing_jobs_copy
def get_all_tenant_ids() -> list[str] | list[None]:
if not MULTI_TENANT:
return [None]
with Session(get_sqlalchemy_engine(schema='public')) as session:
result = session.execute(text("""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')
"""))
tenant_ids = [row[0] for row in result]
valid_tenants = [tenant for tenant in tenant_ids if tenant is None or not tenant.startswith('pg_')]
return valid_tenants
def update_loop(
delay: int = 10,
num_workers: int = NUM_INDEXING_WORKERS,
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
) -> None:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if search_settings.provider_type is None:
logger.notice("Running a first inference to warm up embedding model")
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(
embedding_model=embedding_model,
)
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient
if DASK_JOB_CLIENT_ENABLED:
cluster_primary = LocalCluster(
n_workers=num_workers,
threads_per_worker=1,
# there are warning about high memory usage + "Event loop unresponsive"
# which are not relevant to us since our workers are expected to use a
# lot of memory + involve CPU intensive tasks that will not relinquish
# the event loop
silence_logs=logging.ERROR,
)
cluster_secondary = LocalCluster(
@@ -426,37 +430,70 @@ def update_loop(
client_primary = SimpleJobClient(n_workers=num_workers)
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
existing_jobs: dict[int, Future | SimpleJob] = {}
existing_jobs: dict[str | None, dict[int, Future | SimpleJob]] = {}
logger.notice("Startup complete. Waiting for indexing jobs...")
while True:
start = time.time()
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
logger.debug(f"Running update, current UTC time: {start_time_utc}")
if existing_jobs:
# TODO: make this debug level once the "no jobs are being scheduled" issue is resolved
logger.debug(
"Found existing indexing jobs: "
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
f"{[(tenant_id, list(jobs.keys())) for tenant_id, jobs in existing_jobs.items()]}"
)
try:
with Session(get_sqlalchemy_engine()) as db_session:
check_index_swap(db_session)
existing_jobs = cleanup_indexing_jobs(existing_jobs=existing_jobs)
create_indexing_jobs(existing_jobs=existing_jobs)
existing_jobs = kickoff_indexing_jobs(
existing_jobs=existing_jobs,
client=client_primary,
secondary_client=client_secondary,
)
tenants = get_all_tenant_ids()
for tenant_id in tenants:
try:
logger.debug(f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}")
engine = get_sqlalchemy_engine(schema=tenant_id)
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
if not MULTI_TENANT:
search_settings = get_current_search_settings(db_session)
if search_settings.provider_type is None:
logger.notice("Running a first inference to warm up embedding model")
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
logger.notice("First inference complete.")
tenant_jobs = existing_jobs.get(tenant_id, {})
tenant_jobs = cleanup_indexing_jobs(
existing_jobs=tenant_jobs,
tenant_id=tenant_id
)
create_indexing_jobs(
existing_jobs=tenant_jobs,
tenant_id=tenant_id
)
tenant_jobs = kickoff_indexing_jobs(
existing_jobs=tenant_jobs,
client=client_primary,
secondary_client=client_secondary,
tenant_id=tenant_id,
)
existing_jobs[tenant_id] = tenant_jobs
except Exception as e:
logger.exception(f"Failed to process tenant {tenant_id or 'default'}: {e}")
except Exception as e:
logger.exception(f"Failed to run update due to {e}")
sleep_time = delay - (time.time() - start)
if sleep_time > 0:
time.sleep(sleep_time)
def update__main() -> None:
set_is_ee_based_on_env_variable()
init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME)

View File

@@ -6,7 +6,6 @@ from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Persona
@@ -18,148 +17,156 @@ from danswer.db.persona import upsert_prompt
from danswer.search.enums import RecencyBiasSetting
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
def load_prompts_from_yaml(
db_session: Session,
prompts_yaml: str = PROMPTS_YAML
) -> None:
with open(prompts_yaml, "r") as file:
data = yaml.safe_load(file)
all_prompts = data.get("prompts", [])
with Session(get_sqlalchemy_engine()) as db_session:
for prompt in all_prompts:
upsert_prompt(
user=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
description=prompt["description"].strip(),
system_prompt=prompt["system"].strip(),
task_prompt=prompt["task"].strip(),
include_citations=prompt["include_citations"],
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
db_session=db_session,
commit=True,
)
for prompt in all_prompts:
upsert_prompt(
user=None,
prompt_id=prompt.get("id"),
name=prompt["name"],
description=prompt["description"].strip(),
system_prompt=prompt["system"].strip(),
task_prompt=prompt["task"].strip(),
include_citations=prompt["include_citations"],
datetime_aware=prompt.get("datetime_aware", True),
default_prompt=True,
personas=None,
db_session=db_session,
commit=True,
)
def load_personas_from_yaml(
db_session: Session,
personas_yaml: str = PERSONAS_YAML,
default_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
) -> None:
with open(personas_yaml, "r") as file:
data = yaml.safe_load(file)
all_personas = data.get("personas", [])
with Session(get_sqlalchemy_engine()) as db_session:
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
for persona in all_personas:
doc_set_names = persona["document_sets"]
doc_sets: list[DocumentSetDBModel] = [
get_or_create_document_set_by_name(db_session, name)
for name in doc_set_names
]
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
doc_set_ids = None
prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")
# Assume if user hasn't set any document sets for the persona, the user may want
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
# the document sets for the persona
doc_set_ids: list[int] | None = None
if doc_sets:
doc_set_ids = [doc_set.id for doc_set in doc_sets]
else:
doc_set_ids = None
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
prompt_ids: list[int] | None = None
prompt_set_names = persona["prompts"]
if prompt_set_names:
prompts: list[PromptDBModel | None] = [
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
for prompt_name in prompt_set_names
]
if any([prompt is None for prompt in prompts]):
raise ValueError("Invalid Persona configs, not all prompts exist")
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)
.filter(ToolDBModel.name == "ImageGenerationTool")
.first()
)
if image_gen_tool:
tool_ids.append(image_gen_tool.id)
llm_model_provider_override = persona.get("llm_model_provider_override")
llm_model_version_override = persona.get("llm_model_version_override")
# Set specific overrides for image generation persona
if persona.get("image_generation"):
llm_model_version_override = "gpt-4o"
existing_persona = (
db_session.query(Persona)
.filter(Persona.name == persona["name"])
p_id = persona.get("id")
tool_ids = []
if persona.get("image_generation"):
image_gen_tool = (
db_session.query(ToolDBModel)
.filter(ToolDBModel.name == "ImageGenerationTool")
.first()
)
if image_gen_tool:
tool_ids.append(image_gen_tool.id)
upsert_persona(
user=None,
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
description=persona["description"],
num_chunks=persona.get("num_chunks")
if persona.get("num_chunks") is not None
else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
icon_shape=persona.get("icon_shape"),
icon_color=persona.get("icon_color"),
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
tool_ids=tool_ids,
default_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority"),
is_visible=existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible"),
db_session=db_session,
)
llm_model_provider_override = persona.get("llm_model_provider_override")
llm_model_version_override = persona.get("llm_model_version_override")
# Set specific overrides for image generation persona
if persona.get("image_generation"):
llm_model_version_override = "gpt-4o"
def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None:
existing_persona = (
db_session.query(Persona)
.filter(Persona.name == persona["name"])
.first()
)
upsert_persona(
user=None,
persona_id=(-1 * p_id) if p_id is not None else None,
name=persona["name"],
description=persona["description"],
num_chunks=persona.get("num_chunks")
if persona.get("num_chunks") is not None
else default_chunks,
llm_relevance_filter=persona.get("llm_relevance_filter"),
starter_messages=persona.get("starter_messages"),
llm_filter_extraction=persona.get("llm_filter_extraction"),
icon_shape=persona.get("icon_shape"),
icon_color=persona.get("icon_color"),
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
tool_ids=tool_ids,
default_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority"),
is_visible=existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible"),
db_session=db_session,
)
def load_input_prompts_from_yaml(
db_session: Session,
input_prompts_yaml: str = INPUT_PROMPT_YAML
) -> None:
with open(input_prompts_yaml, "r") as file:
data = yaml.safe_load(file)
all_input_prompts = data.get("input_prompts", [])
with Session(get_sqlalchemy_engine()) as db_session:
for input_prompt in all_input_prompts:
# If these prompts are deleted (which is a hard delete in the DB), on server startup
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
insert_input_prompt_if_not_exists(
user=None,
input_prompt_id=input_prompt.get("id"),
prompt=input_prompt["prompt"],
content=input_prompt["content"],
is_public=input_prompt["is_public"],
active=input_prompt.get("active", True),
db_session=db_session,
commit=True,
)
for input_prompt in all_input_prompts:
# If these prompts are deleted (which is a hard delete in the DB), on server startup
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
insert_input_prompt_if_not_exists(
user=None,
input_prompt_id=input_prompt.get("id"),
prompt=input_prompt["prompt"],
content=input_prompt["content"],
is_public=input_prompt["is_public"],
active=input_prompt.get("active", True),
db_session=db_session,
commit=True,
)
def load_chat_yamls(
db_session: Session,
prompt_yaml: str = PROMPTS_YAML,
personas_yaml: str = PERSONAS_YAML,
input_prompts_yaml: str = INPUT_PROMPT_YAML,
) -> None:
load_prompts_from_yaml(prompt_yaml)
load_personas_from_yaml(personas_yaml)
load_input_prompts_from_yaml(input_prompts_yaml)
load_prompts_from_yaml(db_session, prompt_yaml)
load_personas_from_yaml(db_session, personas_yaml)
load_input_prompts_from_yaml(db_session, input_prompts_yaml)

View File

@@ -32,7 +32,6 @@ from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import reserve_message_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import ToolCall
@@ -314,12 +313,14 @@ def stream_chat_message_objects(
try:
llm, fast_llm = get_llms_for_persona(
persona=persona,
db_session=db_session,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
additional_headers=litellm_additional_headers,
)
except GenAIDisabledException:
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
llm_provider = llm.config.model_provider
llm_model_name = llm.config.model_name
@@ -631,6 +632,7 @@ def stream_chat_message_objects(
or get_main_llm_from_tuple(
get_llms_for_persona(
persona=persona,
db_session=db_session,
llm_override=(
new_msg_req.llm_override or chat_session.llm_override
),
@@ -799,18 +801,19 @@ def stream_chat_message_objects(
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
db_session: Session,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
) -> Iterator[str]:
with get_session_context_manager() as db_session:
objects = stream_chat_message_objects(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
is_connected=is_connected,
)
for obj in objects:
yield get_json_line(obj.model_dump())
objects = stream_chat_message_objects(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
is_connected=is_connected,
)
for obj in objects:
yield get_json_line(obj.model_dump())

View File

@@ -37,9 +37,11 @@ DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "
WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY") or "JWT_SECRET_KEY"
#####
# Auth Configs
#####
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
@@ -134,7 +136,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password"
)
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
# defaults to False
@@ -366,3 +368,21 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
ENTERPRISE_EDITION_ENABLED = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)
###
# CLOUD CONFIGS
###
STRIPE_PRICE = os.environ.get("STRIPE_PRICE", "price_1PsYoPHlhTYqRZib2t5ydpq5")
STRIPE_WEBHOOK_SECRET = os.environ.get(
"STRIPE_WEBHOOK_SECRET",
"whsec_1cd766cd6bd08590aa8c46ab5c21ac32cad77c29de2e09a152a01971d6f405d3"
)
DEFAULT_SCHEMA = os.environ.get("DEFAULT_SCHEMA", "public")
DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "your_shared_secret_key")
EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "your_control_plane_api_key")
MULTI_TENANT = os.environ.get("MULTI_TENANT", "false").lower() == "true"

View File

@@ -88,3 +88,4 @@ HARD_DELETE_CHATS = False
# Internet Search
BING_API_KEY = os.environ.get("BING_API_KEY") or None
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)

View File

@@ -239,7 +239,6 @@ def _datetime_from_string(datetime_string: str) -> datetime:
else:
# If not in UTC, translate it
datetime_object = datetime_object.astimezone(timezone.utc)
return datetime_object

View File

@@ -159,10 +159,12 @@ class LocalFileConnector(LoadConnector):
self,
file_locations: list[Path | str],
batch_size: int = INDEX_BATCH_SIZE,
tenant_id: str | None = None
) -> None:
self.file_locations = [Path(file_location) for file_location in file_locations]
self.batch_size = batch_size
self.pdf_pass: str | None = None
self.tenant_id = tenant_id
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.pdf_pass = credentials.get("pdf_password")
@@ -170,7 +172,7 @@ class LocalFileConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
with Session(get_sqlalchemy_engine()) as db_session:
with Session(get_sqlalchemy_engine(schema=self.tenant_id)) as db_session:
for file_path in self.file_locations:
current_datetime = datetime.now(timezone.utc)
files = _read_files_and_metadata(

View File

@@ -154,7 +154,7 @@ def handle_regular_answer(
get_editable=False,
),
)
llm, _ = get_llms_for_persona(persona)
llm, _ = get_llms_for_persona(persona, db_session=db_session)
# In cases of threads, split the available tokens between docs and thread context
input_tokens = get_max_input_tokens(
@@ -171,6 +171,7 @@ def handle_regular_answer(
persona=persona,
actual_user_input=query_text,
max_llm_token_override=remaining_tokens,
db_session=db_session,
)
else:
max_document_tokens = (

View File

@@ -51,6 +51,7 @@ def get_chat_session_by_id(
is_shared: bool = False,
) -> ChatSession:
stmt = select(ChatSession).where(ChatSession.id == chat_session_id)
db_session.connection()
if is_shared:
stmt = stmt.where(ChatSession.shared_status == ChatSessionSharedStatus.PUBLIC)
@@ -86,7 +87,6 @@ def get_chat_sessions_by_slack_thread_id(
)
return db_session.scalars(stmt).all()
def get_first_messages_for_chat_sessions(
chat_session_ids: list[int], db_session: Session
) -> dict[int, str]:

View File

@@ -1,10 +1,12 @@
import contextvars
from fastapi import Depends
from fastapi import Request, HTTPException
import contextlib
import time
from collections.abc import AsyncGenerator
from collections.abc import Generator
from datetime import datetime
from typing import ContextManager
from sqlalchemy import event
from sqlalchemy import text
from sqlalchemy.engine import create_engine
@@ -14,7 +16,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.app_configs import DEFAULT_SCHEMA
from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
from danswer.configs.app_configs import LOG_POSTGRES_LATENCY
from danswer.configs.app_configs import POSTGRES_DB
@@ -25,10 +28,18 @@ from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from danswer.configs.app_configs import MULTI_TENANT
from danswer.utils.logger import setup_logger
from fastapi.security import OAuth2PasswordBearer
from jwt.exceptions import DecodeError, InvalidTokenError
import jwt
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
logger = setup_logger()
SYNC_DB_API = "psycopg2"
ASYNC_DB_API = "asyncpg"
@@ -128,28 +139,32 @@ def init_sqlalchemy_engine(app_name: str) -> None:
global POSTGRES_APP_NAME
POSTGRES_APP_NAME = app_name
_engines: dict[str, Engine] = {}
def get_sqlalchemy_engine() -> Engine:
global _SYNC_ENGINE
if _SYNC_ENGINE is None:
# NOTE: this is a hack to allow for multiple postgres schemas per engine for now.
def get_sqlalchemy_engine(*, schema: str | None = DEFAULT_SCHEMA) -> Engine:
if schema is None:
schema = current_tenant_id.get()
global _engines
if schema not in _engines:
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync"
db_api=SYNC_DB_API, app_name=f"{POSTGRES_APP_NAME}_{schema}_sync"
)
_SYNC_ENGINE = create_engine(
_engines[schema] = create_engine(
connection_string,
pool_size=40,
max_overflow=10,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
connect_args={"options": f"-c search_path={schema}"}
)
return _SYNC_ENGINE
return _engines[schema]
def get_sqlalchemy_async_engine() -> AsyncEngine:
global _ASYNC_ENGINE
if _ASYNC_ENGINE is None:
# underlying asyncpg cannot accept application_name directly in the connection string
# https://github.com/MagicStack/asyncpg/issues/798
connection_string = build_connection_string()
_ASYNC_ENGINE = create_async_engine(
connection_string,
@@ -163,26 +178,50 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
)
return _ASYNC_ENGINE
current_tenant_id = contextvars.ContextVar(
"current_tenant_id", default=DEFAULT_SCHEMA
)
def get_session_context_manager() -> ContextManager[Session]:
return contextlib.contextmanager(get_session)()
tenant_id = current_tenant_id.get()
return contextlib.contextmanager(lambda: get_session(override_tenant_id=tenant_id))()
def get_current_tenant_id(request: Request) -> str | None:
if not MULTI_TENANT:
return DEFAULT_SCHEMA
def get_session() -> Generator[Session, None, None]:
# The line below was added to monitor the latency caused by Postgres connections
# during API calls.
# with tracer.trace("db.get_session"):
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
token = request.cookies.get("tenant_details")
if not token:
return current_tenant_id.get()
try:
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
tenant_id = payload.get("tenant_id")
if not tenant_id:
raise HTTPException(status_code=400, detail="Invalid token: tenant_id missing")
current_tenant_id.set(tenant_id)
return tenant_id
except (DecodeError, InvalidTokenError):
raise HTTPException(status_code=401, detail="Invalid token format")
except Exception:
raise HTTPException(status_code=500, detail="Internal server error")
def get_session(
tenant_id: str = Depends(get_current_tenant_id),
override_tenant_id: str | None = None
) -> Generator[Session, None, None]:
if override_tenant_id:
tenant_id = override_tenant_id
with Session(get_sqlalchemy_engine(schema=tenant_id), expire_on_commit=False) as session:
yield session
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
async def get_async_session(tenant_id: str | None = None) -> AsyncGenerator[AsyncSession, None]:
async with AsyncSession(
get_sqlalchemy_async_engine(), expire_on_commit=False
) as async_session:
yield async_session
async def warm_up_connections(
sync_connections_to_warm_up: int = 20, async_connections_to_warm_up: int = 20
) -> None:
@@ -190,6 +229,7 @@ async def warm_up_connections(
connections = [
sync_postgres_engine.connect() for _ in range(sync_connections_to_warm_up)
]
for conn in connections:
conn.execute(text("SELECT 1"))
for conn in connections:
@@ -205,7 +245,6 @@ async def warm_up_connections(
for async_conn in async_connections:
await async_conn.close()
def get_session_factory() -> sessionmaker[Session]:
global SessionFactory
if SessionFactory is None:

View File

@@ -135,13 +135,16 @@ def fetch_embedding_provider(
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_provider == True # noqa: E712
)
)
if not provider_model:
return None
return FullLLMProvider.from_model(provider_model)

View File

@@ -128,10 +128,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
TIMESTAMPAware(timezone=True), nullable=True
)
tenant_id: Mapped[str] = mapped_column(Text, nullable=True)
default_model: Mapped[str] = mapped_column(Text, nullable=True)
# organized in typical structured fashion
# formatted as `displayName__provider__modelName`
default_model: Mapped[str] = mapped_column(Text, nullable=True)
# relationships
credentials: Mapped[list["Credential"]] = relationship(
@@ -1184,7 +1185,6 @@ class Tool(Base):
# user who created / owns the tool. Will be None for built-in tools.
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
# Relationship to Persona through the association table
personas: Mapped[list["Persona"]] = relationship(

View File

@@ -1,22 +1,11 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.indexing.models import IndexingSetting
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.search.models import SavedSearchSettings
from danswer.server.manage.embedding.models import (
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
@@ -174,76 +163,3 @@ def update_search_settings_status(
search_settings.status = new_status
db_session.commit()
def user_has_overridden_embedding_model() -> bool:
return DOCUMENT_ENCODER_MODEL != DEFAULT_DOCUMENT_ENCODER_MODEL
def get_old_default_search_settings() -> SearchSettings:
is_overridden = user_has_overridden_embedding_model()
return SearchSettings(
model_name=(
DOCUMENT_ENCODER_MODEL
if is_overridden
else OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
),
model_dim=(
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
),
normalize=(
NORMALIZE_EMBEDDINGS
if is_overridden
else OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
),
query_prefix=(ASYM_QUERY_PREFIX if is_overridden else ""),
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
status=IndexModelStatus.PRESENT,
index_name="danswer_chunk",
)
def get_new_default_search_settings(is_present: bool) -> SearchSettings:
return SearchSettings(
model_name=DOCUMENT_ENCODER_MODEL,
model_dim=DOC_EMBEDDING_DIM,
normalize=NORMALIZE_EMBEDDINGS,
query_prefix=ASYM_QUERY_PREFIX,
passage_prefix=ASYM_PASSAGE_PREFIX,
status=IndexModelStatus.PRESENT if is_present else IndexModelStatus.FUTURE,
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
)
def get_old_default_embedding_model() -> IndexingSetting:
is_overridden = user_has_overridden_embedding_model()
return IndexingSetting(
model_name=(
DOCUMENT_ENCODER_MODEL
if is_overridden
else OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
),
model_dim=(
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
),
normalize=(
NORMALIZE_EMBEDDINGS
if is_overridden
else OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
),
query_prefix=(ASYM_QUERY_PREFIX if is_overridden else ""),
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
index_name="danswer_chunk",
multipass_indexing=False,
)
def get_new_default_embedding_model() -> IndexingSetting:
return IndexingSetting(
model_name=DOCUMENT_ENCODER_MODEL,
model_dim=DOC_EMBEDDING_DIM,
normalize=NORMALIZE_EMBEDDINGS,
query_prefix=ASYM_QUERY_PREFIX,
passage_prefix=ASYM_PASSAGE_PREFIX,
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
multipass_indexing=False,
)

View File

@@ -0,0 +1,36 @@
from danswer.llm.llm_initialization import load_llm_providers
from danswer.db.connector import create_initial_default_connector
from danswer.db.connector_credential_pair import associate_default_cc_pair
from danswer.db.credentials import create_initial_public_credential
from danswer.db.standard_answer import create_initial_default_standard_answer_category
from danswer.db.persona import delete_old_default_personas
from danswer.chat.load_yamls import load_chat_yamls
from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
from danswer.tools.built_in_tools import load_builtin_tools
from danswer.tools.built_in_tools import refresh_built_in_tools_cache
from danswer.utils.logger import setup_logger
from sqlalchemy.orm import Session
logger = setup_logger()
def setup_postgres(db_session: Session) -> None:
logger.notice("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.notice("Verifying default standard answer category exists.")
create_initial_default_standard_answer_category(db_session)
logger.notice("Loading LLM providers from env variables")
load_llm_providers(db_session)
logger.notice("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls(db_session)
logger.notice("Loading built-in tools")
load_builtin_tools(db_session)
refresh_built_in_tools_cache(db_session)
auto_add_search_tool_to_personas(db_session)

View File

@@ -3,13 +3,21 @@ from danswer.document_index.vespa.index import VespaIndex
def get_default_document_index(
primary_index_name: str,
secondary_index_name: str | None,
primary_index_name: str | None = None,
indices: list[str] | None = None,
secondary_index_name: str | None = None
) -> DocumentIndex:
"""Primary index is the index that is used for querying/updating etc.
Secondary index is for when both the currently used index and the upcoming
index both need to be updated, updates are applied to both indices"""
# Currently only supporting Vespa
indices = [primary_index_name] if primary_index_name is not None else indices
if not indices:
raise ValueError("No indices provided")
return VespaIndex(
index_name=primary_index_name, secondary_index_name=secondary_index_name
indices=indices,
secondary_index_name=secondary_index_name
)

View File

@@ -77,7 +77,7 @@ class Verifiable(abc.ABC):
all valid in the schema.
Parameters:
- index_name: The name of the primary index currently used for querying
- indices: The names of the primary indices currently used for querying
- secondary_index_name: The name of the secondary index being built in the background, if it
currently exists. Some functions on the document index act on both the primary and
secondary index, some act on just one.
@@ -86,20 +86,21 @@ class Verifiable(abc.ABC):
@abc.abstractmethod
def __init__(
self,
index_name: str,
indices: list[str],
secondary_index_name: str | None,
*args: Any,
**kwargs: Any
) -> None:
super().__init__(*args, **kwargs)
self.index_name = index_name
self.indices = indices
self.secondary_index_name = secondary_index_name
@abc.abstractmethod
def ensure_indices_exist(
self,
index_embedding_dim: int,
secondary_index_embedding_dim: int | None,
embedding_dims: list[int] | None = None,
index_embedding_dim: int | None = None,
secondary_index_embedding_dim: int | None = None
) -> None:
"""
Verify that the document index exists and is consistent with the expectations in the code.

View File

@@ -1,5 +1,6 @@
schema DANSWER_CHUNK_NAME {
document DANSWER_CHUNK_NAME {
TENANT_ID_REPLACEMENT
# Not to be confused with the UUID generated for this chunk which is called documentid by default
field document_id type string {
indexing: summary | attribute

View File

@@ -335,6 +335,8 @@ def query_vespa(
return inference_chunks
def _get_chunks_via_batch_search(
index_name: str,
chunk_requests: list[VespaChunkRequest],

View File

@@ -13,6 +13,7 @@ from typing import cast
import httpx
import requests
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.chat_configs import DOC_TIME_DECAY
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.chat_configs import TITLE_CONTENT_RATIO
@@ -46,6 +47,8 @@ from danswer.document_index.vespa_constants import BATCH_SIZE
from danswer.document_index.vespa_constants import BOOST
from danswer.document_index.vespa_constants import CONTENT_SUMMARY
from danswer.document_index.vespa_constants import DANSWER_CHUNK_REPLACEMENT_PAT
from danswer.document_index.vespa_constants import TENANT_ID_PAT
from danswer.document_index.vespa_constants import TENANT_ID_REPLACEMENT
from danswer.document_index.vespa_constants import DATE_REPLACEMENT
from danswer.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from danswer.document_index.vespa_constants import DOCUMENT_REPLACEMENT_PAT
@@ -83,7 +86,7 @@ def in_memory_zip_from_file_bytes(file_contents: dict[str, bytes]) -> BinaryIO:
return zip_buffer
def _create_document_xml_lines(doc_names: list[str | None]) -> str:
def _create_document_xml_lines(doc_names: list[str]) -> str:
doc_lines = [
f'<document type="{doc_name}" mode="index" />'
for doc_name in doc_names
@@ -108,15 +111,29 @@ def add_ngrams_to_schema(schema_content: str) -> str:
class VespaIndex(DocumentIndex):
def __init__(self, index_name: str, secondary_index_name: str | None) -> None:
self.index_name = index_name
def __init__(self, indices: list[str], secondary_index_name: str | None) -> None:
self.indices = indices
self.secondary_index_name = secondary_index_name
@property
def index_name(self) -> str:
if len(self.indices) == 0:
raise ValueError("No indices provided")
return self.indices[0]
def ensure_indices_exist(
self,
index_embedding_dim: int,
secondary_index_embedding_dim: int | None,
embedding_dims: list[int] | None = None,
index_embedding_dim: int | None = None,
secondary_index_embedding_dim: int | None = None
) -> None:
if embedding_dims is None:
if index_embedding_dim is not None:
embedding_dims = [index_embedding_dim]
else:
raise ValueError("No embedding dimensions provided")
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
logger.debug(f"Sending Vespa zip to {deploy_url}")
@@ -130,9 +147,15 @@ class VespaIndex(DocumentIndex):
with open(services_file, "r") as services_f:
services_template = services_f.read()
schema_names = [self.index_name, self.secondary_index_name]
# Generate schema names from index settings
schema_names = [index_name for index_name in self.indices]
full_schemas = schema_names
if self.secondary_index_name:
full_schemas.append(self.secondary_index_name)
doc_lines = _create_document_xml_lines(full_schemas)
doc_lines = _create_document_xml_lines(schema_names)
services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines)
kv_store = get_dynamic_config_store()
@@ -160,27 +183,38 @@ class VespaIndex(DocumentIndex):
with open(schema_file, "r") as schema_f:
schema_template = schema_f.read()
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, self.index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
zip_dict[f"schemas/{schema_names[0]}.sd"] = schema.encode("utf-8")
for i, index_name in enumerate(self.indices):
embedding_dim = embedding_dims[i]
logger.info(f"Creating index: {index_name} with embedding dimension: {embedding_dim}")
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim))
schema = schema.replace(TENANT_ID_PAT, TENANT_ID_REPLACEMENT if MULTI_TENANT else "")
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
zip_dict[f"schemas/{index_name}.sd"] = schema.encode("utf-8")
if self.secondary_index_name:
logger.info("Creating secondary index:"
f"{self.secondary_index_name} with embedding dimension: {secondary_index_embedding_dim}")
upcoming_schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, self.secondary_index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(secondary_index_embedding_dim))
zip_dict[f"schemas/{schema_names[1]}.sd"] = upcoming_schema.encode("utf-8")
upcoming_schema = upcoming_schema.replace(TENANT_ID_PAT, TENANT_ID_REPLACEMENT if MULTI_TENANT else "")
zip_dict[f"schemas/{self.secondary_index_name}.sd"] = upcoming_schema.encode("utf-8")
zip_file = in_memory_zip_from_file_bytes(zip_dict)
headers = {"Content-Type": "application/zip"}
response = requests.post(deploy_url, headers=headers, data=zip_file)
if response.status_code != 200:
raise RuntimeError(
f"Failed to prepare Vespa Danswer Index. Response: {response.text}"
f"Failed to prepare Vespa Danswer Indexes. Response: {response.text}"
)
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
@@ -230,7 +264,6 @@ class VespaIndex(DocumentIndex):
)
all_doc_ids = {chunk.source_document.id for chunk in cleaned_chunks}
return {
DocumentInsertionRecord(
document_id=doc_id,
@@ -282,7 +315,7 @@ class VespaIndex(DocumentIndex):
raise requests.HTTPError(failure_msg) from e
def update(self, update_requests: list[UpdateRequest]) -> None:
logger.info(f"Updating {len(update_requests)} documents in Vespa")
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
# Handle Vespa character limitations
# Mutating update_requests but it's not used later anyway
@@ -371,6 +404,91 @@ class VespaIndex(DocumentIndex):
time.monotonic() - update_start,
)
def update_single(self, update_request: UpdateRequest) -> None:
"""Note: if the document id does not exist, the update will be a no-op and the
function will complete with no errors or exceptions.
Handle other exceptions if you wish to implement retry behavior
"""
if len(update_request.document_ids) != 1:
raise ValueError("update_request must contain a single document id")
# Handle Vespa character limitations
# Mutating update_request but it's not used later anyway
update_request.document_ids = [
replace_invalid_doc_id_characters(doc_id)
for doc_id in update_request.document_ids
]
update_start = time.monotonic()
# Fetch all chunks for each document ahead of time
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
chunk_id_start_time = time.monotonic()
all_doc_chunk_ids: list[str] = []
for index_name in index_names:
for document_id in update_request.document_ids:
# this calls vespa and can raise http exceptions
doc_chunk_ids = get_all_vespa_ids_for_document_id(
document_id=document_id,
index_name=index_name,
filters=None,
get_large_chunks=True,
)
all_doc_chunk_ids.extend(doc_chunk_ids)
logger.debug(
f"Took {time.monotonic() - chunk_id_start_time:.2f} seconds to fetch all Vespa chunk IDs"
)
# Build the _VespaUpdateRequest objects
update_dict: dict[str, dict] = {"fields": {}}
if update_request.boost is not None:
update_dict["fields"][BOOST] = {"assign": update_request.boost}
if update_request.document_sets is not None:
update_dict["fields"][DOCUMENT_SETS] = {
"assign": {
document_set: 1 for document_set in update_request.document_sets
}
}
if update_request.access is not None:
update_dict["fields"][ACCESS_CONTROL_LIST] = {
"assign": {acl_entry: 1 for acl_entry in update_request.access.to_acl()}
}
if update_request.hidden is not None:
update_dict["fields"][HIDDEN] = {"assign": update_request.hidden}
if not update_dict["fields"]:
logger.error("Update request received but nothing to update")
return
processed_update_requests: list[_VespaUpdateRequest] = []
for document_id in update_request.document_ids:
for doc_chunk_id in all_doc_chunk_ids:
processed_update_requests.append(
_VespaUpdateRequest(
document_id=document_id,
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
update_request=update_dict,
)
)
with httpx.Client(http2=True) as http_client:
for update in processed_update_requests:
http_client.put(
update.url,
headers={"Content-Type": "application/json"},
json=update.update_request,
)
logger.debug(
"Finished updating Vespa documents in %.2f seconds",
time.monotonic() - update_start,
)
return
def delete(self, doc_ids: list[str]) -> None:
logger.info(f"Deleting {len(doc_ids)} documents from Vespa")

View File

@@ -21,6 +21,7 @@ from danswer.document_index.vespa_constants import CHUNK_ID
from danswer.document_index.vespa_constants import CONTENT
from danswer.document_index.vespa_constants import CONTENT_SUMMARY
from danswer.document_index.vespa_constants import DOC_UPDATED_AT
from danswer.document_index.vespa_constants import TENANT_ID
from danswer.document_index.vespa_constants import DOCUMENT_ID
from danswer.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from danswer.document_index.vespa_constants import DOCUMENT_SETS
@@ -65,6 +66,8 @@ def _does_document_exist(
raise RuntimeError(
f"Unexpected fetch document by ID value from Vespa "
f"with error {doc_fetch_response.status_code}"
f"Index name: {index_name}"
f"Doc chunk id: {doc_chunk_id}"
)
return True
@@ -95,7 +98,7 @@ def get_existing_documents_from_chunks(
try:
chunk_existence_future = {
executor.submit(
_does_document_exist,
_does_document_exist,
str(get_uuid_from_chunk(chunk)),
index_name,
http_client,
@@ -117,7 +120,9 @@ def get_existing_documents_from_chunks(
@retry(tries=3, delay=1, backoff=2)
def _index_vespa_chunk(
chunk: DocMetadataAwareIndexChunk, index_name: str, http_client: httpx.Client
chunk: DocMetadataAwareIndexChunk,
index_name: str,
http_client: httpx.Client,
) -> None:
json_header = {
"Content-Type": "application/json",
@@ -172,8 +177,10 @@ def _index_vespa_chunk(
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
}
if chunk.tenant_id:
vespa_document_fields[TENANT_ID] = chunk.tenant_id
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_chunk_id}"
logger.debug(f'Indexing to URL "{vespa_url}"')
res = http_client.post(
vespa_url, headers=json_header, json={"fields": vespa_document_fields}
)

View File

@@ -7,6 +7,7 @@ from danswer.document_index.interfaces import VespaChunkRequest
from danswer.document_index.vespa_constants import ACCESS_CONTROL_LIST
from danswer.document_index.vespa_constants import CHUNK_ID
from danswer.document_index.vespa_constants import DOC_UPDATED_AT
from danswer.document_index.vespa_constants import TENANT_ID
from danswer.document_index.vespa_constants import DOCUMENT_ID
from danswer.document_index.vespa_constants import DOCUMENT_SETS
from danswer.document_index.vespa_constants import HIDDEN
@@ -19,6 +20,7 @@ logger = setup_logger()
def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) -> str:
def _build_or_filters(key: str, vals: list[str] | None) -> str:
if vals is None:
return ""
@@ -53,6 +55,10 @@ def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) ->
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
if filters.tenant_id:
filter_str += f'({TENANT_ID} contains "{filters.tenant_id}") and '
# CAREFUL touching this one, currently there is no second ACL double-check post retrieval
if filters.access_control_list is not None:
filter_str += _build_or_filters(

View File

@@ -8,7 +8,14 @@ VESPA_DIM_REPLACEMENT_PAT = "VARIABLE_DIM"
DANSWER_CHUNK_REPLACEMENT_PAT = "DANSWER_CHUNK_NAME"
DOCUMENT_REPLACEMENT_PAT = "DOCUMENT_REPLACEMENT"
DATE_REPLACEMENT = "DATE_REPLACEMENT"
SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER"
TENANT_ID_PAT = "TENANT_ID_REPLACEMENT"
TENANT_ID_REPLACEMENT = """field tenant_id type string {
indexing: summary | attribute
rank: filter
attribute: fast-search
}"""
# config server
VESPA_CONFIG_SERVER_URL = f"http://{VESPA_CONFIG_SERVER_HOST}:{VESPA_TENANT_PORT}"
VESPA_APPLICATION_ENDPOINT = f"{VESPA_CONFIG_SERVER_URL}/application/v2"
@@ -31,7 +38,7 @@ MAX_ID_SEARCH_QUERY_SIZE = 400
VESPA_TIMEOUT = "3s"
BATCH_SIZE = 128 # Specific to Vespa
TENANT_ID = "tenant_id"
DOCUMENT_ID = "document_id"
CHUNK_ID = "chunk_id"
BLURB = "blurb"

View File

@@ -134,6 +134,7 @@ def index_doc_batch_with_handler(
attempt_id: int | None,
db_session: Session,
ignore_time_skip: bool = False,
tenant_id: str | None = None,
) -> tuple[int, int]:
r = (0, 0)
try:
@@ -145,6 +146,7 @@ def index_doc_batch_with_handler(
index_attempt_metadata=index_attempt_metadata,
db_session=db_session,
ignore_time_skip=ignore_time_skip,
tenant_id=tenant_id,
)
except Exception as e:
if INDEXING_EXCEPTION_LIMIT == 0:
@@ -258,6 +260,7 @@ def index_doc_batch(
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
ignore_time_skip: bool = False,
tenant_id: str | None = None,
) -> tuple[int, int]:
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
@@ -316,6 +319,7 @@ def index_doc_batch(
if chunk.source_document.id in ctx.id_to_db_doc_map
else DEFAULT_BOOST
),
tenant_id=tenant_id,
)
for chunk in chunks_with_embeddings
]
@@ -357,6 +361,7 @@ def build_indexing_pipeline(
chunker: Chunker | None = None,
ignore_time_skip: bool = False,
attempt_id: int | None = None,
tenant_id: str | None = None,
) -> IndexingPipelineProtocol:
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
search_settings = get_current_search_settings(db_session)
@@ -393,4 +398,5 @@ def build_indexing_pipeline(
ignore_time_skip=ignore_time_skip,
attempt_id=attempt_id,
db_session=db_session,
tenant_id=tenant_id,
)

View File

@@ -55,12 +55,10 @@ class DocAwareChunk(BaseChunk):
f"Chunk ID: '{self.chunk_id}'; {self.source_document.to_short_descriptor()}"
)
class IndexChunk(DocAwareChunk):
embeddings: ChunkEmbedding
title_embedding: Embedding | None
class DocMetadataAwareIndexChunk(IndexChunk):
"""An `IndexChunk` that contains all necessary metadata to be indexed. This includes
the following:
@@ -73,6 +71,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
negative -> ranked lower.
"""
tenant_id: str | None = None
access: "DocumentAccess"
document_sets: set[str]
boost: int
@@ -84,6 +83,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access: "DocumentAccess",
document_sets: set[str],
boost: int,
tenant_id: str | None,
) -> "DocMetadataAwareIndexChunk":
index_chunk_data = index_chunk.model_dump()
return cls(
@@ -91,6 +91,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access=access,
document_sets=document_sets,
boost=boost,
tenant_id=tenant_id,
)

View File

@@ -1,3 +1,4 @@
from sqlalchemy.orm import Session
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
@@ -94,13 +95,14 @@ def compute_max_document_tokens(
def compute_max_document_tokens_for_persona(
persona: Persona,
db_session: Session,
actual_user_input: str | None = None,
max_llm_token_override: int | None = None,
) -> int:
prompt = persona.prompts[0] if persona.prompts else get_default_prompt__read_only()
return compute_max_document_tokens(
prompt_config=PromptConfig.from_model(prompt),
llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config,
llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona, db_session=db_session)).config,
actual_user_input=actual_user_input,
max_llm_token_override=max_llm_token_override,
)

View File

@@ -1,7 +1,7 @@
from danswer.db.engine import get_session_context_manager
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_default_provider
from danswer.db.llm import fetch_provider
from danswer.db.models import Persona
@@ -10,16 +10,20 @@ from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.headers import build_llm_extra_headers
from danswer.llm.interfaces import LLM
from danswer.llm.override_models import LLMOverride
from sqlalchemy.orm import Session
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_main_llm_from_tuple(
llms: tuple[LLM, LLM],
) -> LLM:
return llms[0]
def get_llms_for_persona(
persona: Persona,
db_session: Session,
llm_override: LLMOverride | None = None,
additional_headers: dict[str, str] | None = None,
) -> tuple[LLM, LLM]:
@@ -28,14 +32,15 @@ def get_llms_for_persona(
temperature_override = llm_override.temperature if llm_override else None
provider_name = model_provider_override or persona.llm_model_provider_override
if not provider_name:
return get_default_llms(
temperature=temperature_override or GEN_AI_TEMPERATURE,
additional_headers=additional_headers,
db_session=db_session,
)
with get_session_context_manager() as db_session:
llm_provider = fetch_provider(db_session, provider_name)
llm_provider = fetch_provider(db_session, provider_name)
if not llm_provider:
raise ValueError("No LLM provider found")
@@ -57,7 +62,6 @@ def get_llms_for_persona(
custom_config=llm_provider.custom_config,
additional_headers=additional_headers,
)
return _create_llm(model), _create_llm(fast_model)
@@ -65,13 +69,19 @@ def get_default_llms(
timeout: int = QA_TIMEOUT,
temperature: float = GEN_AI_TEMPERATURE,
additional_headers: dict[str, str] | None = None,
db_session: Session | None = None,
) -> tuple[LLM, LLM]:
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()
with get_session_context_manager() as db_session:
if db_session is None:
with get_session_context_manager() as db_session:
llm_provider = fetch_default_provider(db_session)
else:
llm_provider = fetch_default_provider(db_session)
if not llm_provider:
raise ValueError("No default LLM provider found")

View File

@@ -1,3 +1,5 @@
from danswer.document_index.vespa.index import VespaIndex
import time
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
@@ -12,15 +14,17 @@ from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from httpx_oauth.clients.google import GoogleOAuth2
from sqlalchemy.orm import Session
from danswer.document_index.interfaces import DocumentIndex
from danswer.configs.app_configs import MULTI_TENANT
from danswer import __version__
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserUpdate
from danswer.auth.users import auth_backend
from danswer.auth.users import fastapi_users
from danswer.chat.load_yamls import load_chat_yamls
from sqlalchemy.orm import Session
from danswer.indexing.models import IndexingSetting
from danswer.configs.app_configs import APP_API_PREFIX
from danswer.configs.app_configs import APP_HOST
from danswer.configs.app_configs import APP_PORT
@@ -37,30 +41,22 @@ from danswer.configs.constants import KV_REINDEX_KEY
from danswer.configs.constants import KV_SEARCH_SETTINGS
from danswer.configs.constants import POSTGRES_WEB_APP_NAME
from danswer.db.connector import check_connectors_exist
from danswer.db.connector import create_initial_default_connector
from danswer.db.connector_credential_pair import associate_default_cc_pair
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
from danswer.db.credentials import create_initial_public_credential
from danswer.db.document import check_docs_exist
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import init_sqlalchemy_engine
from danswer.db.engine import warm_up_connections
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import expire_index_attempts
from danswer.db.persona import delete_old_default_personas
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_current_search_settings
from danswer.db.search_settings import update_secondary_search_settings
from danswer.db.standard_answer import create_initial_default_standard_answer_category
from danswer.db.swap_index import check_index_swap
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.indexing.models import IndexingSetting
from danswer.llm.llm_initialization import load_llm_providers
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
@@ -106,9 +102,7 @@ from danswer.server.settings.api import basic_router as settings_router
from danswer.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
from danswer.tools.built_in_tools import load_builtin_tools
from danswer.tools.built_in_tools import refresh_built_in_tools_cache
from danswer.server.tenants.api import basic_router as tenants_router
from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
@@ -117,7 +111,8 @@ from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import SUPPORTED_EMBEDDING_MODELS
from danswer.db_setup import setup_postgres
logger = setup_logger()
@@ -170,27 +165,6 @@ def include_router_with_global_prefix_prepended(
application.include_router(router, **final_kwargs)
def setup_postgres(db_session: Session) -> None:
logger.notice("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.notice("Verifying default standard answer category exists.")
create_initial_default_standard_answer_category(db_session)
logger.notice("Loading LLM providers from env variables")
load_llm_providers(db_session)
logger.notice("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls()
logger.notice("Loading built-in tools")
load_builtin_tools(db_session)
refresh_built_in_tools_cache(db_session)
auto_add_search_tool_to_personas(db_session)
def translate_saved_search_settings(db_session: Session) -> None:
kv_store = get_dynamic_config_store()
@@ -258,23 +232,22 @@ def mark_reindex_flag(db_session: Session) -> None:
def setup_vespa(
document_index: DocumentIndex,
index_setting: IndexingSetting,
secondary_index_setting: IndexingSetting | None,
embedding_dims: list[int],
secondary_embedding_dim: int | None = None
) -> None:
# Vespa startup is a bit slow, so give it a few seconds
wait_time = 5
for _ in range(5):
try:
document_index.ensure_indices_exist(
index_embedding_dim=index_setting.model_dim,
secondary_index_embedding_dim=secondary_index_setting.model_dim
if secondary_index_setting
else None,
embedding_dims=embedding_dims,
secondary_index_embedding_dim=secondary_embedding_dim
)
break
except Exception:
logger.notice(f"Waiting on Vespa, retrying in {wait_time} seconds...")
time.sleep(wait_time)
logger.exception("Error ensuring multi-tenant indices exist")
@asynccontextmanager
@@ -304,6 +277,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
# Break bad state for thrashing indexes
if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
search_settings_id=search_settings.id, db_session=db_session
)
@@ -316,12 +290,14 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.notice(f'Using Embedding model: "{search_settings.model_name}"')
if search_settings.query_prefix or search_settings.passage_prefix:
logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"')
logger.notice(
f'Passage embedding prefix: "{search_settings.passage_prefix}"'
)
if search_settings:
if not search_settings.disable_rerank_for_streaming:
logger.notice("Reranking is enabled.")
@@ -347,19 +323,39 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
# ensure Vespa is setup correctly
logger.notice("Verifying Document Index(s) is/are available.")
document_index = get_default_document_index(
primary_index_name=search_settings.index_name,
secondary_index_name=secondary_search_settings.index_name
if secondary_search_settings
else None,
)
setup_vespa(
document_index,
IndexingSetting.from_db_model(search_settings),
IndexingSetting.from_db_model(secondary_search_settings)
if secondary_search_settings
else None,
)
# document_index = get_default_document_index(
# indices=[model.index_name for model in SUPPORTED_EMBEDDING_MODELS]
# ) if MULTI_TENANT else get_default_document_index(
# indices=[model.index_name for model in SUPPORTED_EMBEDDING_MODELS],
# secondary_index_name=secondary_search_settings.index_name if secondary_search_settings else None
# )
if MULTI_TENANT:
document_index = get_default_document_index(
indices=[model.index_name for model in SUPPORTED_EMBEDDING_MODELS]
)
setup_vespa(
document_index,
[model.dim for model in SUPPORTED_EMBEDDING_MODELS],
secondary_embedding_dim=secondary_search_settings.model_dim if secondary_search_settings else None
)
else:
document_index = get_default_document_index(
indices=[search_settings.index_name],
secondary_index_name=secondary_search_settings.index_name if secondary_search_settings else None
)
setup_vespa(
document_index,
[IndexingSetting.from_db_model(search_settings).model_dim],
secondary_embedding_dim=(
IndexingSetting.from_db_model(secondary_search_settings).model_dim
if secondary_search_settings
else None
)
)
logger.notice(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
if search_settings.provider_type is None:
@@ -370,6 +366,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
server_port=MODEL_SERVER_PORT,
),
)
logger.notice("Setup complete")
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
yield
@@ -415,6 +412,9 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(
application, token_rate_limit_settings_router
)
include_router_with_global_prefix_prepended(
application, tenants_router
)
include_router_with_global_prefix_prepended(application, indexing_router)
if AUTH_TYPE == AuthType.DISABLED:

View File

@@ -25,7 +25,6 @@ from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.chat import update_search_docs_table_with_relevance
from danswer.db.engine import get_session_context_manager
from danswer.db.models import User
from danswer.db.persona import get_prompt_by_id
from danswer.llm.answering.answer import Answer
@@ -118,7 +117,8 @@ def stream_answer_objects(
one_shot=True,
danswerbot_flow=danswerbot_flow,
)
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona)
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona, db_session=db_session)
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
@@ -139,6 +139,7 @@ def stream_answer_objects(
rephrased_query = query_req.query_override or thread_based_query_rephrase(
user_query=query_msg.message,
history_str=history_str,
db_session=db_session
)
# Given back ahead of the documents for latency reasons
@@ -209,7 +210,8 @@ def stream_answer_objects(
question=query_msg.message,
answer_style_config=answer_config,
prompt_config=PromptConfig.from_model(prompt),
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona)),
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona, db_session=db_session)),
# TODO: change back
single_message_history=history_str,
tools=[search_tool],
force_use_tool=ForceUseTool(
@@ -316,17 +318,17 @@ def stream_search_answer(
user: User | None,
max_document_tokens: int | None,
max_history_tokens: int | None,
db_session: Session,
) -> Iterator[str]:
with get_session_context_manager() as session:
objects = stream_answer_objects(
query_req=query_req,
user=user,
max_document_tokens=max_document_tokens,
max_history_tokens=max_history_tokens,
db_session=session,
)
for obj in objects:
yield get_json_line(obj.model_dump())
objects = stream_answer_objects(
query_req=query_req,
user=user,
max_document_tokens=max_document_tokens,
max_history_tokens=max_history_tokens,
db_session=db_session,
)
for obj in objects:
yield get_json_line(obj.model_dump())
def get_search_answer(

View File

@@ -25,7 +25,7 @@ class ThreadMessage(BaseModel):
class DirectQARequest(ChunkContext):
messages: list[ThreadMessage]
prompt_id: int | None
prompt_id: int | None = None
persona_id: int
multilingual_query_expansion: list[str] | None = None
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)

View File

@@ -11,7 +11,6 @@ class RecencyBiasSetting(str, Enum):
# Determine based on query if to use base_decay or favor_recent
AUTO = "auto"
class OptionalSearchSetting(str, Enum):
ALWAYS = "always"
NEVER = "never"

View File

@@ -27,7 +27,7 @@ class RerankingDetails(BaseModel):
# If model is None (or num_rerank is 0), then reranking is turned off
rerank_model_name: str | None
rerank_provider_type: RerankerProvider | None
rerank_api_key: str | None
rerank_api_key: str | None = None
num_rerank: int
@@ -98,6 +98,7 @@ class BaseFilters(BaseModel):
class IndexFilters(BaseFilters):
access_control_list: list[str] | None
tenant_id: str | None = None
class ChunkMetric(BaseModel):

View File

@@ -29,11 +29,11 @@ from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.timing import log_function_time
from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.engine import current_tenant_id
logger = setup_logger()
def query_analysis(query: str) -> tuple[bool, list[str]]:
analysis_model = QueryAnalysisModel()
return analysis_model.predict(query)
@@ -121,6 +121,7 @@ def retrieval_preprocessing(
]
if filter_fn
]
parallel_results = run_functions_in_parallel(functions_to_run)
predicted_time_cutoff, predicted_favor_recent = (
@@ -151,12 +152,15 @@ def retrieval_preprocessing(
user_acl_filters = (
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
final_filters = IndexFilters(
source_type=preset_filters.source_type or predicted_source_filters,
document_set=preset_filters.document_set,
time_cutoff=preset_filters.time_cutoff or predicted_time_cutoff,
tags=preset_filters.tags, # Tags are never auto-extracted
access_control_list=user_acl_filters,
tenant_id=current_tenant_id.get() if MULTI_TENANT else None,
)
llm_evaluation_type = LLMEvaluationType.BASIC

View File

@@ -237,7 +237,7 @@ def retrieve_chunks(
# Currently only uses query expansion on multilingual use cases
query_rephrases = multilingual_query_expansion(
query.query, multilingual_expansion
query.query, multilingual_expansion, db_session=db_session
)
# Just to be extra sure, add the original query.
query_rephrases.append(query.query)

View File

@@ -15,11 +15,11 @@ from danswer.prompts.miscellaneous_prompts import LANGUAGE_REPHRASE_PROMPT
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import count_punctuation
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from sqlalchemy.orm import Session
logger = setup_logger()
def llm_multilingual_query_expansion(query: str, language: str) -> str:
def llm_multilingual_query_expansion(query: str, language: str, db_session: Session) -> str:
def _get_rephrase_messages() -> list[dict[str, str]]:
messages = [
{
@@ -33,7 +33,7 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str:
return messages
try:
_, fast_llm = get_default_llms(timeout=5)
_, fast_llm = get_default_llms(timeout=5, db_session=db_session)
except GenAIDisabledException:
logger.warning(
"Unable to perform multilingual query expansion, Gen AI disabled"
@@ -51,12 +51,13 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str:
def multilingual_query_expansion(
query: str,
expansion_languages: list[str],
db_session: Session,
use_threads: bool = True,
) -> list[str]:
languages = [language.strip() for language in expansion_languages]
if use_threads:
functions_with_args: list[tuple[Callable, tuple]] = [
(llm_multilingual_query_expansion, (query, language))
(llm_multilingual_query_expansion, (query, language, db_session))
for language in languages
]
@@ -65,7 +66,7 @@ def multilingual_query_expansion(
else:
query_rephrases = [
llm_multilingual_query_expansion(query, language) for language in languages
llm_multilingual_query_expansion(query, language, db_session) for language in languages
]
return query_rephrases
@@ -134,9 +135,10 @@ def history_based_query_rephrase(
def thread_based_query_rephrase(
user_query: str,
history_str: str,
db_session: Session,
llm: LLM | None = None,
size_heuristic: int = 200,
punctuation_heuristic: int = 10,
punctuation_heuristic: int = 10
) -> str:
if not history_str:
return user_query
@@ -149,7 +151,7 @@ def thread_based_query_rephrase(
if llm is None:
try:
llm, _ = get_default_llms()
llm, _ = get_default_llms(db_session=db_session)
except GenAIDisabledException:
# If Generative AI is turned off, just return the original query
return user_query

View File

@@ -1,4 +1,5 @@
import re
from sqlalchemy.orm import Session
from collections.abc import Iterator
from danswer.chat.models import DanswerAnswerPiece
@@ -46,13 +47,14 @@ def extract_answerability_bool(model_raw: str) -> bool:
def get_query_answerability(
user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY
db_session: Session,
user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY,
) -> tuple[str, bool]:
if skip_check:
return "Query Answerability Evaluation feature is turned off", True
try:
llm, _ = get_default_llms()
llm, _ = get_default_llms(db_session=db_session)
except GenAIDisabledException:
return "Generative AI is turned off - skipping check", True
@@ -67,7 +69,7 @@ def get_query_answerability(
def stream_query_answerability(
user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY
db_session: Session, user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY,
) -> Iterator[str]:
if skip_check:
yield get_json_line(
@@ -79,7 +81,7 @@ def stream_query_answerability(
return
try:
llm, _ = get_default_llms()
llm, _ = get_default_llms(db_session=db_session)
except GenAIDisabledException:
yield get_json_line(
QueryValidationResponse(

View File

@@ -166,6 +166,5 @@ if __name__ == "__main__":
while True:
user_input = input("Query to Extract Sources: ")
sources = extract_source_filter(
user_input, get_main_llm_from_tuple(get_default_llms()), db_session
user_input, get_main_llm_from_tuple(get_default_llms(db_session=db_session)), db_session
)
print(sources)

View File

@@ -34,6 +34,7 @@ PUBLIC_ENDPOINT_SPECS = [
("/auth/reset-password", {"POST"}),
("/auth/request-verify-token", {"POST"}),
("/auth/verify", {"POST"}),
("/tenants/auth/sso-callback", {"POST"}),
("/users/me", {"GET"}),
("/users/me", {"PATCH"}),
("/users/{id}", {"GET"}),
@@ -42,6 +43,8 @@ PUBLIC_ENDPOINT_SPECS = [
# oauth
("/auth/oauth/authorize", {"GET"}),
("/auth/oauth/callback", {"GET"}),
# tenant service related (must use API key)
("/tenants/create", {"POST"}),
]

View File

@@ -38,6 +38,7 @@ from danswer.server.manage.models import BoostUpdateRequest
from danswer.server.manage.models import HiddenUpdateRequest
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from danswer.db.engine import current_tenant_id
router = APIRouter(prefix="/manage")
logger = setup_logger()
@@ -195,7 +196,7 @@ def create_deletion_attempt_for_connector_id(
)
# actually kick off the deletion
cleanup_connector_credential_pair_task.apply_async(
kwargs=dict(connector_id=connector_id, credential_id=credential_id),
kwargs=dict(connector_id=connector_id, credential_id=credential_id, tenant_id=current_tenant_id.get()),
)
if cc_pair.connector.source == DocumentSource.FILE:

View File

@@ -60,6 +60,7 @@ def set_new_search_settings(
search_settings = get_current_search_settings(db_session)
if search_settings_new.index_name is None:
# We define index name here
index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
if (
@@ -97,6 +98,7 @@ def set_new_search_settings(
primary_index_name=search_settings.index_name,
secondary_index_name=new_search_settings.index_name,
)
document_index.ensure_indices_exist(
index_embedding_dim=search_settings.model_dim,
secondary_index_embedding_dim=new_search_settings.model_dim,

View File

@@ -1,13 +1,18 @@
import os
import re
from datetime import datetime
from datetime import timezone
from enum import Enum
import stripe
from email_validator import validate_email
from fastapi import APIRouter
from fastapi import Body
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import status
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy import Column
from sqlalchemy import desc
@@ -27,7 +32,9 @@ from danswer.auth.users import current_user
from danswer.auth.users import optional_user
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.configs.app_configs import STRIPE_PRICE
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import AuthType
from danswer.db.engine import get_session
from danswer.db.models import AccessToken
@@ -47,6 +54,13 @@ from danswer.utils.logger import setup_logger
from ee.danswer.db.api_key import is_api_key_email_address
from ee.danswer.db.user_group import remove_curator_status__no_commit
STRIPE_SECRET_KEY = os.getenv("STRIPE_SECRET_KEY")
STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET")
# STRIPE_PRICE = os.getenv("STRIPE_PRICE")
stripe.api_key = "sk_test_51NwZq2HlhTYqRZibT2cssHV8E5QcLAUmaRLQPMjGb5aOxOWomVxOmzRgxf82ziDBuGdPP2GIDod8xe6DyqeGgUDi00KbsHPoT4"
logger = setup_logger()
router = APIRouter()
@@ -319,6 +333,63 @@ def verify_user_logged_in(
return user_info
class BillingPlanType(str, Enum):
FREE = "free"
PREMIUM = "premium"
ENTERPRISE = "enterprise"
class CheckoutSessionUpdateBillingStatus(BaseModel):
quantity: int
plan: BillingPlanType
@router.post("/create-checkout-session")
async def create_checkout_session(
request: Request,
checkout_billing: CheckoutSessionUpdateBillingStatus,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> JSONResponse:
quantity = checkout_billing.quantity
plan = checkout_billing.plan
logger.info(f"Creating checkout session for plan: {plan} with quantity: {quantity}")
user_email = "pablosfsanchez@gmail.com"
success_url = f"{WEB_DOMAIN}/admin/plan?success=true"
cancel_url = f"{WEB_DOMAIN}/admin/plan?success=false"
logger.info(f"Stripe price being used: {STRIPE_PRICE}")
logger.info(
f"Creating checkout session with success_url: {success_url} and cancel_url: {cancel_url}"
)
try:
checkout_session = stripe.checkout.Session.create(
customer_email=user_email,
line_items=[
{
"price": STRIPE_PRICE,
"quantity": quantity,
},
],
mode="subscription",
success_url=success_url,
cancel_url=cancel_url,
metadata={"tenant_id": str("random tenant")},
)
logger.info(
f"Checkout session created successfully with id: {checkout_session.id}"
)
except Exception as e:
logger.error(f"Unexpected error: {str(e)}")
raise HTTPException(status_code=500, detail="An unexpected error occurred")
return JSONResponse({"sessionId": checkout_session.id})
"""APIs to adjust user preferences"""

View File

@@ -226,7 +226,7 @@ def rename_chat_session(
try:
llm, _ = get_default_llms(
additional_headers=get_litellm_additional_request_headers(request.headers)
additional_headers=get_litellm_additional_request_headers(request.headers), db_session=db_session
)
except GenAIDisabledException:
# This may be longer than what the LLM tends to produce but is the most
@@ -296,6 +296,7 @@ async def is_disconnected(request: Request) -> Callable[[], bool]:
def handle_new_chat_message(
chat_message_req: CreateChatMessageRequest,
request: Request,
db_session: Session = Depends(get_session),
user: User | None = Depends(current_user),
_: None = Depends(check_token_rate_limits),
is_disconnected_func: Callable[[], bool] = Depends(is_disconnected),
@@ -308,6 +309,8 @@ def handle_new_chat_message(
To avoid extra overhead/latency, this assumes (and checks) that previous messages on the path
have already been set as latest"""
logger.debug(f"Received new chat message: {chat_message_req.message}")
logger.debug("Messge info")
logger.debug(chat_message_req.__dict__)
if (
not chat_message_req.message
@@ -328,6 +331,7 @@ def handle_new_chat_message(
request.headers
),
is_connected=is_disconnected_func,
db_session=db_session,
):
yield json.dumps(packet) if isinstance(packet, dict) else packet
@@ -424,7 +428,7 @@ def get_max_document_tokens(
raise HTTPException(status_code=404, detail="Persona not found")
return MaxSelectedDocumentTokens(
max_tokens=compute_max_document_tokens_for_persona(persona),
max_tokens=compute_max_document_tokens_for_persona(persona, db_session=db_session),
)
@@ -474,7 +478,7 @@ def seed_chat(
root_message = get_or_create_root_message(
chat_session_id=new_chat_session.id, db_session=db_session
)
llm, fast_llm = get_llms_for_persona(persona=new_chat_session.persona)
llm, fast_llm = get_llms_for_persona(persona=new_chat_session.persona, db_session=db_session)
tokenizer = get_tokenizer(
model_name=llm.config.model_name,

View File

@@ -40,6 +40,7 @@ from danswer.server.query_and_chat.models import SourceTag
from danswer.server.query_and_chat.models import TagResponse
from danswer.server.query_and_chat.token_limit import check_token_rate_limits
from danswer.utils.logger import setup_logger
from danswer.db.engine import get_current_tenant_id
logger = setup_logger()
@@ -52,6 +53,7 @@ def admin_search(
question: AdminSearchRequest,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> AdminSearchResponse:
query = question.query
logger.notice(f"Received admin search query: {query}")
@@ -62,6 +64,7 @@ def admin_search(
time_cutoff=question.filters.time_cutoff,
tags=question.filters.tags,
access_control_list=user_acl_filters,
tenant_id=tenant_id,
)
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(
@@ -117,13 +120,13 @@ def get_tags(
@basic_router.post("/query-validation")
def query_validation(
simple_query: SimpleQueryRequest, _: User = Depends(current_user)
simple_query: SimpleQueryRequest, _: User = Depends(current_user), db_session: Session = Depends(get_session)
) -> QueryValidationResponse:
# Note if weak model prompt is chosen, this check does not occur and will simply return that
# the query is valid, this is because weaker models cannot really handle this task well.
# Additionally, some weak model servers cannot handle concurrent inferences.
logger.notice(f"Validating query: {simple_query.query}")
reasoning, answerable = get_query_answerability(simple_query.query)
reasoning, answerable = get_query_answerability(db_session=db_session, user_query=simple_query.query)
return QueryValidationResponse(reasoning=reasoning, answerable=answerable)
@@ -226,14 +229,14 @@ def get_search_session(
# No search responses are answered with a conversational generative AI response
@basic_router.post("/stream-query-validation")
def stream_query_validation(
simple_query: SimpleQueryRequest, _: User = Depends(current_user)
simple_query: SimpleQueryRequest, _: User = Depends(current_user), db_session: Session = Depends(get_session)
) -> StreamingResponse:
# Note if weak model prompt is chosen, this check does not occur and will simply return that
# the query is valid, this is because weaker models cannot really handle this task well.
# Additionally, some weak model servers cannot handle concurrent inferences.
logger.notice(f"Validating query: {simple_query.query}")
return StreamingResponse(
stream_query_answerability(simple_query.query), media_type="application/json"
stream_query_answerability(user_query=simple_query.query, db_session=db_session), media_type="application/json"
)
@@ -242,6 +245,7 @@ def get_answer_with_quote(
query_request: DirectQARequest,
user: User = Depends(current_user),
_: None = Depends(check_token_rate_limits),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
query = query_request.messages[0].message
@@ -252,5 +256,6 @@ def get_answer_with_quote(
user=user,
max_document_tokens=None,
max_history_tokens=0,
db_session=db_session,
)
return StreamingResponse(packets, media_type="application/json")

View File

@@ -33,6 +33,7 @@ def check_token_rate_limits(
) -> None:
# short circuit if no rate limits are set up
# NOTE: result of `any_rate_limit_exists` is cached, so this call is fast 99% of the time
return
if not any_rate_limit_exists():
return

View File

@@ -1,17 +1,16 @@
from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.auth.users import is_user_admin
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.configs.constants import NotificationType
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.notification import create_notification
from danswer.db.notification import dismiss_all_notifications
@@ -27,15 +26,17 @@ from danswer.server.settings.models import UserSettings
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.utils.logger import setup_logger
from fastapi import HTTPException
logger = setup_logger()
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
admin_router = APIRouter(prefix="/admin/settings")
basic_router = APIRouter(prefix="/settings")
@admin_router.put("")
def put_settings(
settings: Settings, _: User | None = Depends(current_admin_user)
@@ -66,7 +67,7 @@ def fetch_settings(
return UserSettings(
**general_settings.model_dump(),
notifications=user_notifications,
needs_reindexing=needs_reindexing
needs_reindexing=needs_reindexing,
)
@@ -91,6 +92,7 @@ def dismiss_notification_endpoint(
def get_user_notifications(
user: User | None, db_session: Session
) -> list[Notification]:
return cast(list[Notification], [])
"""Get notifications for the user, currently the logic is very specific to the reindexing flag"""
is_admin = is_user_admin(user)
if not is_admin:

View File

@@ -0,0 +1,72 @@
from fastapi import APIRouter
from fastapi import Depends
from sqlalchemy.orm import Session
from fastapi import Body
from danswer.db.engine import get_sqlalchemy_engine
from danswer.auth.users import create_user_session
from danswer.auth.users import get_user_manager
from danswer.auth.users import UserManager
from danswer.auth.users import verify_sso_token
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.utils.logger import setup_logger
from fastapi.responses import JSONResponse
from fastapi import HTTPException
from ee.danswer.auth.users import control_plane_dep
from danswer.server.tenants.provisioning import setup_postgres_and_initial_settings
from danswer.server.tenants.provisioning import check_schema_exists
from danswer.server.tenants.provisioning import run_alembic_migrations
from danswer.server.tenants.provisioning import create_tenant_schema
from danswer.configs.app_configs import MULTI_TENANT
logger = setup_logger()
basic_router = APIRouter(prefix="/tenants")
@basic_router.post("/create")
def create_tenant(tenant_id: str, _: None= Depends(control_plane_dep)) -> dict[str, str]:
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
if not tenant_id:
raise HTTPException(status_code=400, detail="tenant_id is required")
create_tenant_schema(tenant_id)
run_alembic_migrations(tenant_id)
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
setup_postgres_and_initial_settings(db_session)
logger.info(f"Tenant {tenant_id} created successfully")
return {"status": "success", "message": f"Tenant {tenant_id} created successfully"}
@basic_router.post("/auth/sso-callback")
async def sso_callback(
sso_token: str = Body(..., embed=True),
user_manager: UserManager = Depends(get_user_manager),
) -> JSONResponse:
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
payload = verify_sso_token(sso_token)
user = await user_manager.sso_authenticate(
payload["email"], payload["tenant_id"]
)
tenant_id = payload["tenant_id"]
schema_exists = await check_schema_exists(tenant_id)
if not schema_exists:
raise HTTPException(status_code=403, detail="Your Danswer app has not been set up yet!")
session_token = await create_user_session(user, payload["tenant_id"])
response = JSONResponse(content={"message": "Authentication successful"})
response.set_cookie(
key="tenant_details",
value=session_token,
max_age=SESSION_EXPIRE_TIME_SECONDS,
expires=SESSION_EXPIRE_TIME_SECONDS,
path="/",
secure=False,
httponly=True,
samesite="lax",
)
return response

View File

@@ -0,0 +1,157 @@
import contextlib
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.db.engine import get_async_session
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import expire_index_attempts
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from sqlalchemy.orm import Session
from danswer.llm.llm_initialization import load_llm_providers
from danswer.db.connector import create_initial_default_connector
from danswer.db.connector_credential_pair import associate_default_cc_pair
from danswer.db.credentials import create_initial_public_credential
from danswer.db.standard_answer import create_initial_default_standard_answer_category
from danswer.db.persona import delete_old_default_personas
from danswer.chat.load_yamls import load_chat_yamls
from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
from danswer.tools.built_in_tools import load_builtin_tools
from danswer.tools.built_in_tools import refresh_built_in_tools_cache
from danswer.utils.logger import setup_logger
from danswer.db.engine import get_sqlalchemy_engine
from sqlalchemy.schema import CreateSchema
from sqlalchemy import text
from alembic.config import Config
from alembic import command
from danswer.db.engine import build_connection_string
import os
from danswer.db_setup import setup_postgres
logger = setup_logger()
def run_alembic_migrations(schema_name: str) -> None:
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
alembic_ini_path = os.path.join(root_dir, 'alembic.ini')
# Configure Alembic
alembic_cfg = Config(alembic_ini_path)
alembic_cfg.set_main_option('sqlalchemy.url', build_connection_string())
# Prepare the x arguments
x_arguments = [f"schema={schema_name}"]
alembic_cfg.cmd_opts.x = x_arguments # type: ignore
# Run migrations programmatically
command.upgrade(alembic_cfg, 'head')
logger.info(f"Alembic migrations completed successfully for schema: {schema_name}")
except Exception as e:
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
raise
def create_tenant_schema(tenant_id: str) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
with db_session.begin():
result = db_session.execute(
text("""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name = :schema_name
"""),
{"schema_name": tenant_id}
)
schema_exists = result.scalar() is not None
if not schema_exists:
db_session.execute(CreateSchema(tenant_id))
logger.info(f"Schema {tenant_id} created")
else:
logger.info(f"Schema {tenant_id} already exists")
def setup_postgres_and_initial_settings(db_session: Session) -> None:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
secondary_search_settings = get_secondary_search_settings(db_session)
# Break bad state for thrashing indexes
if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
search_settings_id=search_settings.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)
# Expire all old embedding models indexing attempts, technically redundant
cancel_indexing_attempts_past_model(db_session)
logger.notice(f'Using Embedding model: "{search_settings.model_name}"')
if search_settings.query_prefix or search_settings.passage_prefix:
logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"')
logger.notice(
f'Passage embedding prefix: "{search_settings.passage_prefix}"'
)
if search_settings:
if not search_settings.disable_rerank_for_streaming:
logger.notice("Reranking is enabled.")
if search_settings.multilingual_expansion:
logger.notice(
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
)
if search_settings.rerank_model_name and not search_settings.provider_type:
warm_up_cross_encoder(search_settings.rerank_model_name)
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data()
# setup Postgres with default credentials, llm providers, etc.
setup_postgres(db_session)
# ensure Vespa is setup correctly
logger.notice("Verifying Document Index(s) is/are available.")
logger.notice("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.notice("Verifying default standard answer category exists.")
create_initial_default_standard_answer_category(db_session)
logger.notice("Loading LLM providers from env variables")
load_llm_providers(db_session)
logger.notice("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls(db_session)
logger.notice("Loading built-in tools")
load_builtin_tools(db_session)
refresh_built_in_tools_cache(db_session)
auto_add_search_tool_to_personas(db_session)
async def check_schema_exists(tenant_id: str) -> bool:
get_async_session_context = contextlib.asynccontextmanager(
get_async_session
)
async with get_async_session_context() as session:
result = await session.execute(
text("SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"),
{"schema_name": tenant_id}
)
return result.scalar() is not None

View File

@@ -13,15 +13,43 @@ from ee.danswer.db.api_key import fetch_user_for_api_key
from ee.danswer.db.saml import get_saml_account
from ee.danswer.server.seeding import get_seed_config
from ee.danswer.utils.secrets import extract_hashed_cookie
import jwt
from danswer.configs.app_configs import DATA_PLANE_SECRET
from danswer.configs.app_configs import EXPECTED_API_KEY
logger = setup_logger()
def verify_auth_setting() -> None:
# All the Auth flows are valid for EE version
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
async def control_plane_dep(request: Request) -> None:
auth_header = request.headers.get("Authorization")
api_key = request.headers.get("X-API-KEY")
if api_key != EXPECTED_API_KEY:
logger.warning("Invalid API key")
raise HTTPException(status_code=401, detail="Invalid API key")
if not auth_header or not auth_header.startswith("Bearer "):
logger.warning("Invalid authorization header")
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = auth_header.split(" ")[1]
try:
payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=["HS256"])
if payload.get("scope") != "tenant:create":
logger.warning("Insufficient permissions")
raise HTTPException(status_code=403, detail="Insufficient permissions")
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
raise HTTPException(status_code=401, detail="Token has expired")
except jwt.InvalidTokenError:
logger.warning("Invalid token")
raise HTTPException(status_code=401, detail="Invalid token")
async def optional_user_(
request: Request,
user: User | None,
@@ -44,6 +72,7 @@ async def optional_user_(
return user
def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
) -> User | None:
@@ -63,6 +92,7 @@ def api_key_dep(
return user
def get_default_admin_user_emails_() -> list[str]:
seed_config = get_seed_config()
if seed_config and seed_config.admin_user_emails:

View File

@@ -183,7 +183,7 @@ def handle_send_message_simple_with_history(
one_shot=False,
)
llm, _ = get_llms_for_persona(persona=chat_session.persona)
llm, _ = get_llms_for_persona(persona=chat_session.persona, db_session=db_session)
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
@@ -223,6 +223,7 @@ def handle_send_message_simple_with_history(
rephrased_query = req.query_override or thread_based_query_rephrase(
user_query=query,
history_str=history_str,
db_session=db_session
)
full_chat_msg_info = CreateChatMessageRequest(

View File

@@ -53,7 +53,7 @@ def handle_search_request(
query = search_request.message
logger.notice(f"Received document search query: {query}")
llm, fast_llm = get_default_llms()
llm, fast_llm = get_default_llms(db_session=db_session)
search_pipeline = SearchPipeline(
search_request=SearchRequest(
@@ -141,7 +141,7 @@ def get_answer_with_quote(
)
llm = get_main_llm_from_tuple(
get_default_llms() if not persona else get_llms_for_persona(persona)
get_default_llms(db_session=db_session) if not persona else get_llms_for_persona(persona, db_session=db_session)
)
input_tokens = get_max_input_tokens(
model_name=llm.config.model_name, model_provider=llm.config.model_provider
@@ -154,6 +154,7 @@ def get_answer_with_quote(
persona=persona,
actual_user_input=query,
max_llm_token_override=remaining_tokens,
db_session=db_session,
)
answer_details = get_search_answer(

View File

@@ -1,3 +1,4 @@
from pydantic import BaseModel
import os
# Used for logging
@@ -52,8 +53,9 @@ LOG_FILE_NAME = os.environ.get("LOG_FILE_NAME") or "danswer"
# Enable generating persistent log files for local dev environments
DEV_LOGGING_ENABLED = os.environ.get("DEV_LOGGING_ENABLED", "").lower() == "true"
# notset, debug, info, notice, warning, error, or critical
LOG_LEVEL = os.environ.get("LOG_LEVEL", "notice")
LOG_LEVEL = os.environ.get("LOG_LEVEL", "debug")
# Fields which should only be set on new search setting
@@ -68,3 +70,31 @@ PRESERVED_SEARCH_FIELDS = [
"passage_prefix",
"query_prefix",
]
class SupportedEmbeddingModel(BaseModel):
name: str
dim: int
index_name: str
SUPPORTED_EMBEDDING_MODELS = [
SupportedEmbeddingModel(
name="intfloat/e5-small-v2",
dim=384,
index_name="danswer_chunk_intfloat_e5_small_v2"
),
SupportedEmbeddingModel(
name="intfloat/e5-large-v2",
dim=1024,
index_name="danswer_chunk_intfloat_e5_large_v2"
),
SupportedEmbeddingModel(
name="sentence-transformers/all-distilroberta-v1",
dim=768,
index_name="danswer_chunk_sentence_transformers_all_distilroberta_v1"
),
SupportedEmbeddingModel(
name="sentence-transformers/all-mpnet-base-v2",
dim=768,
index_name="danswer_chunk_sentence_transformers_all_mpnet_base_v2"
),
]

View File

@@ -18,7 +18,7 @@ from danswer.db.swap_index import check_index_swap
from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
from danswer.document_index.vespa.index import VespaIndex
from danswer.indexing.models import IndexingSetting
from danswer.main import setup_postgres
from danswer.db_setup import setup_postgres
from danswer.main import setup_vespa
from tests.integration.common_utils.llm import seed_default_openai_provider
@@ -132,9 +132,9 @@ def reset_vespa() -> None:
index_name = search_settings.index_name
setup_vespa(
document_index=VespaIndex(index_name=index_name, secondary_index_name=None),
index_setting=IndexingSetting.from_db_model(search_settings),
secondary_index_setting=None,
document_index=VespaIndex(indices=[index_name], secondary_index_name=None),
embedding_dims=[search_settings.model_dim],
secondary_embedding_dim=None,
)
for _ in range(5):

View File

@@ -296,7 +296,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432:5432"
- "5433:5432"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@@ -306,7 +306,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432:5432"
- "5433:5432"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@@ -150,7 +150,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432"
- "5433"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@@ -31,6 +31,7 @@ const nextConfig = {
return defaultRedirects.concat([
{
source: "/api/chat/send-message:params*",
destination: "http://127.0.0.1:8080/chat/send-message:params*", // Proxy to Backend
permanent: true,

10
web/package-lock.json generated
View File

@@ -15,6 +15,7 @@
"@radix-ui/react-dialog": "^1.0.5",
"@radix-ui/react-popover": "^1.0.7",
"@radix-ui/react-tooltip": "^1.0.7",
"@stripe/stripe-js": "^4.4.0",
"@tremor/react": "^3.9.2",
"@types/js-cookie": "^3.0.3",
"@types/lodash": "^4.17.0",
@@ -1670,6 +1671,15 @@
"integrity": "sha512-qC/xYId4NMebE6w/V33Fh9gWxLgURiNYgVNObbJl2LZv0GUUItCcCqC5axQSwRaAgaxl2mELq1rMzlswaQ0Zxg==",
"dev": true
},
"node_modules/@stripe/stripe-js": {
"version": "4.4.0",
"resolved": "https://registry.npmjs.org/@stripe/stripe-js/-/stripe-js-4.4.0.tgz",
"integrity": "sha512-p1WeTOwnAyXQ9I5/YC3+JXoUB6NKMR4qGjBobie2+rgYa3ftUTRS2L5qRluw/tGACty5SxqnfORCdsaymD1XjQ==",
"license": "MIT",
"engines": {
"node": ">=12.16"
}
},
"node_modules/@swc/counter": {
"version": "0.1.3",
"resolved": "https://registry.npmjs.org/@swc/counter/-/counter-0.1.3.tgz",

View File

@@ -16,6 +16,7 @@
"@radix-ui/react-dialog": "^1.0.5",
"@radix-ui/react-popover": "^1.0.7",
"@radix-ui/react-tooltip": "^1.0.7",
"@stripe/stripe-js": "^4.4.0",
"@tremor/react": "^3.9.2",
"@types/js-cookie": "^3.0.3",
"@types/lodash": "^4.17.0",

View File

@@ -112,7 +112,7 @@ export default function Page() {
value={searchTerm}
onChange={(e) => setSearchTerm(e.target.value)}
onKeyDown={handleKeyPress}
className="flex mt-2 max-w-sm h-9 w-full rounded-md border-2 border border-input bg-transparent px-3 py-1 text-sm shadow-sm transition-colors placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
className="ml-1 w-96 h-9 flex-none rounded-md border border-border bg-background-50 px-3 py-1 text-sm shadow-sm transition-colors placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
/>
{Object.entries(categorizedSources)

View File

@@ -173,7 +173,7 @@ export function ProviderCreationModal({
</a>
</Text>
<div className="flex flex-col gap-y-2">
<div className="flex w-full flex-col gap-y-2">
{useFileUpload ? (
<>
<Label>Upload JSON File</Label>

View File

@@ -457,7 +457,7 @@ export function CCPairIndexingStatusTable({
placeholder="Search connectors..."
value={searchTerm}
onChange={(e) => setSearchTerm(e.target.value)}
className="ml-2 w-96 h-9 flex-none rounded-md border border-border bg-background-50 px-3 py-1 text-sm shadow-sm transition-colors placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
className="ml-1 w-96 h-9 flex-none rounded-md border border-border bg-background-50 px-3 py-1 text-sm shadow-sm transition-colors placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
/>
<Button className="h-9" onClick={() => toggleSources()}>

View File

@@ -26,11 +26,52 @@ export interface EnterpriseSettings {
custom_popup_header: string | null;
custom_popup_content: string | null;
}
import { FiStar, FiDollarSign, FiAward } from "react-icons/fi";
export enum BillingPlanType {
FREE = "free",
PREMIUM = "premium",
ENTERPRISE = "enterprise",
}
export interface CloudSettings {
numberOfSeats: number;
planType: BillingPlanType;
}
export interface CombinedSettings {
settings: Settings;
enterpriseSettings: EnterpriseSettings | null;
cloudSettings: CloudSettings | null;
customAnalyticsScript: string | null;
isMobile?: boolean;
webVersion: string | null;
}
export const defaultCombinedSettings: CombinedSettings = {
settings: {
chat_page_enabled: true,
search_page_enabled: true,
default_page: "search",
maximum_chat_retention_days: 30,
notifications: [],
needs_reindexing: false,
},
enterpriseSettings: {
application_name: "Danswer",
use_custom_logo: false,
use_custom_logotype: false,
custom_lower_disclaimer_content: null,
custom_header_content: null,
custom_popup_header: null,
custom_popup_content: null,
},
cloudSettings: {
numberOfSeats: 0,
planType: BillingPlanType.FREE,
},
customAnalyticsScript: null,
isMobile: false,
webVersion: null,
};

View File

@@ -147,7 +147,7 @@ function AssistantListItem({
</div>
<div className="text-sm mt-2">{assistant.description}</div>
<div className="mt-2 flex items-start gap-y-2 flex-col gap-x-3">
<div className="mt-2 flex flex-none items-start gap-y-2 flex-col gap-x-3">
<AssistantSharedStatusDisplay assistant={assistant} user={user} />
{assistant.tools.length != 0 && (
<AssistantTools list assistant={assistant} />
@@ -175,6 +175,7 @@ function AssistantListItem({
<FiEdit2 size={16} />
</Link>
<DefaultPopover
content={
<div className="hover:bg-hover rounded p-2 cursor-pointer">

View File

@@ -89,6 +89,7 @@ const Page = async ({
/>
</>
)}
{authTypeMetadata?.authType === "basic" && (
<Card className="mt-4 w-96">
<div className="flex">

View File

@@ -0,0 +1,17 @@
import React from "react";
export const metadata = {
title: "SSO Callback",
};
export default function SSOCallbackLayout({
children,
}: {
children: React.ReactNode;
}) {
return (
<html lang="en">
<body>{children}</body>
</html>
);
}

View File

@@ -0,0 +1,127 @@
"use client";
import { useEffect, useRef, useState } from "react";
import { useRouter, useSearchParams } from "next/navigation";
import { Card, Text } from "@tremor/react";
import { Logo } from "@/components/Logo";
export default function SSOCallback() {
const router = useRouter();
const searchParams = useSearchParams();
const [error, setError] = useState<string | null>(null);
const [authStatus, setAuthStatus] = useState<string>("Authenticating...");
const verificationStartedRef = useRef(false);
useEffect(() => {
const verifyToken = async () => {
if (verificationStartedRef.current) {
return;
}
verificationStartedRef.current = true;
const hashParams = new URLSearchParams(window.location.hash.slice(1));
const ssoToken = hashParams.get("sso_token");
if (!ssoToken) {
setError("No SSO token found in URL hash");
return;
}
window.history.replaceState(null, '', window.location.pathname);
if (!ssoToken) {
setError("No SSO token found");
return;
}
try {
setAuthStatus("Verifying SSO token...");
const response = await fetch(
`/api/tenants/auth/sso-callback`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
},
credentials: "include",
body: JSON.stringify({ sso_token: ssoToken }),
}
)
if (response.ok) {
setAuthStatus("Authentication successful!");
router.replace("/admin/configuration/llm");
} else {
const errorData = await response.json();
setError(errorData.detail || "Authentication failed");
}
} catch (error) {
console.error("Error verifying token:", error);
setError("An unexpected error occurred");
}
};
verifyToken();
}, [router, searchParams]);
return (
<div className="flex items-center justify-center min-h-screen bg-gradient-to-r from-background-50 to-blue-50">
<Card className="max-w-lg p-8 text-center shadow-xl rounded-xl bg-white">
{error ? (
<div className="space-y-4">
<svg
className="w-16 h-16 mx-auto text-neutral-600"
fill="none"
stroke="currentColor"
viewBox="0 0 24 24"
xmlns="http://www.w3.org/2000/svg"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth={2}
d="M12 8v4m0 4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
/>
</svg>
<Text className="text-xl font-bold text-red-600">{error}</Text>
</div>
) : (
<div className="space-y-6 flex flex-col">
<div className="flex mx-auto">
<Logo height={200} width={200} />
</div>
<Text className="text-2xl font-semibold text-text-900">
{authStatus}
</Text>
<div className="w-full h-2 bg-background-100 rounded-full overflow-hidden">
<div
className="h-full bg-background-600 rounded-full animate-progress"
style={{
animation: "progress 5s ease-out forwards",
width: "0%",
}}
/>
</div>
<style jsx>{`
@keyframes progress {
0% {
width: 0%;
}
60% {
width: 75%;
}
100% {
width: 99%;
}
}
.animate-progress {
animation: progress 5s ease-out forwards;
}
`}</style>
</div>
)}
</Card>
</div>
);
}

View File

@@ -1,12 +1,9 @@
"use client";
import ReactMarkdown from "react-markdown";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { useContext, useState, useRef, useLayoutEffect } from "react";
import remarkGfm from "remark-gfm";
import { Popover } from "@/components/popover/Popover";
import { ChevronDownIcon } from "@/components/icons/icons";
import { Divider } from "@tremor/react";
import { MinimalMarkdown } from "@/components/chat_search/MinimalMarkdown";
export function ChatBanner() {
@@ -34,27 +31,6 @@ export function ChatBanner() {
return null;
}
const renderMarkdown = (className: string) => (
<ReactMarkdown
className={`w-full text-wrap break-word ${className}`}
components={{
a: ({ node, ...props }) => (
<a
{...props}
className="text-sm text-link hover:text-link-hover"
target="_blank"
rel="noopener noreferrer"
/>
),
p: ({ node, ...props }) => (
<p {...props} className="text-wrap break-word text-sm m-0 w-full" />
),
}}
remarkPlugins={[remarkGfm]}
>
{settings.enterpriseSettings?.custom_header_content}
</ReactMarkdown>
);
return (
<div
className={`
@@ -65,7 +41,6 @@ export function ChatBanner() {
w-full
mx-auto
relative
bg-background-100
shadow-sm
rounded
border-l-8 border-l-400
@@ -81,7 +56,6 @@ export function ChatBanner() {
className="line-clamp-2 text-center w-full overflow-hidden pr-8"
>
<MinimalMarkdown
className=""
content={settings.enterpriseSettings.custom_header_content}
/>
</div>
@@ -90,7 +64,6 @@ export function ChatBanner() {
className="absolute top-0 left-0 invisible w-full"
>
<MinimalMarkdown
className=""
content={settings.enterpriseSettings.custom_header_content}
/>
</div>

View File

@@ -210,6 +210,11 @@ export function ChatPage({
liveAssistant,
llmProviders
);
console.log("persona default", personaDefault);
console.log(
destructureValue(user?.preferences.default_model || "openai:gpt-4o")
);
if (personaDefault) {
llmOverrideManager.setLlmOverride(personaDefault);
@@ -882,6 +887,11 @@ export function ChatPage({
modelOverRide?: LlmOverride;
regenerationRequest?: RegenerationRequest | null;
} = {}) => {
console.log("model override", modelOverRide);
console.log(modelOverRide?.name);
console.log(llmOverrideManager.llmOverride.name);
console.log("HII")
console.log(llmOverrideManager.globalDefault.name);
let frozenSessionId = currentSessionId();
if (currentChatState() != "input") {

View File

@@ -205,7 +205,8 @@ const FolderItem = ({
className="text-sm px-1 flex-1 min-w-0 -my-px mr-2"
/>
) : (
<div className="flex-1 min-w-0">
<div className="flex-1 break-all min-w-0">
{editedFolderName || folder.folder_name}
</div>
)}

View File

@@ -116,6 +116,7 @@ export function ChatInputBar({
const { llmProviders } = useChatContext();
const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null);
const suggestionsRef = useRef<HTMLDivElement | null>(null);
const [showSuggestions, setShowSuggestions] = useState(false);
const [showPrompts, setShowPrompts] = useState(false);

View File

@@ -152,6 +152,8 @@ export async function* sendMessage({
}): AsyncGenerator<PacketType, void, unknown> {
const documentsAreSelected =
selectedDocumentIds && selectedDocumentIds.length > 0;
console.log("llm ovverride deatilas", modelProvider, modelVersion)
const body = JSON.stringify({
alternate_assistant_id: alternateAssistantId,
@@ -164,30 +166,30 @@ export async function* sendMessage({
regenerate,
retrieval_options: !documentsAreSelected
? {
run_search:
promptId === null ||
run_search:
promptId === null ||
promptId === undefined ||
queryOverride ||
forceSearch
? "always"
: "auto",
real_time: true,
filters: filters,
}
? "always"
: "auto",
real_time: true,
filters: filters,
}
: null,
query_override: queryOverride,
prompt_override: systemPromptOverride
? {
system_prompt: systemPromptOverride,
}
system_prompt: systemPromptOverride,
}
: null,
llm_override:
temperature || modelVersion
? {
temperature,
model_provider: modelProvider,
model_version: modelVersion,
}
temperature,
model_provider: modelProvider,
model_version: modelVersion,
}
: null,
use_existing_user_message: useExistingUserMessage,
});
@@ -427,11 +429,11 @@ export function processRawChatHistory(
// this is identical to what is computed at streaming time
...(messageInfo.message_type === "assistant"
? {
retrievalType: retrievalType,
query: messageInfo.rephrased_query,
documents: messageInfo?.context_docs?.top_documents || [],
citations: messageInfo?.citations || {},
}
retrievalType: retrievalType,
query: messageInfo.rephrased_query,
documents: messageInfo?.context_docs?.top_documents || [],
citations: messageInfo?.citations || {},
}
: {}),
toolCalls: messageInfo.tool_calls,
parentMessageId: messageInfo.parent_message,
@@ -596,8 +598,7 @@ export function buildChatUrl(
const finalSearchParams: string[] = [];
if (chatSessionId) {
finalSearchParams.push(
`${
search ? SEARCH_PARAM_NAMES.SEARCH_ID : SEARCH_PARAM_NAMES.CHAT_ID
`${search ? SEARCH_PARAM_NAMES.SEARCH_ID : SEARCH_PARAM_NAMES.CHAT_ID
}=${chatSessionId}`
);
}

View File

@@ -13,7 +13,6 @@ import {
sortableKeyboardCoordinates,
verticalListSortingStrategy,
} from "@dnd-kit/sortable";
import { CSS } from "@dnd-kit/utilities";
import { Persona } from "@/app/admin/assistants/interfaces";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import { getFinalLLM } from "@/lib/llm/utils";
@@ -41,6 +40,7 @@ export function AssistantsTab({
coordinateGetter: sortableKeyboardCoordinates,
})
);
console.log("llm providers", llmProviders);
function handleDragEnd(event: DragEndEvent) {
const { active, over } = event;

View File

@@ -0,0 +1,325 @@
"use client";
import { BillingPlanType } from "@/app/admin/settings/interfaces";
import { useContext, useEffect, useState } from "react";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { Button, Divider, Card } from "@tremor/react";
import { StripeCheckoutButton } from "./StripeCheckoutButton";
import {
CheckmarkIcon,
CheckmarkCircleIcon,
LightningIcon,
PeopleIcon,
XIcon,
ChevronDownIcon,
} from "@/components/icons/icons";
import { FiAward, FiDollarSign, FiStar } from "react-icons/fi";
import Cookies from "js-cookie";
import { Modal } from "@/components/Modal";
import { Logo } from "@/components/Logo";
import { useSearchParams } from "next/navigation";
import { usePopup } from "@/components/admin/connectors/Popup";
export function BillingSettings({ newUser }: { newUser: boolean }) {
const settings = useContext(SettingsContext);
const cloudSettings = settings?.cloudSettings;
const searchParams = useSearchParams();
const [isOpen, setIsOpen] = useState(false);
const [isNewUserOpen, setIsNewUserOpen] = useState(true)
const [newSeats, setNewSeats] = useState<null | number>(null);
const [newPlan, setNewPlan] = useState<null | BillingPlanType>(null);
const { popup, setPopup } = usePopup();
useEffect(() => {
const success = searchParams.get("success");
if (success === "true") {
setPopup({
message: "Your plan has been successfully updated!",
type: "success",
});
} else if (success === "false") {
setPopup({
message: "There was an error updating your plan",
type: "error",
});
}
// Clear the 'success' parameter from the URL
if (success) {
const newUrl = new URL(window.location.href);
newUrl.searchParams.delete("success");
window.history.replaceState({}, "", newUrl);
}
}, []);
// Use actual data from cloudSettings
const currentPlan = cloudSettings?.planType;
const seats = cloudSettings?.numberOfSeats!;
useEffect(() => {
if (cloudSettings) {
setNewSeats(cloudSettings.numberOfSeats);
setNewPlan(cloudSettings.planType);
}
}, [cloudSettings]);
if (!cloudSettings) {
return null;
}
const features = [
{ name: "All Connector Access", included: true },
{ name: "Basic Support", included: true },
{ name: "Custom Branding", included: currentPlan !== BillingPlanType.FREE },
{
name: "Analytics Dashboard",
included: currentPlan !== BillingPlanType.FREE,
},
{ name: "Query History", included: currentPlan !== BillingPlanType.FREE },
{
name: "Priority Support",
included: currentPlan !== BillingPlanType.FREE,
},
{
name: "Service Level Agreements (SLAs)",
included: currentPlan === BillingPlanType.ENTERPRISE,
},
{
name: "Advanced Support",
included: currentPlan === BillingPlanType.ENTERPRISE,
},
{
name: "Custom Integrations",
included: currentPlan === BillingPlanType.ENTERPRISE,
},
{
name: "Dedicated Account Manager",
included: currentPlan === BillingPlanType.ENTERPRISE,
},
];
function getBillingPlanIcon(planType: BillingPlanType) {
switch (planType) {
case BillingPlanType.FREE:
return <FiStar />;
case BillingPlanType.PREMIUM:
return <FiDollarSign />;
case BillingPlanType.ENTERPRISE:
return <FiAward />;
default:
return <FiStar />;
}
}
const handleCloseModal = () => {
setIsNewUserOpen(false);
Cookies.set("new_auth_user", "false");
};
if (newSeats === null || currentPlan === undefined) {
return null;
}
return (
<div className="max-w-4xl mr-auto space-y-8 p-6">
{newUser && isNewUserOpen && (
<Modal
onOutsideClick={handleCloseModal}
className="max-w-lg w-full p-8 bg-background-150 rounded-lg shadow-xl"
>
<>
<h2 className="text-3xl font-semibold text-text-900 mb-6 text-center">
Welcome to Danswer!
</h2>
<div className="text-center mb-8">
<Logo className="mx-auto mb-4" height={150} width={150} />
<p className="text-lg text-text-700 leading-relaxed">
We&apos;re thrilled to have you on board! Here, you can manage your
billing settings and explore your plan details.
</p>
</div>
<div className="flex justify-center">
<Button
onClick={handleCloseModal}
className="px-8 py-3 bg-blue-600 text-white text-lg font-semibold rounded-full hover:bg-blue-700 transition duration-300 ease-in-out focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-opacity-50"
>
Let&apos;s Get Started
</Button>
<Button
onClick={() => window.open("mailto:support@danswer.ai")}
className="border-0 hover:underline ml-4 px-4 py-2 bg-gray-200 w-fit text-text-700 text-sm font-medium rounded-full hover:bg-gray-300 transition duration-300 ease-in-out focus:outline-none focus:ring-2 focus:ring-gray-400 focus:ring-opacity-50 flex items-center"
>
Questions?
</Button>
</div>
</>
</Modal>
)}
<Card className="bg-white shadow-lg rounded-lg overflow-hidden">
<div className="px-8 py-6">
<h2 className="text-3xl gap-x-2 font-bold text-text-800 mb-6 flex items-center">
Your Plan
<CheckmarkCircleIcon size={24} className="text-blue-500" />
</h2>
<div className="space-y-6">
<div className="flex justify-between items-center">
<p className="text-lg text-text-600 flex gap-x-2 items-center">
<LightningIcon size={20} />
Tier:
</p>
<span className="text-xl font-semibold text-blue-600 capitalize">
{currentPlan}
</span>
</div>
<div className="flex justify-between items-center">
<p className="text-lg text-text-600 gap-x-2 flex items-center">
<PeopleIcon size={20} />
Current Seats:
</p>
<span className="text-xl font-semibold text-blue-600">
{seats}
</span>
</div>
<Divider />
<div className="mt-6 relative">
<label className="block text-lg font-medium text-text-700 mb-2 flex items-center">
New Tier:
</label>
<div
className="w-full px-4 py-2 text-lg border border-gray-300 rounded-md focus:ring-2 focus:ring-blue-500 focus:border-transparent transition duration-150 ease-in-out flex items-center justify-between cursor-pointer"
onClick={() => setIsOpen(!isOpen)}
>
<span className="flex items-center">
{getBillingPlanIcon(newPlan!)}
<span className="ml-2 capitalize">{newPlan}</span>
</span>
<ChevronDownIcon size={12} />
</div>
{isOpen && (
<div className="absolute z-10 w-full mt-1 bg-white border border-gray-300 rounded-md shadow-lg">
{Object.values(BillingPlanType).map((plan) => (
<div
key={plan}
className="px-4 py-2 hover:bg-gray-100 cursor-pointer flex items-center"
onClick={() => {
setNewPlan(plan);
setIsOpen(false);
if (plan === BillingPlanType.FREE) {
setNewSeats(1);
}
}}
>
{getBillingPlanIcon(plan)}
<span className="ml-2 capitalize">{plan}</span>
</div>
))}
</div>
)}
</div>
<div className="mt-6">
<label className="block text-lg font-medium text-text-700 mb-2 flex items-center">
New Number of Seats:
</label>
{newPlan === BillingPlanType.FREE ? (
<input
type="number"
value={1}
disabled
className="w-full px-4 py-2 text-lg border border-gray-300 rounded-md bg-gray-100 cursor-not-allowed"
/>
) : (
<input
type="number"
value={newSeats}
onChange={(e) => setNewSeats(Number(e.target.value))}
min="1"
className="w-full px-4 py-2 text-lg border border-gray-300 rounded-md focus:ring-2 focus:ring-blue-500 focus:border-transparent transition duration-150 ease-in-out"
/>
)}
</div>
<div className="mt-8 flex justify-center">
<StripeCheckoutButton
currentPlan={currentPlan}
currentQuantity={cloudSettings.numberOfSeats}
newQuantity={newSeats}
newPlan={newPlan!}
/>
</div>
</div>
</div>
</Card>
<Card className="bg-white shadow-lg rounded-lg overflow-hidden">
<div className="px-6 py-4">
<h2 className="text-3xl font-bold text-text-800 mb-4">Features</h2>
<ul className="space-y-3">
{features.map((feature, index) => (
<li key={index} className="flex items-center text-lg">
<span className="mr-3">
{feature.included ? (
<CheckmarkIcon className="text-success" />
) : (
<XIcon className="text-error" />
)}
</span>
<span
className={
feature.included ? "text-text-800" : "text-text-500"
}
>
{feature.name}
</span>
</li>
))}
</ul>
</div>
</Card>
{currentPlan !== BillingPlanType.FREE && (
<Card className="bg-white shadow-lg rounded-lg overflow-hidden">
<div className="px-8 py-6">
<h2 className="text-3xl font-bold text-text-800 mb-4">
Tenant Deletion
</h2>
<p className="text-text-600 mb-6">
Permanently delete your tenant and all associated data.
</p>
<div
className="bg-red-100 border-l-4 border-red-500 text-error p-4 mb-6"
role="alert"
>
<p className="font-bold">Warning:</p>
<p>Deleting your tenant will result in the following:</p>
<ul className="list-disc list-inside mt-2">
<li>
All data associated with this tenant will be permanently
deleted
</li>
<li>This action cannot be undone</li>
</ul>
</div>
<Button
onClick={() => {
alert("not implemented");
}}
className="bg-red-500 hover:bg-red-600 text-white font-bold py-3 px-6 rounded-lg transition duration-300 shadow-md hover:shadow-lg"
>
Delete Tenant
</Button>
</div>
</Card>
)}
{popup}
</div>
);
}

View File

@@ -0,0 +1,72 @@
import { SubLabel } from "@/components/admin/connectors/Field";
import { usePopup } from "@/components/admin/connectors/Popup";
import { useEffect, useState } from "react";
import Dropzone from "react-dropzone";
export function ImageUpload({
selectedFile,
setSelectedFile,
}: {
selectedFile: File | null;
setSelectedFile: (file: File) => void;
}) {
const [tmpImageUrl, setTmpImageUrl] = useState<string>("");
useEffect(() => {
if (selectedFile) {
setTmpImageUrl(URL.createObjectURL(selectedFile));
} else {
setTmpImageUrl("");
}
}, [selectedFile]);
const [dragActive, setDragActive] = useState(false);
const { popup, setPopup } = usePopup();
return (
<>
{popup}
<Dropzone
onDrop={(acceptedFiles) => {
if (acceptedFiles.length !== 1) {
setPopup({
type: "error",
message: "Only one file can be uploaded at a time",
});
}
setTmpImageUrl(URL.createObjectURL(acceptedFiles[0]));
setSelectedFile(acceptedFiles[0]);
setDragActive(false);
}}
onDragLeave={() => setDragActive(false)}
onDragEnter={() => setDragActive(true)}
>
{({ getRootProps, getInputProps }) => (
<section>
<div
{...getRootProps()}
className={
"flex flex-col items-center w-full px-4 py-12 rounded " +
"shadow-lg tracking-wide border border-border cursor-pointer" +
(dragActive ? " border-accent" : "")
}
>
<input {...getInputProps()} />
<b className="text-emphasis">
Drag and drop a .png or .jpg file, or click to select a file!
</b>
</div>
{tmpImageUrl && (
<div className="mt-4 mb-8">
<SubLabel>Uploaded Image:</SubLabel>
<img src={tmpImageUrl} className="mt-4 max-w-xs max-h-64" />
</div>
)}
</section>
)}
</Dropzone>
</>
);
}

View File

@@ -0,0 +1,78 @@
"use client";
import { useState } from "react";
import { loadStripe } from "@stripe/stripe-js";
import { BillingPlanType } from "@/app/admin/settings/interfaces";
export function StripeCheckoutButton({
newQuantity,
newPlan,
currentQuantity,
currentPlan,
}: {
newQuantity: number;
newPlan: BillingPlanType;
currentQuantity: number;
currentPlan: BillingPlanType;
}) {
const [isLoading, setIsLoading] = useState(false);
const handleClick = async () => {
setIsLoading(true);
try {
const response = await fetch("/api/create-checkout-session", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({ quantity: newQuantity, plan: newPlan }),
});
if (!response.ok) {
throw new Error("Failed to create checkout session");
}
const { sessionId } = await response.json();
const stripe = await loadStripe(
process.env.NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY!
);
if (stripe) {
await stripe.redirectToCheckout({ sessionId });
} else {
throw new Error("Stripe failed to load");
}
} catch (error) {
console.error("Error:", error);
} finally {
setIsLoading(false);
}
};
return (
<button
onClick={handleClick}
className={`py-2 px-4 text-white rounded ${currentPlan === newPlan && currentQuantity === newQuantity
? "bg-gray-400 cursor-not-allowed"
: "bg-blue-500 hover:bg-blue-600"
} disabled:bg-blue-300`}
disabled={
(currentPlan === newPlan && currentQuantity === newQuantity) ||
isLoading
}
>
{isLoading
? "Loading..."
: currentPlan === newPlan && currentQuantity === newQuantity
? "No Changes"
: newPlan > currentPlan ||
(newPlan === currentPlan && newQuantity > currentQuantity)
? "Upgrade Plan"
: newPlan == BillingPlanType.ENTERPRISE
? "Talk to us"
: // : newPlan < currentPlan ||
newPlan === currentPlan && newQuantity < currentQuantity
? "Upgrade Plan"
: "Change Plan"}
</button>
);
}

View File

@@ -0,0 +1,19 @@
import { BillingSettings } from "./BillingSettings";
import { AdminPageTitle } from "@/components/admin/Title";
import { CreditCardIcon } from "@/components/icons/icons";
import { cookies } from "next/headers";
export default async function Whitelabeling() {
const newUser =
cookies().get("new_auth_user")?.value.toLocaleLowerCase() === "true";
return (
<div className="mx-auto container">
<AdminPageTitle
title="Billing"
icon={<CreditCardIcon size={32} className="my-auto" />}
/>
<BillingSettings newUser={newUser} />
</div>
);
}

View File

@@ -9,11 +9,11 @@ export default async function AdminLayout({
return (
<div className="flex h-screen">
<div className="mx-auto my-auto text-lg font-bold text-red-500">
This funcitonality is only available in the Enterprise Edition :(
This functionality is only available in the Enterprise Edition :(
</div>
</div>
);
}
return children;
}
}

View File

@@ -1,10 +1,5 @@
import "./globals.css";
import {
fetchEnterpriseSettingsSS,
fetchSettingsSS,
SettingsError,
} from "@/components/settings/lib";
import { fetchEnterpriseSettingsSS, fetchSettingsSS } from "@/components/settings/lib";
import {
CUSTOM_ANALYTICS_ENABLED,
SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED,
@@ -14,10 +9,8 @@ import { Metadata } from "next";
import { buildClientUrl } from "@/lib/utilsSS";
import { Inter } from "next/font/google";
import Head from "next/head";
import { EnterpriseSettings } from "./admin/settings/interfaces";
import { redirect } from "next/navigation";
import { Button, Card } from "@tremor/react";
import LogoType from "@/components/header/LogoType";
import { CombinedSettings, defaultCombinedSettings, EnterpriseSettings } from "./admin/settings/interfaces";
import { Card } from "@tremor/react";
import { HeaderTitle } from "@/components/header/HeaderTitle";
import { Logo } from "@/components/Logo";
import { UserProvider } from "@/components/user/UserProvider";
@@ -55,10 +48,9 @@ export default async function RootLayout({
}: {
children: React.ReactNode;
}) {
const combinedSettings = await fetchSettingsSS();
if (!combinedSettings) {
// Just display a simple full page error if fetching fails.
const combinedSettings = await fetchSettingsSS() || defaultCombinedSettings;
if (!combinedSettings) {
return (
<html lang="en" className={`${inter.variable} font-sans`}>
<Head>

View File

@@ -22,7 +22,7 @@ export const IsPublicGroupSelector = <T extends IsPublicGroupSelectorFormType>({
enforceGroupSelection?: boolean;
}) => {
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
const { isAdmin, user, isLoadingUser } = useUser();
const { isAdmin, user, isLoadingUser, isCurator } = useUser();
const [shouldHideContent, setShouldHideContent] = useState(false);
useEffect(() => {
@@ -87,42 +87,45 @@ export const IsPublicGroupSelector = <T extends IsPublicGroupSelectorFormType>({
)}
{(!formikProps.values.is_public ||
!isAdmin ||
formikProps.values.groups.length > 0) && (
<>
<div className="flex mt-4 gap-x-2 items-center">
<div className="block font-medium text-base">
Assign group access for this {objectName}
</div>
</div>
<Text className="mb-3">
{isAdmin || !enforceGroupSelection ? (
<>
This {objectName} will be visible/accessible by the groups
selected below
</>
) : (
<>
Curators must select one or more groups to give access to this{" "}
{objectName}
</>
)}
</Text>
<FieldArray
name="groups"
render={(arrayHelpers: ArrayHelpers) => (
<div className="flex gap-2 flex-wrap mb-4">
{userGroupsIsLoading ? (
<div className="animate-pulse bg-gray-200 h-8 w-32 rounded"></div>
isCurator ||
formikProps.values.groups.length > 0) &&
(userGroupsIsLoading ? (
<div className="animate-pulse bg-gray-200 h-8 w-32 rounded"></div>
) : (
userGroups &&
userGroups.length > 0 && (
<>
<div className="flex mt-4 gap-x-2 items-center">
<div className="block font-medium text-base">
Assign group access for this {objectName}
</div>
</div>
<Text className="mb-3">
{isAdmin || !enforceGroupSelection ? (
<>
This {objectName} will be visible/accessible by the groups
selected below
</>
) : (
userGroups &&
userGroups.map((userGroup: UserGroup) => {
const ind = formikProps.values.groups.indexOf(userGroup.id);
let isSelected = ind !== -1;
return (
<div
key={userGroup.id}
className={`
<>
Curators must select one or more groups to give access to
this {objectName}
</>
)}
</Text>
<FieldArray
name="groups"
render={(arrayHelpers: ArrayHelpers) => (
<div className="flex gap-2 flex-wrap mb-4">
{userGroups.map((userGroup: UserGroup) => {
const ind = formikProps.values.groups.indexOf(
userGroup.id
);
let isSelected = ind !== -1;
return (
<div
key={userGroup.id}
className={`
px-3
py-1
rounded-lg
@@ -132,32 +135,33 @@ export const IsPublicGroupSelector = <T extends IsPublicGroupSelectorFormType>({
flex
cursor-pointer
${isSelected ? "bg-background-strong" : "hover:bg-hover"}
`}
onClick={() => {
if (isSelected) {
arrayHelpers.remove(ind);
} else {
arrayHelpers.push(userGroup.id);
}
}}
>
<div className="my-auto flex">
<FiUsers className="my-auto mr-2" /> {userGroup.name}
`}
onClick={() => {
if (isSelected) {
arrayHelpers.remove(ind);
} else {
arrayHelpers.push(userGroup.id);
}
}}
>
<div className="my-auto flex">
<FiUsers className="my-auto mr-2" />{" "}
{userGroup.name}
</div>
</div>
</div>
);
})
);
})}
</div>
)}
</div>
)}
/>
<ErrorMessage
name="groups"
component="div"
className="text-error text-sm mt-1"
/>
</>
)}
/>
<ErrorMessage
name="groups"
component="div"
className="text-error text-sm mt-1"
/>
</>
)
))}
</div>
);
};

View File

@@ -21,6 +21,7 @@ import {
AssistantsIconSkeleton,
ClosedBookIcon,
SearchIcon,
CreditCardIcon,
} from "@/components/icons/icons";
import { UserRole } from "@/lib/types";
import { FiActivity, FiBarChart2 } from "react-icons/fi";
@@ -29,16 +30,17 @@ import { User } from "@/lib/types";
import { usePathname } from "next/navigation";
import { SettingsContext } from "../settings/SettingsProvider";
import { useContext } from "react";
import { CustomTooltip } from "../tooltip/CustomTooltip";
export function ClientLayout({
user,
children,
enableEnterprise,
cloudEnabled,
}: {
user: User | null;
children: React.ReactNode;
enableEnterprise: boolean;
cloudEnabled: boolean;
}) {
const isCurator =
user?.role === UserRole.CURATOR || user?.role === UserRole.GLOBAL_CURATOR;
@@ -317,6 +319,20 @@ export function ClientLayout({
},
]
: []),
...(cloudEnabled
? [
{
name: (
<div className="flex">
<CreditCardIcon size={18} />
<div className="ml-1">Billing</div>
</div>
),
link: "/admin/plan",
},
]
: []),
],
},
]

View File

@@ -6,7 +6,10 @@ import {
} from "@/lib/userSS";
import { redirect } from "next/navigation";
import { ClientLayout } from "./ClientLayout";
import { SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED } from "@/lib/constants";
import {
CLOUD_ENABLED,
SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED,
} from "@/lib/constants";
import { AnnouncementBanner } from "../header/AnnouncementBanner";
export async function Layout({ children }: { children: React.ReactNode }) {
@@ -43,6 +46,7 @@ export async function Layout({ children }: { children: React.ReactNode }) {
return (
<ClientLayout
enableEnterprise={SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED}
cloudEnabled={CLOUD_ENABLED}
user={user}
>
<AnnouncementBanner />

Some files were not shown because too many files have changed in this diff Show More