mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
64 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ebb9b94c6f | ||
|
|
a759238372 | ||
|
|
40cce46996 | ||
|
|
c9a84d5084 | ||
|
|
ee831b73e4 | ||
|
|
7419bf6b06 | ||
|
|
e5f3f2d73a | ||
|
|
dc5a91fd85 | ||
|
|
3158db4239 | ||
|
|
341bf26ff2 | ||
|
|
516f1840ce | ||
|
|
1f12b074df | ||
|
|
8f3f905a99 | ||
|
|
01e1bd0ee2 | ||
|
|
5f7be266f0 | ||
|
|
478dd1c4bb | ||
|
|
f0a5ec223f | ||
|
|
daad96d180 | ||
|
|
b853e5f22a | ||
|
|
dcc4c61fcb | ||
|
|
5775aec498 | ||
|
|
b4ee066424 | ||
|
|
88ade7cb7e | ||
|
|
a69a0333a5 | ||
|
|
198f80d224 | ||
|
|
4855a80f86 | ||
|
|
6e78f2094b | ||
|
|
c2e953633a | ||
|
|
0ff4ff0abc | ||
|
|
38af754968 | ||
|
|
4f9420217e | ||
|
|
e5584ca364 | ||
|
|
ae3218f941 | ||
|
|
5b220ac7b1 | ||
|
|
d1f40cfd30 | ||
|
|
fe3f6d451d | ||
|
|
d1641652a2 | ||
|
|
17412fb9f7 | ||
|
|
a28ac88341 | ||
|
|
95a11b8adc | ||
|
|
e9906c37fe | ||
|
|
127526d080 | ||
|
|
d3d63ee8f7 | ||
|
|
f98c77397d | ||
|
|
918623eb97 | ||
|
|
482117c4e7 | ||
|
|
827e4169c5 | ||
|
|
06c3e2064f | ||
|
|
0a1c8ae980 | ||
|
|
db54cb448b | ||
|
|
cb0a1e4fdc | ||
|
|
0547fff1d5 | ||
|
|
01841adb43 | ||
|
|
bc5b269446 | ||
|
|
64768f82f3 | ||
|
|
90da2166c2 | ||
|
|
c8f7e6185f | ||
|
|
7b895008d3 | ||
|
|
3838908e70 | ||
|
|
3112a9df9d | ||
|
|
fb29c70f37 | ||
|
|
e547cd6a79 | ||
|
|
1fa324b135 | ||
|
|
f4f3dd479e |
@@ -1,72 +1,81 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
|
||||
from typing import Tuple
|
||||
from alembic import context
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.models import Base
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy import pool, text
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from celery.backends.database.session import ResultModelBase # type: ignore
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Interpret the config file for Python logging.
|
||||
# This line sets up loggers basically.
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# add your model's MetaData object here
|
||||
# Add your model's MetaData object here
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
def get_schema_options() -> str:
|
||||
x_args_raw = context.get_x_argument()
|
||||
x_args = {}
|
||||
for arg in x_args_raw:
|
||||
for pair in arg.split(','):
|
||||
if '=' in pair:
|
||||
key, value = pair.split('=', 1)
|
||||
x_args[key] = value
|
||||
schema_name = x_args.get('schema', 'public')
|
||||
return schema_name
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL
|
||||
and not an Engine, though an Engine is acceptable
|
||||
here as well. By skipping the Engine creation
|
||||
we don't even need a DBAPI to be available.
|
||||
|
||||
Calls to context.execute() here emit the given string to the
|
||||
script output.
|
||||
|
||||
"""
|
||||
"""Run migrations in 'offline' mode."""
|
||||
url = build_connection_string()
|
||||
schema = get_schema_options()
|
||||
|
||||
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(connection=connection, target_metadata=target_metadata) # type: ignore
|
||||
schema = get_schema_options()
|
||||
|
||||
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
|
||||
connection.execute(text('COMMIT'))
|
||||
|
||||
connection.execute(text(f'SET search_path TO "{schema}"'))
|
||||
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
compare_server_default=True,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
"""In this scenario we need to create an Engine
|
||||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
|
||||
print("Running async migrations")
|
||||
"""Run migrations in 'online' mode."""
|
||||
connectable = create_async_engine(
|
||||
build_connection_string(),
|
||||
poolclass=pool.NullPool,
|
||||
@@ -77,13 +86,10 @@ async def run_async_migrations() -> None:
|
||||
|
||||
await connectable.dispose()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""Run migrations in 'online' mode."""
|
||||
|
||||
asyncio.run(run_async_migrations())
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
|
||||
@@ -9,9 +9,9 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from alembic_utils import encrypt_string
|
||||
import json
|
||||
|
||||
from danswer.utils.encryption import encrypt_string_to_bytes
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0a98909f2757"
|
||||
@@ -57,7 +57,7 @@ def upgrade() -> None:
|
||||
# In other words, this upgrade does not apply the encryption. Porting existing sensitive data
|
||||
# and key rotation currently is not supported and will come out in the future
|
||||
for row_id, creds, _ in results:
|
||||
creds_binary = encrypt_string_to_bytes(json.dumps(creds))
|
||||
creds_binary = encrypt_string(json.dumps(creds))
|
||||
connection.execute(
|
||||
creds_table.update()
|
||||
.where(creds_table.c.id == row_id)
|
||||
@@ -86,7 +86,7 @@ def upgrade() -> None:
|
||||
results = connection.execute(sa.select(llm_table))
|
||||
|
||||
for row_id, api_key, _ in results:
|
||||
llm_key = encrypt_string_to_bytes(api_key)
|
||||
llm_key = encrypt_string(api_key)
|
||||
connection.execute(
|
||||
llm_table.update()
|
||||
.where(llm_table.c.id == row_id)
|
||||
|
||||
@@ -8,7 +8,7 @@ Create Date: 2023-11-11 20:51:24.228999
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from alembic_utils import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "15326fcec57e"
|
||||
|
||||
@@ -9,8 +9,7 @@ Create Date: 2024-08-25 12:39:51.731632
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
from alembic_utils import NUM_POSTPROCESSED_RESULTS
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1f60f60c3401"
|
||||
|
||||
@@ -5,11 +5,8 @@ Revises: fad14119fb92
|
||||
Create Date: 2024-04-15 01:36:02.952809
|
||||
|
||||
"""
|
||||
import json
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "703313b75876"
|
||||
@@ -53,30 +50,6 @@ def upgrade() -> None:
|
||||
sa.PrimaryKeyConstraint("rate_limit_id", "user_group_id"),
|
||||
)
|
||||
|
||||
try:
|
||||
settings_json = cast(
|
||||
str, get_dynamic_config_store().load("token_budget_settings")
|
||||
)
|
||||
settings = json.loads(settings_json)
|
||||
|
||||
is_enabled = settings.get("enable_token_budget", False)
|
||||
token_budget = settings.get("token_budget", -1)
|
||||
period_hours = settings.get("period_hours", -1)
|
||||
|
||||
if is_enabled and token_budget > 0 and period_hours > 0:
|
||||
op.execute(
|
||||
f"INSERT INTO token_rate_limit \
|
||||
(enabled, token_budget, period_hours, scope) VALUES \
|
||||
({is_enabled}, {token_budget}, {period_hours}, 'GLOBAL')"
|
||||
)
|
||||
|
||||
# Delete the dynamic config
|
||||
get_dynamic_config_store().delete("token_budget_settings")
|
||||
|
||||
except Exception:
|
||||
# Ignore if the dynamic config is not found
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("token_rate_limit__user_group")
|
||||
|
||||
@@ -7,10 +7,8 @@ Create Date: 2024-03-22 21:34:27.629444
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from alembic_utils import IndexModelStatus, RecencyBiasSetting, SearchType
|
||||
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.enums import SearchType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "776b3bbe9092"
|
||||
|
||||
@@ -7,7 +7,7 @@ Create Date: 2024-03-21 12:05:23.956734
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from alembic_utils import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "91fd3b470d1a"
|
||||
|
||||
@@ -10,7 +10,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.dialects.postgresql import ENUM
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from alembic_utils import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b156fa702355"
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
"""add tenant id to user model
|
||||
|
||||
Revision ID: b25c363470f3
|
||||
Revises: 1f60f60c3401
|
||||
Create Date: 2024-08-29 17:03:20.794120
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b25c363470f3"
|
||||
down_revision = "1f60f60c3401"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("user", sa.Column("tenant_id", sa.Text(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "tenant_id")
|
||||
@@ -19,15 +19,16 @@ depends_on: None = None
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
existing_ids_and_chosen_assistants = conn.execute(
|
||||
sa.text("select id, chosen_assistants from public.user")
|
||||
sa.text('SELECT id, chosen_assistants FROM "user"')
|
||||
)
|
||||
op.drop_column(
|
||||
"user",
|
||||
'user',
|
||||
"chosen_assistants",
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
'user',
|
||||
sa.Column(
|
||||
"chosen_assistants",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
@@ -37,7 +38,7 @@ def upgrade() -> None:
|
||||
for id, chosen_assistants in existing_ids_and_chosen_assistants:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
|
||||
'UPDATE user SET chosen_assistants = :chosen_assistants WHERE id = :id'
|
||||
),
|
||||
{"chosen_assistants": json.dumps(chosen_assistants), "id": id},
|
||||
)
|
||||
@@ -46,20 +47,20 @@ def upgrade() -> None:
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
existing_ids_and_chosen_assistants = conn.execute(
|
||||
sa.text("select id, chosen_assistants from public.user")
|
||||
sa.text('SELECT id, chosen_assistants FROM user')
|
||||
)
|
||||
op.drop_column(
|
||||
"user",
|
||||
'user',
|
||||
"chosen_assistants",
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
'user',
|
||||
sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True),
|
||||
)
|
||||
for id, chosen_assistants in existing_ids_and_chosen_assistants:
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"update public.user set chosen_assistants = :chosen_assistants where id = :id"
|
||||
'UPDATE user SET chosen_assistants = :chosen_assistants WHERE id = :id'
|
||||
),
|
||||
{"chosen_assistants": chosen_assistants, "id": id},
|
||||
)
|
||||
|
||||
@@ -8,20 +8,13 @@ Create Date: 2024-01-25 17:12:31.813160
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import table, column, String, Integer, Boolean
|
||||
|
||||
from danswer.db.search_settings import (
|
||||
get_new_default_embedding_model,
|
||||
get_old_default_embedding_model,
|
||||
user_has_overridden_embedding_model,
|
||||
)
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from alembic_utils import IndexModelStatus
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "dbaa756c2ccf"
|
||||
down_revision = "7f726bad5367"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
@@ -40,9 +33,32 @@ def upgrade() -> None:
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
# since all index attempts must be associated with an embedding model,
|
||||
# need to put something in here to avoid nulls. On server startup,
|
||||
# this value will be overriden
|
||||
|
||||
# Define the old default embedding model directly
|
||||
old_embedding_model = {
|
||||
"model_name": "sentence-transformers/all-distilroberta-v1",
|
||||
"model_dim": 768,
|
||||
"normalize": True,
|
||||
"query_prefix": "",
|
||||
"passage_prefix": "",
|
||||
"index_name": "OPENSEARCH_INDEX_NAME",
|
||||
"status": IndexModelStatus.PAST,
|
||||
}
|
||||
|
||||
# Define the new default embedding model directly
|
||||
new_embedding_model = {
|
||||
"model_name": "intfloat/e5-small-v2",
|
||||
"model_dim": 384,
|
||||
"normalize": False,
|
||||
"query_prefix": "query: ",
|
||||
"passage_prefix": "passage: ",
|
||||
"index_name": "danswer_chunk_intfloat_e5_small_v2",
|
||||
"status": IndexModelStatus.PRESENT,
|
||||
}
|
||||
|
||||
# Assume the user has not overridden the embedding model
|
||||
user_overridden_embedding_model = False
|
||||
|
||||
EmbeddingModel = table(
|
||||
"embedding_model",
|
||||
column("id", Integer),
|
||||
@@ -52,45 +68,23 @@ def upgrade() -> None:
|
||||
column("query_prefix", String),
|
||||
column("passage_prefix", String),
|
||||
column("index_name", String),
|
||||
column(
|
||||
"status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False)
|
||||
),
|
||||
column("status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False)),
|
||||
)
|
||||
# insert an embedding model row that corresponds to the embedding model
|
||||
# the user selected via env variables before this change. This is needed since
|
||||
# all index_attempts must be associated with an embedding model, so without this
|
||||
# we will run into violations of non-null contraints
|
||||
old_embedding_model = get_old_default_embedding_model()
|
||||
|
||||
# Insert the old embedding model
|
||||
op.bulk_insert(
|
||||
EmbeddingModel,
|
||||
[
|
||||
{
|
||||
"model_name": old_embedding_model.model_name,
|
||||
"model_dim": old_embedding_model.model_dim,
|
||||
"normalize": old_embedding_model.normalize,
|
||||
"query_prefix": old_embedding_model.query_prefix,
|
||||
"passage_prefix": old_embedding_model.passage_prefix,
|
||||
"index_name": old_embedding_model.index_name,
|
||||
"status": IndexModelStatus.PRESENT,
|
||||
}
|
||||
old_embedding_model
|
||||
],
|
||||
)
|
||||
# if the user has not overridden the default embedding model via env variables,
|
||||
# insert the new default model into the database to auto-upgrade them
|
||||
if not user_has_overridden_embedding_model():
|
||||
new_embedding_model = get_new_default_embedding_model()
|
||||
|
||||
# If the user has not overridden the embedding model, insert the new default model
|
||||
if not user_overridden_embedding_model:
|
||||
op.bulk_insert(
|
||||
EmbeddingModel,
|
||||
[
|
||||
{
|
||||
"model_name": new_embedding_model.model_name,
|
||||
"model_dim": new_embedding_model.model_dim,
|
||||
"normalize": new_embedding_model.normalize,
|
||||
"query_prefix": new_embedding_model.query_prefix,
|
||||
"passage_prefix": new_embedding_model.passage_prefix,
|
||||
"index_name": new_embedding_model.index_name,
|
||||
"status": IndexModelStatus.FUTURE,
|
||||
}
|
||||
new_embedding_model
|
||||
],
|
||||
)
|
||||
|
||||
@@ -129,7 +123,6 @@ def upgrade() -> None:
|
||||
postgresql_where=sa.text("status = 'FUTURE'"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"index_attempt__embedding_model_fk", "index_attempt", type_="foreignkey"
|
||||
|
||||
@@ -8,7 +8,7 @@ Create Date: 2024-03-14 18:06:08.523106
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from alembic_utils import DocumentSource
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e50154680a5c"
|
||||
|
||||
99
backend/alembic_utils.py
Normal file
99
backend/alembic_utils.py
Normal file
@@ -0,0 +1,99 @@
|
||||
from cryptography.hazmat.primitives import padding
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from os import urandom
|
||||
import os
|
||||
from enum import Enum
|
||||
|
||||
ENCRYPTION_KEY_SECRET = os.environ.get("ENCRYPTION_KEY_SECRET") or ""
|
||||
|
||||
def _get_trimmed_key(key: str) -> bytes:
|
||||
encoded_key = key.encode()
|
||||
key_length = len(encoded_key)
|
||||
if key_length < 16:
|
||||
raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short")
|
||||
elif key_length > 32:
|
||||
key = key[:32]
|
||||
elif key_length not in (16, 24, 32):
|
||||
valid_lengths = [16, 24, 32]
|
||||
key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))]
|
||||
|
||||
return encoded_key
|
||||
|
||||
def encrypt_string(input_str: str) -> bytes:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
return input_str.encode()
|
||||
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
iv = urandom(16)
|
||||
padder = padding.PKCS7(algorithms.AES.block_size).padder()
|
||||
padded_data = padder.update(input_str.encode()) + padder.finalize()
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
return iv + encrypted_data
|
||||
|
||||
NUM_POSTPROCESSED_RESULTS = 20
|
||||
|
||||
class IndexModelStatus(str, Enum):
|
||||
PAST = "PAST"
|
||||
PRESENT = "PRESENT"
|
||||
FUTURE = "FUTURE"
|
||||
|
||||
|
||||
class RecencyBiasSetting(str, Enum):
|
||||
FAVOR_RECENT = "favor_recent" # 2x decay rate
|
||||
BASE_DECAY = "base_decay"
|
||||
NO_DECAY = "no_decay"
|
||||
# Determine based on query if to use base_decay or favor_recent
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
KEYWORD = "keyword"
|
||||
SEMANTIC = "semantic"
|
||||
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
# Special case, document passed in via Danswer APIs without specifying a source type
|
||||
INGESTION_API = "ingestion_api"
|
||||
SLACK = "slack"
|
||||
WEB = "web"
|
||||
GOOGLE_DRIVE = "google_drive"
|
||||
GMAIL = "gmail"
|
||||
REQUESTTRACKER = "requesttracker"
|
||||
GITHUB = "github"
|
||||
GITLAB = "gitlab"
|
||||
GURU = "guru"
|
||||
BOOKSTACK = "bookstack"
|
||||
CONFLUENCE = "confluence"
|
||||
SLAB = "slab"
|
||||
JIRA = "jira"
|
||||
PRODUCTBOARD = "productboard"
|
||||
FILE = "file"
|
||||
NOTION = "notion"
|
||||
ZULIP = "zulip"
|
||||
LINEAR = "linear"
|
||||
HUBSPOT = "hubspot"
|
||||
DOCUMENT360 = "document360"
|
||||
GONG = "gong"
|
||||
GOOGLE_SITES = "google_sites"
|
||||
ZENDESK = "zendesk"
|
||||
LOOPIO = "loopio"
|
||||
DROPBOX = "dropbox"
|
||||
SHAREPOINT = "sharepoint"
|
||||
TEAMS = "teams"
|
||||
SALESFORCE = "salesforce"
|
||||
DISCOURSE = "discourse"
|
||||
AXERO = "axero"
|
||||
CLICKUP = "clickup"
|
||||
MEDIAWIKI = "mediawiki"
|
||||
WIKIPEDIA = "wikipedia"
|
||||
S3 = "s3"
|
||||
R2 = "r2"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
OCI_STORAGE = "oci_storage"
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
@@ -33,6 +33,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
from danswer.configs.app_configs import SECRET_JWT_KEY
|
||||
from datetime import timedelta
|
||||
import contextlib
|
||||
import smtplib
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -8,6 +11,7 @@ from email.mime.text import MIMEText
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import jwt
|
||||
from email_validator import EmailNotValidError
|
||||
from email_validator import validate_email
|
||||
from fastapi import APIRouter
|
||||
@@ -54,6 +58,7 @@ from danswer.db.auth import get_access_token_db
|
||||
from danswer.db.auth import get_default_admin_user_emails
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.engine import get_async_session
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import AccessToken
|
||||
@@ -191,11 +196,88 @@ def send_user_verification_email(
|
||||
s.login(SMTP_USER, SMTP_PASS)
|
||||
s.send_message(msg)
|
||||
|
||||
def verify_sso_token(token: str) -> dict:
|
||||
try:
|
||||
payload = jwt.decode(token, "SSO_SECRET_KEY", algorithms=["HS256"])
|
||||
|
||||
if datetime.now(timezone.utc) > datetime.fromtimestamp(
|
||||
payload["exp"], timezone.utc
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Token has expired"
|
||||
)
|
||||
return payload
|
||||
except jwt.PyJWTError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token"
|
||||
)
|
||||
|
||||
|
||||
async def get_or_create_user(email: str, user_id: str) -> User:
|
||||
get_async_session_context = contextlib.asynccontextmanager(get_async_session)
|
||||
get_user_db_context = contextlib.asynccontextmanager(get_user_db)
|
||||
|
||||
async with get_async_session_context() as session:
|
||||
async with get_user_db_context(session) as user_db:
|
||||
existing_user = await user_db.get_by_email(email)
|
||||
if existing_user:
|
||||
return existing_user
|
||||
|
||||
new_user = {
|
||||
"email": email,
|
||||
"id": uuid.UUID(user_id),
|
||||
"role": UserRole.BASIC,
|
||||
"oidc_expiry": None,
|
||||
"default_model": None,
|
||||
"chosen_assistants": None,
|
||||
"hashed_password": "p",
|
||||
"is_active": True,
|
||||
"is_superuser": False,
|
||||
"is_verified": True,
|
||||
}
|
||||
|
||||
created_user: User = await user_db.create(new_user)
|
||||
return created_user
|
||||
|
||||
|
||||
async def create_user_session(user: User, tenant_id: str) -> str:
|
||||
# Create a payload user information and tenant_id
|
||||
payload = {
|
||||
"sub": str(user.id),
|
||||
"email": user.email,
|
||||
"tenant_id": tenant_id,
|
||||
"exp": datetime.utcnow() + timedelta(seconds=SESSION_EXPIRE_TIME_SECONDS)
|
||||
}
|
||||
|
||||
token = jwt.encode(payload, SECRET_JWT_KEY, algorithm="HS256")
|
||||
return token
|
||||
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
|
||||
async def sso_authenticate(
|
||||
self,
|
||||
email: str,
|
||||
tenant_id: str,
|
||||
) -> User:
|
||||
try:
|
||||
user = await self.get_by_email(email)
|
||||
|
||||
except Exception:
|
||||
# user_create = UserCreate(email=email, password=secrets.token_urlsafe(32))
|
||||
user_create = UserCreate(
|
||||
role=UserRole.BASIC, password="password", email=email, is_verified=True
|
||||
)
|
||||
user = await self.create(user_create)
|
||||
|
||||
# Update user with tenant information if needed
|
||||
if user.tenant_id != tenant_id:
|
||||
await self.user_db.update(user, {"tenant_id": tenant_id})
|
||||
return user
|
||||
|
||||
async def create(
|
||||
self,
|
||||
user_create: schemas.UC | UserCreate,
|
||||
@@ -210,6 +292,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
|
||||
return await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
|
||||
async def oauth_callback(
|
||||
@@ -298,7 +381,6 @@ def get_database_strategy(
|
||||
strategy = DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
||||
)
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.background.update import get_all_tenant_ids
|
||||
import json
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
@@ -67,11 +70,12 @@ _SYNC_BATCH_SIZE = 100
|
||||
def cleanup_connector_credential_pair_task(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str | None
|
||||
) -> int:
|
||||
"""Connector deletion task. This is run as an async task because it is a somewhat slow job.
|
||||
Needs to potentially update a large number of Postgres and Vespa docs, including deleting them
|
||||
or updating the ACL"""
|
||||
engine = get_sqlalchemy_engine()
|
||||
engine = get_sqlalchemy_engine(schema=tenant_id)
|
||||
with Session(engine) as db_session:
|
||||
# validate that the connector / credential pair is deletable
|
||||
cc_pair = get_connector_credential_pair(
|
||||
@@ -101,6 +105,7 @@ def cleanup_connector_credential_pair_task(
|
||||
db_session=db_session,
|
||||
document_index=document_index,
|
||||
cc_pair=cc_pair,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run connector_deletion due to {e}")
|
||||
@@ -109,7 +114,7 @@ def cleanup_connector_credential_pair_task(
|
||||
|
||||
@build_celery_task_wrapper(name_cc_prune_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
def prune_documents_task(connector_id: int, credential_id: int, tenant_id: str | None) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
@@ -167,6 +172,7 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_index=document_index,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
@@ -177,7 +183,7 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
|
||||
@build_celery_task_wrapper(name_document_set_sync_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_document_set_task(document_set_id: int) -> None:
|
||||
def sync_document_set_task(document_set_id: int, tenant_id: str | None) -> None:
|
||||
"""For document sets marked as not up to date, sync the state from postgres
|
||||
into the datastore. Also handles deletions."""
|
||||
|
||||
@@ -210,7 +216,7 @@ def sync_document_set_task(document_set_id: int) -> None:
|
||||
]
|
||||
document_index.update(update_requests=update_requests)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
|
||||
try:
|
||||
cursor = None
|
||||
while True:
|
||||
@@ -261,10 +267,10 @@ def sync_document_set_task(document_set_id: int) -> None:
|
||||
name="check_for_document_sets_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_document_sets_sync_task() -> None:
|
||||
def check_for_document_sets_sync_task(tenant_id: str | None) -> None:
|
||||
"""Runs periodically to check if any sync tasks should be run and adds them
|
||||
to the queue"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
user_id=None, db_session=db_session, include_outdated=True
|
||||
@@ -281,9 +287,9 @@ def check_for_document_sets_sync_task() -> None:
|
||||
name="check_for_cc_pair_deletion_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_cc_pair_deletion_task() -> None:
|
||||
def check_for_cc_pair_deletion_task(tenant_id: str | None) -> None:
|
||||
"""Runs periodically to check if any deletion tasks should be run"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
|
||||
# check if any document sets are not synced
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
@@ -293,6 +299,7 @@ def check_for_cc_pair_deletion_task() -> None:
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
tenant_id=tenant_id
|
||||
),
|
||||
)
|
||||
|
||||
@@ -303,7 +310,7 @@ def check_for_cc_pair_deletion_task() -> None:
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
)
|
||||
def kombu_message_cleanup_task(self: Any) -> int:
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
|
||||
"""Runs periodically to clean up the kombu_message table"""
|
||||
|
||||
# we will select messages older than this amount to clean up
|
||||
@@ -315,7 +322,7 @@ def kombu_message_cleanup_task(self: Any) -> int:
|
||||
ctx["deleted"] = 0
|
||||
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
|
||||
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
|
||||
# Exit the task if we can't take the advisory lock
|
||||
result = db_session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:id)"),
|
||||
@@ -416,11 +423,11 @@ def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
|
||||
name="check_for_prune_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_prune_task() -> None:
|
||||
def check_for_prune_task(tenant_id: str | None) -> None:
|
||||
"""Runs periodically to check if any prune tasks should be run and adds them
|
||||
to the queue"""
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
|
||||
for cc_pair in all_cc_pairs:
|
||||
@@ -435,6 +442,7 @@ def check_for_prune_task() -> None:
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -442,31 +450,33 @@ def check_for_prune_task() -> None:
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
celery_app.conf.beat_schedule = {
|
||||
"check-for-document-set-sync": {
|
||||
"task": "check_for_document_sets_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
},
|
||||
"check-for-cc-pair-deletion": {
|
||||
"task": "check_for_cc_pair_deletion_task",
|
||||
# don't need to check too often, since we kick off a deletion initially
|
||||
# during the API call that actually marks the CC pair for deletion
|
||||
"schedule": timedelta(minutes=1),
|
||||
},
|
||||
}
|
||||
celery_app.conf.beat_schedule.update(
|
||||
{
|
||||
"check-for-prune": {
|
||||
|
||||
def schedule_tenant_tasks() -> None:
|
||||
tenants = get_all_tenant_ids()
|
||||
|
||||
for tenant_id in tenants:
|
||||
# Schedule tasks specific to each tenant
|
||||
celery_app.conf.beat_schedule[f"check-for-document-set-sync-{tenant_id}"] = {
|
||||
"task": "check_for_document_sets_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"args": (tenant_id,),
|
||||
}
|
||||
celery_app.conf.beat_schedule[f"check-for-cc-pair-deletion-{tenant_id}"] = {
|
||||
"task": "check_for_cc_pair_deletion_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"args": (tenant_id,),
|
||||
}
|
||||
celery_app.conf.beat_schedule[f"check-for-prune-{tenant_id}"] = {
|
||||
"task": "check_for_prune_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
},
|
||||
}
|
||||
)
|
||||
celery_app.conf.beat_schedule.update(
|
||||
{
|
||||
"kombu-message-cleanup": {
|
||||
"args": (tenant_id,),
|
||||
}
|
||||
|
||||
# Schedule tasks that are not tenant-specific
|
||||
celery_app.conf.beat_schedule["kombu-message-cleanup"] = {
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
},
|
||||
}
|
||||
)
|
||||
"args": (tenant_id,),
|
||||
}
|
||||
|
||||
schedule_tenant_tasks()
|
||||
|
||||
@@ -35,6 +35,7 @@ from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from danswer.utils.variable_functionality import noop_fallback
|
||||
from danswer.configs.app_configs import DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -46,12 +47,13 @@ def delete_connector_credential_pair_batch(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
document_index: DocumentIndex,
|
||||
tenant_id: str | None
|
||||
) -> None:
|
||||
"""
|
||||
Removes a batch of documents ids from a cc-pair. If no other cc-pair uses a document anymore
|
||||
it gets permanently deleted.
|
||||
"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
|
||||
# acquire lock for all documents in this batch so that indexing can't
|
||||
# override the deletion
|
||||
with prepare_to_modify_documents(
|
||||
@@ -124,6 +126,7 @@ def delete_connector_credential_pair(
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
tenant_id: str | None
|
||||
) -> int:
|
||||
connector_id = cc_pair.connector_id
|
||||
credential_id = cc_pair.credential_id
|
||||
@@ -135,6 +138,7 @@ def delete_connector_credential_pair(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
limit=_DELETION_BATCH_SIZE,
|
||||
|
||||
)
|
||||
if not documents:
|
||||
break
|
||||
@@ -144,6 +148,7 @@ def delete_connector_credential_pair(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_index=document_index,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
num_docs_deleted += len(documents)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
@@ -5,7 +6,7 @@ from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
||||
from danswer.background.indexing.tracer import DanswerTracer
|
||||
from danswer.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
|
||||
@@ -16,7 +17,6 @@ from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
@@ -44,6 +44,7 @@ def _get_connector_runner(
|
||||
attempt: IndexAttempt,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
tenant_id: str | None
|
||||
) -> ConnectorRunner:
|
||||
"""
|
||||
NOTE: `start_time` and `end_time` are only used for poll connectors
|
||||
@@ -61,6 +62,7 @@ def _get_connector_runner(
|
||||
attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
attempt.connector_credential_pair.credential,
|
||||
db_session,
|
||||
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
@@ -82,6 +84,7 @@ def _get_connector_runner(
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt: IndexAttempt,
|
||||
tenant_id: str | None
|
||||
) -> None:
|
||||
"""
|
||||
1. Get documents which are either new or updated from specified application
|
||||
@@ -102,6 +105,7 @@ def _run_indexing(
|
||||
primary_index_name=index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings
|
||||
)
|
||||
@@ -113,6 +117,7 @@ def _run_indexing(
|
||||
ignore_time_skip=index_attempt.from_beginning
|
||||
or (search_settings.status == IndexModelStatus.FUTURE),
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
db_cc_pair = index_attempt.connector_credential_pair
|
||||
@@ -169,6 +174,7 @@ def _run_indexing(
|
||||
attempt=index_attempt,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
|
||||
all_connector_doc_ids: set[str] = set()
|
||||
@@ -196,7 +202,7 @@ def _run_indexing(
|
||||
db_session.refresh(index_attempt)
|
||||
if index_attempt.status != IndexingStatus.IN_PROGRESS:
|
||||
# Likely due to user manually disabling it or model swap
|
||||
raise RuntimeError("Index Attempt was canceled")
|
||||
raise RuntimeError(f"Index Attempt was canceled, status is {index_attempt.status}")
|
||||
|
||||
batch_description = []
|
||||
for doc in doc_batch:
|
||||
@@ -383,38 +389,30 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
||||
|
||||
return attempt
|
||||
|
||||
|
||||
def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None:
|
||||
"""Entrypoint for indexing run when using dask distributed.
|
||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
||||
and mark the attempt as failed."""
|
||||
def run_indexing_entrypoint(index_attempt_id: int, tenant_id: str | None, is_ee: bool = False) -> None:
|
||||
try:
|
||||
if is_ee:
|
||||
global_version.set_ee()
|
||||
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# make sure that it is valid to run this indexing attempt + mark it
|
||||
# as in progress
|
||||
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
|
||||
attempt = _prepare_index_attempt(db_session, index_attempt_id)
|
||||
|
||||
logger.info(
|
||||
f"Indexing starting: "
|
||||
f"Indexing starting for tenant {tenant_id}: " if tenant_id is not None else "" +
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
|
||||
_run_indexing(db_session, attempt)
|
||||
_run_indexing(db_session, attempt, tenant_id)
|
||||
|
||||
logger.info(
|
||||
f"Indexing finished: "
|
||||
f"Indexing finished for tenant {tenant_id}: " if tenant_id is not None else "" +
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
|
||||
logger.exception(f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}")
|
||||
|
||||
@@ -14,8 +14,11 @@ from danswer.db.tasks import mark_task_start
|
||||
from danswer.db.tasks import register_task
|
||||
|
||||
|
||||
def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str:
|
||||
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
def name_cc_cleanup_task(connector_id: int, credential_id: int, tenant_id: str | None = None) -> str:
|
||||
task_name = f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
if tenant_id is not None:
|
||||
task_name += f"_{tenant_id}"
|
||||
return task_name
|
||||
|
||||
|
||||
def name_document_set_sync_task(document_set_id: int) -> str:
|
||||
|
||||
@@ -8,6 +8,7 @@ from dask.distributed import Future
|
||||
from distributed import LocalCluster
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from sqlalchemy import text
|
||||
from danswer.background.indexing.dask_utils import ResourceLogger
|
||||
from danswer.background.indexing.job_client import SimpleJob
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
@@ -45,7 +46,8 @@ from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from sqlalchemy.exc import ProgrammingError
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -143,13 +145,14 @@ def _mark_run_failed(
|
||||
"""Main funcs"""
|
||||
|
||||
|
||||
def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None) -> None:
|
||||
|
||||
"""Creates new indexing jobs for each connector / credential pair which is:
|
||||
1. Enabled
|
||||
2. `refresh_frequency` time has passed since the last indexing run for this pair
|
||||
3. There is not already an ongoing indexing attempt for this pair
|
||||
"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
|
||||
ongoing: set[tuple[int | None, int]] = set()
|
||||
for attempt_id in existing_jobs:
|
||||
attempt = get_index_attempt(
|
||||
@@ -204,12 +207,13 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
|
||||
def cleanup_indexing_jobs(
|
||||
existing_jobs: dict[int, Future | SimpleJob],
|
||||
tenant_id: str | None,
|
||||
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
|
||||
) -> dict[int, Future | SimpleJob]:
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
|
||||
# clean up completed jobs
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
|
||||
for attempt_id, job in existing_jobs.items():
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
@@ -247,38 +251,42 @@ def cleanup_indexing_jobs(
|
||||
)
|
||||
|
||||
# clean up in-progress jobs that were never completed
|
||||
connectors = fetch_connectors(db_session)
|
||||
for connector in connectors:
|
||||
in_progress_indexing_attempts = get_inprogress_index_attempts(
|
||||
connector.id, db_session
|
||||
)
|
||||
for index_attempt in in_progress_indexing_attempts:
|
||||
if index_attempt.id in existing_jobs:
|
||||
# If index attempt is canceled, stop the run
|
||||
if index_attempt.status == IndexingStatus.FAILED:
|
||||
existing_jobs[index_attempt.id].cancel()
|
||||
# check to see if the job has been updated in last `timeout_hours` hours, if not
|
||||
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
|
||||
# on the fact that the `time_updated` field is constantly updated every
|
||||
# batch of documents indexed
|
||||
current_db_time = get_db_current_time(db_session=db_session)
|
||||
time_since_update = current_db_time - index_attempt.time_updated
|
||||
if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
|
||||
existing_jobs[index_attempt.id].cancel()
|
||||
try:
|
||||
connectors = fetch_connectors(db_session)
|
||||
for connector in connectors:
|
||||
in_progress_indexing_attempts = get_inprogress_index_attempts(
|
||||
connector.id, db_session
|
||||
)
|
||||
|
||||
for index_attempt in in_progress_indexing_attempts:
|
||||
if index_attempt.id in existing_jobs:
|
||||
# If index attempt is canceled, stop the run
|
||||
if index_attempt.status == IndexingStatus.FAILED:
|
||||
existing_jobs[index_attempt.id].cancel()
|
||||
# check to see if the job has been updated in last `timeout_hours` hours, if not
|
||||
# assume it to frozen in some bad state and just mark it as failed. Note: this relies
|
||||
# on the fact that the `time_updated` field is constantly updated every
|
||||
# batch of documents indexed
|
||||
current_db_time = get_db_current_time(db_session=db_session)
|
||||
time_since_update = current_db_time - index_attempt.time_updated
|
||||
if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
|
||||
existing_jobs[index_attempt.id].cancel()
|
||||
_mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason="Indexing run frozen - no updates in the last three hours. "
|
||||
"The run will be re-attempted at next scheduled indexing time.",
|
||||
)
|
||||
else:
|
||||
# If job isn't known, simply mark it as failed
|
||||
_mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason="Indexing run frozen - no updates in the last three hours. "
|
||||
"The run will be re-attempted at next scheduled indexing time.",
|
||||
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
)
|
||||
else:
|
||||
# If job isn't known, simply mark it as failed
|
||||
_mark_run_failed(
|
||||
db_session=db_session,
|
||||
index_attempt=index_attempt,
|
||||
failure_reason=_UNEXPECTED_STATE_FAILURE_REASON,
|
||||
)
|
||||
|
||||
except ProgrammingError as _:
|
||||
logger.debug(f"No Connector Table exists for: {tenant_id}")
|
||||
pass
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
@@ -286,9 +294,11 @@ def kickoff_indexing_jobs(
|
||||
existing_jobs: dict[int, Future | SimpleJob],
|
||||
client: Client | SimpleJobClient,
|
||||
secondary_client: Client | SimpleJobClient,
|
||||
tenant_id: str | None,
|
||||
) -> dict[int, Future | SimpleJob]:
|
||||
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
engine = get_sqlalchemy_engine()
|
||||
engine = get_sqlalchemy_engine(schema=tenant_id)
|
||||
|
||||
# Don't include jobs waiting in the Dask queue that just haven't started running
|
||||
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
||||
@@ -337,13 +347,16 @@ def kickoff_indexing_jobs(
|
||||
run = secondary_client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
tenant_id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
else:
|
||||
|
||||
run = client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
tenant_id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
@@ -376,41 +389,32 @@ def kickoff_indexing_jobs(
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
if not MULTI_TENANT:
|
||||
return [None]
|
||||
with Session(get_sqlalchemy_engine(schema='public')) as session:
|
||||
result = session.execute(text("""
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'public')
|
||||
"""))
|
||||
tenant_ids = [row[0] for row in result]
|
||||
valid_tenants = [tenant for tenant in tenant_ids if tenant is None or not tenant.startswith('pg_')]
|
||||
|
||||
return valid_tenants
|
||||
|
||||
|
||||
def update_loop(
|
||||
delay: int = 10,
|
||||
num_workers: int = NUM_INDEXING_WORKERS,
|
||||
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
|
||||
) -> None:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
|
||||
if search_settings.provider_type is None:
|
||||
logger.notice("Running a first inference to warm up embedding model")
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
client_primary: Client | SimpleJobClient
|
||||
client_secondary: Client | SimpleJobClient
|
||||
if DASK_JOB_CLIENT_ENABLED:
|
||||
cluster_primary = LocalCluster(
|
||||
n_workers=num_workers,
|
||||
threads_per_worker=1,
|
||||
# there are warning about high memory usage + "Event loop unresponsive"
|
||||
# which are not relevant to us since our workers are expected to use a
|
||||
# lot of memory + involve CPU intensive tasks that will not relinquish
|
||||
# the event loop
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
cluster_secondary = LocalCluster(
|
||||
@@ -426,37 +430,70 @@ def update_loop(
|
||||
client_primary = SimpleJobClient(n_workers=num_workers)
|
||||
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
|
||||
|
||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||
existing_jobs: dict[str | None, dict[int, Future | SimpleJob]] = {}
|
||||
|
||||
logger.notice("Startup complete. Waiting for indexing jobs...")
|
||||
while True:
|
||||
start = time.time()
|
||||
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.debug(f"Running update, current UTC time: {start_time_utc}")
|
||||
|
||||
if existing_jobs:
|
||||
# TODO: make this debug level once the "no jobs are being scheduled" issue is resolved
|
||||
logger.debug(
|
||||
"Found existing indexing jobs: "
|
||||
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
|
||||
f"{[(tenant_id, list(jobs.keys())) for tenant_id, jobs in existing_jobs.items()]}"
|
||||
)
|
||||
|
||||
try:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
check_index_swap(db_session)
|
||||
existing_jobs = cleanup_indexing_jobs(existing_jobs=existing_jobs)
|
||||
create_indexing_jobs(existing_jobs=existing_jobs)
|
||||
existing_jobs = kickoff_indexing_jobs(
|
||||
existing_jobs=existing_jobs,
|
||||
client=client_primary,
|
||||
secondary_client=client_secondary,
|
||||
)
|
||||
tenants = get_all_tenant_ids()
|
||||
|
||||
for tenant_id in tenants:
|
||||
try:
|
||||
logger.debug(f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}")
|
||||
engine = get_sqlalchemy_engine(schema=tenant_id)
|
||||
with Session(engine) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
if not MULTI_TENANT:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
if search_settings.provider_type is None:
|
||||
logger.notice("Running a first inference to warm up embedding model")
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
warm_up_bi_encoder(embedding_model=embedding_model)
|
||||
logger.notice("First inference complete.")
|
||||
|
||||
tenant_jobs = existing_jobs.get(tenant_id, {})
|
||||
|
||||
tenant_jobs = cleanup_indexing_jobs(
|
||||
existing_jobs=tenant_jobs,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
create_indexing_jobs(
|
||||
existing_jobs=tenant_jobs,
|
||||
tenant_id=tenant_id
|
||||
)
|
||||
tenant_jobs = kickoff_indexing_jobs(
|
||||
existing_jobs=tenant_jobs,
|
||||
client=client_primary,
|
||||
secondary_client=client_secondary,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
existing_jobs[tenant_id] = tenant_jobs
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to process tenant {tenant_id or 'default'}: {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to run update due to {e}")
|
||||
|
||||
sleep_time = delay - (time.time() - start)
|
||||
if sleep_time > 0:
|
||||
time.sleep(sleep_time)
|
||||
|
||||
|
||||
def update__main() -> None:
|
||||
set_is_ee_based_on_env_variable()
|
||||
init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME)
|
||||
|
||||
@@ -6,7 +6,6 @@ from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import PERSONAS_YAML
|
||||
from danswer.configs.chat_configs import PROMPTS_YAML
|
||||
from danswer.db.document_set import get_or_create_document_set_by_name
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import Persona
|
||||
@@ -18,148 +17,156 @@ from danswer.db.persona import upsert_prompt
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
|
||||
|
||||
def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None:
|
||||
def load_prompts_from_yaml(
|
||||
db_session: Session,
|
||||
prompts_yaml: str = PROMPTS_YAML
|
||||
) -> None:
|
||||
with open(prompts_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_prompts = data.get("prompts", [])
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for prompt in all_prompts:
|
||||
upsert_prompt(
|
||||
user=None,
|
||||
prompt_id=prompt.get("id"),
|
||||
name=prompt["name"],
|
||||
description=prompt["description"].strip(),
|
||||
system_prompt=prompt["system"].strip(),
|
||||
task_prompt=prompt["task"].strip(),
|
||||
include_citations=prompt["include_citations"],
|
||||
datetime_aware=prompt.get("datetime_aware", True),
|
||||
default_prompt=True,
|
||||
personas=None,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
for prompt in all_prompts:
|
||||
upsert_prompt(
|
||||
user=None,
|
||||
prompt_id=prompt.get("id"),
|
||||
name=prompt["name"],
|
||||
description=prompt["description"].strip(),
|
||||
system_prompt=prompt["system"].strip(),
|
||||
task_prompt=prompt["task"].strip(),
|
||||
include_citations=prompt["include_citations"],
|
||||
datetime_aware=prompt.get("datetime_aware", True),
|
||||
default_prompt=True,
|
||||
personas=None,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
|
||||
def load_personas_from_yaml(
|
||||
db_session: Session,
|
||||
personas_yaml: str = PERSONAS_YAML,
|
||||
default_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
) -> None:
|
||||
|
||||
with open(personas_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_personas = data.get("personas", [])
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for persona in all_personas:
|
||||
doc_set_names = persona["document_sets"]
|
||||
doc_sets: list[DocumentSetDBModel] = [
|
||||
get_or_create_document_set_by_name(db_session, name)
|
||||
for name in doc_set_names
|
||||
for persona in all_personas:
|
||||
doc_set_names = persona["document_sets"]
|
||||
doc_sets: list[DocumentSetDBModel] = [
|
||||
get_or_create_document_set_by_name(db_session, name)
|
||||
for name in doc_set_names
|
||||
]
|
||||
|
||||
# Assume if user hasn't set any document sets for the persona, the user may want
|
||||
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
|
||||
# the document sets for the persona
|
||||
doc_set_ids: list[int] | None = None
|
||||
if doc_sets:
|
||||
doc_set_ids = [doc_set.id for doc_set in doc_sets]
|
||||
else:
|
||||
doc_set_ids = None
|
||||
|
||||
prompt_ids: list[int] | None = None
|
||||
prompt_set_names = persona["prompts"]
|
||||
if prompt_set_names:
|
||||
prompts: list[PromptDBModel | None] = [
|
||||
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
|
||||
for prompt_name in prompt_set_names
|
||||
]
|
||||
if any([prompt is None for prompt in prompts]):
|
||||
raise ValueError("Invalid Persona configs, not all prompts exist")
|
||||
|
||||
# Assume if user hasn't set any document sets for the persona, the user may want
|
||||
# to later attach document sets to the persona manually, therefore, don't overwrite/reset
|
||||
# the document sets for the persona
|
||||
doc_set_ids: list[int] | None = None
|
||||
if doc_sets:
|
||||
doc_set_ids = [doc_set.id for doc_set in doc_sets]
|
||||
else:
|
||||
doc_set_ids = None
|
||||
if prompts:
|
||||
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
|
||||
|
||||
prompt_ids: list[int] | None = None
|
||||
prompt_set_names = persona["prompts"]
|
||||
if prompt_set_names:
|
||||
prompts: list[PromptDBModel | None] = [
|
||||
get_prompt_by_name(prompt_name, user=None, db_session=db_session)
|
||||
for prompt_name in prompt_set_names
|
||||
]
|
||||
if any([prompt is None for prompt in prompts]):
|
||||
raise ValueError("Invalid Persona configs, not all prompts exist")
|
||||
|
||||
if prompts:
|
||||
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
|
||||
|
||||
p_id = persona.get("id")
|
||||
tool_ids = []
|
||||
if persona.get("image_generation"):
|
||||
image_gen_tool = (
|
||||
db_session.query(ToolDBModel)
|
||||
.filter(ToolDBModel.name == "ImageGenerationTool")
|
||||
.first()
|
||||
)
|
||||
if image_gen_tool:
|
||||
tool_ids.append(image_gen_tool.id)
|
||||
|
||||
llm_model_provider_override = persona.get("llm_model_provider_override")
|
||||
llm_model_version_override = persona.get("llm_model_version_override")
|
||||
|
||||
# Set specific overrides for image generation persona
|
||||
if persona.get("image_generation"):
|
||||
llm_model_version_override = "gpt-4o"
|
||||
|
||||
existing_persona = (
|
||||
db_session.query(Persona)
|
||||
.filter(Persona.name == persona["name"])
|
||||
p_id = persona.get("id")
|
||||
tool_ids = []
|
||||
if persona.get("image_generation"):
|
||||
image_gen_tool = (
|
||||
db_session.query(ToolDBModel)
|
||||
.filter(ToolDBModel.name == "ImageGenerationTool")
|
||||
.first()
|
||||
)
|
||||
if image_gen_tool:
|
||||
tool_ids.append(image_gen_tool.id)
|
||||
|
||||
upsert_persona(
|
||||
user=None,
|
||||
persona_id=(-1 * p_id) if p_id is not None else None,
|
||||
name=persona["name"],
|
||||
description=persona["description"],
|
||||
num_chunks=persona.get("num_chunks")
|
||||
if persona.get("num_chunks") is not None
|
||||
else default_chunks,
|
||||
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
||||
starter_messages=persona.get("starter_messages"),
|
||||
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
||||
icon_shape=persona.get("icon_shape"),
|
||||
icon_color=persona.get("icon_color"),
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
prompt_ids=prompt_ids,
|
||||
document_set_ids=doc_set_ids,
|
||||
tool_ids=tool_ids,
|
||||
default_persona=True,
|
||||
is_public=True,
|
||||
display_priority=existing_persona.display_priority
|
||||
if existing_persona is not None
|
||||
else persona.get("display_priority"),
|
||||
is_visible=existing_persona.is_visible
|
||||
if existing_persona is not None
|
||||
else persona.get("is_visible"),
|
||||
db_session=db_session,
|
||||
)
|
||||
llm_model_provider_override = persona.get("llm_model_provider_override")
|
||||
llm_model_version_override = persona.get("llm_model_version_override")
|
||||
|
||||
# Set specific overrides for image generation persona
|
||||
if persona.get("image_generation"):
|
||||
llm_model_version_override = "gpt-4o"
|
||||
|
||||
def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None:
|
||||
existing_persona = (
|
||||
db_session.query(Persona)
|
||||
.filter(Persona.name == persona["name"])
|
||||
.first()
|
||||
)
|
||||
|
||||
upsert_persona(
|
||||
user=None,
|
||||
persona_id=(-1 * p_id) if p_id is not None else None,
|
||||
name=persona["name"],
|
||||
description=persona["description"],
|
||||
num_chunks=persona.get("num_chunks")
|
||||
if persona.get("num_chunks") is not None
|
||||
else default_chunks,
|
||||
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
||||
starter_messages=persona.get("starter_messages"),
|
||||
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
||||
icon_shape=persona.get("icon_shape"),
|
||||
icon_color=persona.get("icon_color"),
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
prompt_ids=prompt_ids,
|
||||
document_set_ids=doc_set_ids,
|
||||
tool_ids=tool_ids,
|
||||
default_persona=True,
|
||||
is_public=True,
|
||||
display_priority=existing_persona.display_priority
|
||||
if existing_persona is not None
|
||||
else persona.get("display_priority"),
|
||||
is_visible=existing_persona.is_visible
|
||||
if existing_persona is not None
|
||||
else persona.get("is_visible"),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
def load_input_prompts_from_yaml(
|
||||
db_session: Session,
|
||||
input_prompts_yaml: str = INPUT_PROMPT_YAML
|
||||
) -> None:
|
||||
with open(input_prompts_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_input_prompts = data.get("input_prompts", [])
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for input_prompt in all_input_prompts:
|
||||
# If these prompts are deleted (which is a hard delete in the DB), on server startup
|
||||
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
|
||||
insert_input_prompt_if_not_exists(
|
||||
user=None,
|
||||
input_prompt_id=input_prompt.get("id"),
|
||||
prompt=input_prompt["prompt"],
|
||||
content=input_prompt["content"],
|
||||
is_public=input_prompt["is_public"],
|
||||
active=input_prompt.get("active", True),
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
for input_prompt in all_input_prompts:
|
||||
# If these prompts are deleted (which is a hard delete in the DB), on server startup
|
||||
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
|
||||
|
||||
insert_input_prompt_if_not_exists(
|
||||
user=None,
|
||||
input_prompt_id=input_prompt.get("id"),
|
||||
prompt=input_prompt["prompt"],
|
||||
content=input_prompt["content"],
|
||||
is_public=input_prompt["is_public"],
|
||||
active=input_prompt.get("active", True),
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
def load_chat_yamls(
|
||||
db_session: Session,
|
||||
prompt_yaml: str = PROMPTS_YAML,
|
||||
personas_yaml: str = PERSONAS_YAML,
|
||||
input_prompts_yaml: str = INPUT_PROMPT_YAML,
|
||||
) -> None:
|
||||
load_prompts_from_yaml(prompt_yaml)
|
||||
load_personas_from_yaml(personas_yaml)
|
||||
load_input_prompts_from_yaml(input_prompts_yaml)
|
||||
load_prompts_from_yaml(db_session, prompt_yaml)
|
||||
load_personas_from_yaml(db_session, personas_yaml)
|
||||
load_input_prompts_from_yaml(db_session, input_prompts_yaml)
|
||||
|
||||
@@ -32,7 +32,6 @@ from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import reserve_message_id
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
@@ -314,12 +313,14 @@ def stream_chat_message_objects(
|
||||
try:
|
||||
llm, fast_llm = get_llms_for_persona(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
llm_override=new_msg_req.llm_override or chat_session.llm_override,
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
except GenAIDisabledException:
|
||||
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
|
||||
|
||||
|
||||
llm_provider = llm.config.model_provider
|
||||
llm_model_name = llm.config.model_name
|
||||
|
||||
@@ -631,6 +632,7 @@ def stream_chat_message_objects(
|
||||
or get_main_llm_from_tuple(
|
||||
get_llms_for_persona(
|
||||
persona=persona,
|
||||
db_session=db_session,
|
||||
llm_override=(
|
||||
new_msg_req.llm_override or chat_session.llm_override
|
||||
),
|
||||
@@ -799,18 +801,19 @@ def stream_chat_message_objects(
|
||||
def stream_chat_message(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
) -> Iterator[str]:
|
||||
with get_session_context_manager() as db_session:
|
||||
objects = stream_chat_message_objects(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
is_connected=is_connected,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
objects = stream_chat_message_objects(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
is_connected=is_connected,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
@@ -37,9 +37,11 @@ DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "
|
||||
WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
|
||||
|
||||
SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY") or "JWT_SECRET_KEY"
|
||||
#####
|
||||
# Auth Configs
|
||||
#####
|
||||
|
||||
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
|
||||
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
|
||||
|
||||
@@ -134,7 +136,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
os.environ.get("POSTGRES_PASSWORD") or "password"
|
||||
)
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
|
||||
# defaults to False
|
||||
@@ -366,3 +368,21 @@ CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads(
|
||||
ENTERPRISE_EDITION_ENABLED = (
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
|
||||
)
|
||||
|
||||
###
|
||||
# CLOUD CONFIGS
|
||||
###
|
||||
STRIPE_PRICE = os.environ.get("STRIPE_PRICE", "price_1PsYoPHlhTYqRZib2t5ydpq5")
|
||||
|
||||
|
||||
STRIPE_WEBHOOK_SECRET = os.environ.get(
|
||||
"STRIPE_WEBHOOK_SECRET",
|
||||
"whsec_1cd766cd6bd08590aa8c46ab5c21ac32cad77c29de2e09a152a01971d6f405d3"
|
||||
)
|
||||
|
||||
DEFAULT_SCHEMA = os.environ.get("DEFAULT_SCHEMA", "public")
|
||||
|
||||
DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "your_shared_secret_key")
|
||||
EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "your_control_plane_api_key")
|
||||
|
||||
MULTI_TENANT = os.environ.get("MULTI_TENANT", "false").lower() == "true"
|
||||
|
||||
@@ -88,3 +88,4 @@ HARD_DELETE_CHATS = False
|
||||
|
||||
# Internet Search
|
||||
BING_API_KEY = os.environ.get("BING_API_KEY") or None
|
||||
VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2)
|
||||
|
||||
@@ -239,7 +239,6 @@ def _datetime_from_string(datetime_string: str) -> datetime:
|
||||
else:
|
||||
# If not in UTC, translate it
|
||||
datetime_object = datetime_object.astimezone(timezone.utc)
|
||||
|
||||
return datetime_object
|
||||
|
||||
|
||||
|
||||
@@ -159,10 +159,12 @@ class LocalFileConnector(LoadConnector):
|
||||
self,
|
||||
file_locations: list[Path | str],
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
tenant_id: str | None = None
|
||||
) -> None:
|
||||
self.file_locations = [Path(file_location) for file_location in file_locations]
|
||||
self.batch_size = batch_size
|
||||
self.pdf_pass: str | None = None
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.pdf_pass = credentials.get("pdf_password")
|
||||
@@ -170,7 +172,7 @@ class LocalFileConnector(LoadConnector):
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
documents: list[Document] = []
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with Session(get_sqlalchemy_engine(schema=self.tenant_id)) as db_session:
|
||||
for file_path in self.file_locations:
|
||||
current_datetime = datetime.now(timezone.utc)
|
||||
files = _read_files_and_metadata(
|
||||
|
||||
@@ -154,7 +154,7 @@ def handle_regular_answer(
|
||||
get_editable=False,
|
||||
),
|
||||
)
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
llm, _ = get_llms_for_persona(persona, db_session=db_session)
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
input_tokens = get_max_input_tokens(
|
||||
@@ -171,6 +171,7 @@ def handle_regular_answer(
|
||||
persona=persona,
|
||||
actual_user_input=query_text,
|
||||
max_llm_token_override=remaining_tokens,
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
max_document_tokens = (
|
||||
|
||||
@@ -51,6 +51,7 @@ def get_chat_session_by_id(
|
||||
is_shared: bool = False,
|
||||
) -> ChatSession:
|
||||
stmt = select(ChatSession).where(ChatSession.id == chat_session_id)
|
||||
db_session.connection()
|
||||
|
||||
if is_shared:
|
||||
stmt = stmt.where(ChatSession.shared_status == ChatSessionSharedStatus.PUBLIC)
|
||||
@@ -86,7 +87,6 @@ def get_chat_sessions_by_slack_thread_id(
|
||||
)
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def get_first_messages_for_chat_sessions(
|
||||
chat_session_ids: list[int], db_session: Session
|
||||
) -> dict[int, str]:
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import contextvars
|
||||
from fastapi import Depends
|
||||
from fastapi import Request, HTTPException
|
||||
import contextlib
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from typing import ContextManager
|
||||
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import create_engine
|
||||
@@ -14,7 +16,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from danswer.configs.app_configs import SECRET_JWT_KEY
|
||||
from danswer.configs.app_configs import DEFAULT_SCHEMA
|
||||
from danswer.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
||||
from danswer.configs.app_configs import LOG_POSTGRES_LATENCY
|
||||
from danswer.configs.app_configs import POSTGRES_DB
|
||||
@@ -25,10 +28,18 @@ from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||
from danswer.configs.app_configs import POSTGRES_PORT
|
||||
from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jwt.exceptions import DecodeError, InvalidTokenError
|
||||
import jwt
|
||||
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
SYNC_DB_API = "psycopg2"
|
||||
ASYNC_DB_API = "asyncpg"
|
||||
|
||||
@@ -128,28 +139,32 @@ def init_sqlalchemy_engine(app_name: str) -> None:
|
||||
global POSTGRES_APP_NAME
|
||||
POSTGRES_APP_NAME = app_name
|
||||
|
||||
_engines: dict[str, Engine] = {}
|
||||
|
||||
def get_sqlalchemy_engine() -> Engine:
|
||||
global _SYNC_ENGINE
|
||||
if _SYNC_ENGINE is None:
|
||||
# NOTE: this is a hack to allow for multiple postgres schemas per engine for now.
|
||||
def get_sqlalchemy_engine(*, schema: str | None = DEFAULT_SCHEMA) -> Engine:
|
||||
if schema is None:
|
||||
schema = current_tenant_id.get()
|
||||
|
||||
global _engines
|
||||
if schema not in _engines:
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync"
|
||||
db_api=SYNC_DB_API, app_name=f"{POSTGRES_APP_NAME}_{schema}_sync"
|
||||
)
|
||||
_SYNC_ENGINE = create_engine(
|
||||
_engines[schema] = create_engine(
|
||||
connection_string,
|
||||
pool_size=40,
|
||||
max_overflow=10,
|
||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||
connect_args={"options": f"-c search_path={schema}"}
|
||||
)
|
||||
return _SYNC_ENGINE
|
||||
return _engines[schema]
|
||||
|
||||
|
||||
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
global _ASYNC_ENGINE
|
||||
if _ASYNC_ENGINE is None:
|
||||
# underlying asyncpg cannot accept application_name directly in the connection string
|
||||
# https://github.com/MagicStack/asyncpg/issues/798
|
||||
connection_string = build_connection_string()
|
||||
_ASYNC_ENGINE = create_async_engine(
|
||||
connection_string,
|
||||
@@ -163,26 +178,50 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
)
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
current_tenant_id = contextvars.ContextVar(
|
||||
"current_tenant_id", default=DEFAULT_SCHEMA
|
||||
)
|
||||
|
||||
def get_session_context_manager() -> ContextManager[Session]:
|
||||
return contextlib.contextmanager(get_session)()
|
||||
tenant_id = current_tenant_id.get()
|
||||
return contextlib.contextmanager(lambda: get_session(override_tenant_id=tenant_id))()
|
||||
|
||||
def get_current_tenant_id(request: Request) -> str | None:
|
||||
if not MULTI_TENANT:
|
||||
return DEFAULT_SCHEMA
|
||||
|
||||
def get_session() -> Generator[Session, None, None]:
|
||||
# The line below was added to monitor the latency caused by Postgres connections
|
||||
# during API calls.
|
||||
# with tracer.trace("db.get_session"):
|
||||
with Session(get_sqlalchemy_engine(), expire_on_commit=False) as session:
|
||||
token = request.cookies.get("tenant_details")
|
||||
if not token:
|
||||
return current_tenant_id.get()
|
||||
|
||||
try:
|
||||
payload = jwt.decode(token, SECRET_JWT_KEY, algorithms=["HS256"])
|
||||
tenant_id = payload.get("tenant_id")
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Invalid token: tenant_id missing")
|
||||
current_tenant_id.set(tenant_id)
|
||||
return tenant_id
|
||||
except (DecodeError, InvalidTokenError):
|
||||
raise HTTPException(status_code=401, detail="Invalid token format")
|
||||
except Exception:
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
def get_session(
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
override_tenant_id: str | None = None
|
||||
) -> Generator[Session, None, None]:
|
||||
if override_tenant_id:
|
||||
tenant_id = override_tenant_id
|
||||
|
||||
with Session(get_sqlalchemy_engine(schema=tenant_id), expire_on_commit=False) as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
async def get_async_session(tenant_id: str | None = None) -> AsyncGenerator[AsyncSession, None]:
|
||||
async with AsyncSession(
|
||||
get_sqlalchemy_async_engine(), expire_on_commit=False
|
||||
) as async_session:
|
||||
yield async_session
|
||||
|
||||
|
||||
async def warm_up_connections(
|
||||
sync_connections_to_warm_up: int = 20, async_connections_to_warm_up: int = 20
|
||||
) -> None:
|
||||
@@ -190,6 +229,7 @@ async def warm_up_connections(
|
||||
connections = [
|
||||
sync_postgres_engine.connect() for _ in range(sync_connections_to_warm_up)
|
||||
]
|
||||
|
||||
for conn in connections:
|
||||
conn.execute(text("SELECT 1"))
|
||||
for conn in connections:
|
||||
@@ -205,7 +245,6 @@ async def warm_up_connections(
|
||||
for async_conn in async_connections:
|
||||
await async_conn.close()
|
||||
|
||||
|
||||
def get_session_factory() -> sessionmaker[Session]:
|
||||
global SessionFactory
|
||||
if SessionFactory is None:
|
||||
|
||||
@@ -135,13 +135,16 @@ def fetch_embedding_provider(
|
||||
|
||||
|
||||
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_provider == True # noqa: E712
|
||||
)
|
||||
)
|
||||
|
||||
if not provider_model:
|
||||
return None
|
||||
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
|
||||
|
||||
|
||||
@@ -128,10 +128,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
|
||||
TIMESTAMPAware(timezone=True), nullable=True
|
||||
)
|
||||
tenant_id: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
|
||||
default_model: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
# organized in typical structured fashion
|
||||
# formatted as `displayName__provider__modelName`
|
||||
default_model: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
|
||||
# relationships
|
||||
credentials: Mapped[list["Credential"]] = relationship(
|
||||
@@ -1184,7 +1185,6 @@ class Tool(Base):
|
||||
|
||||
# user who created / owns the tool. Will be None for built-in tools.
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
|
||||
# Relationship to Persona through the association table
|
||||
personas: Mapped[list["Persona"]] = relationship(
|
||||
|
||||
@@ -1,22 +1,11 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
|
||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||
from danswer.configs.model_configs import DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.llm import fetch_embedding_provider
|
||||
from danswer.db.models import CloudEmbeddingProvider
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.indexing.models import IndexingSetting
|
||||
from danswer.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
from danswer.server.manage.embedding.models import (
|
||||
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
|
||||
@@ -174,76 +163,3 @@ def update_search_settings_status(
|
||||
search_settings.status = new_status
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def user_has_overridden_embedding_model() -> bool:
|
||||
return DOCUMENT_ENCODER_MODEL != DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
|
||||
|
||||
def get_old_default_search_settings() -> SearchSettings:
|
||||
is_overridden = user_has_overridden_embedding_model()
|
||||
return SearchSettings(
|
||||
model_name=(
|
||||
DOCUMENT_ENCODER_MODEL
|
||||
if is_overridden
|
||||
else OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
),
|
||||
model_dim=(
|
||||
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
),
|
||||
normalize=(
|
||||
NORMALIZE_EMBEDDINGS
|
||||
if is_overridden
|
||||
else OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
),
|
||||
query_prefix=(ASYM_QUERY_PREFIX if is_overridden else ""),
|
||||
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
|
||||
status=IndexModelStatus.PRESENT,
|
||||
index_name="danswer_chunk",
|
||||
)
|
||||
|
||||
|
||||
def get_new_default_search_settings(is_present: bool) -> SearchSettings:
|
||||
return SearchSettings(
|
||||
model_name=DOCUMENT_ENCODER_MODEL,
|
||||
model_dim=DOC_EMBEDDING_DIM,
|
||||
normalize=NORMALIZE_EMBEDDINGS,
|
||||
query_prefix=ASYM_QUERY_PREFIX,
|
||||
passage_prefix=ASYM_PASSAGE_PREFIX,
|
||||
status=IndexModelStatus.PRESENT if is_present else IndexModelStatus.FUTURE,
|
||||
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
|
||||
)
|
||||
|
||||
|
||||
def get_old_default_embedding_model() -> IndexingSetting:
|
||||
is_overridden = user_has_overridden_embedding_model()
|
||||
return IndexingSetting(
|
||||
model_name=(
|
||||
DOCUMENT_ENCODER_MODEL
|
||||
if is_overridden
|
||||
else OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
),
|
||||
model_dim=(
|
||||
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
),
|
||||
normalize=(
|
||||
NORMALIZE_EMBEDDINGS
|
||||
if is_overridden
|
||||
else OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
),
|
||||
query_prefix=(ASYM_QUERY_PREFIX if is_overridden else ""),
|
||||
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
|
||||
index_name="danswer_chunk",
|
||||
multipass_indexing=False,
|
||||
)
|
||||
|
||||
|
||||
def get_new_default_embedding_model() -> IndexingSetting:
|
||||
return IndexingSetting(
|
||||
model_name=DOCUMENT_ENCODER_MODEL,
|
||||
model_dim=DOC_EMBEDDING_DIM,
|
||||
normalize=NORMALIZE_EMBEDDINGS,
|
||||
query_prefix=ASYM_QUERY_PREFIX,
|
||||
passage_prefix=ASYM_PASSAGE_PREFIX,
|
||||
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
|
||||
multipass_indexing=False,
|
||||
)
|
||||
|
||||
36
backend/danswer/db_setup.py
Normal file
36
backend/danswer/db_setup.py
Normal file
@@ -0,0 +1,36 @@
|
||||
|
||||
from danswer.llm.llm_initialization import load_llm_providers
|
||||
from danswer.db.connector import create_initial_default_connector
|
||||
from danswer.db.connector_credential_pair import associate_default_cc_pair
|
||||
from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.db.standard_answer import create_initial_default_standard_answer_category
|
||||
from danswer.db.persona import delete_old_default_personas
|
||||
from danswer.chat.load_yamls import load_chat_yamls
|
||||
from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
|
||||
from danswer.tools.built_in_tools import load_builtin_tools
|
||||
from danswer.tools.built_in_tools import refresh_built_in_tools_cache
|
||||
from danswer.utils.logger import setup_logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
def setup_postgres(db_session: Session) -> None:
|
||||
logger.notice("Verifying default connector/credential exist.")
|
||||
create_initial_public_credential(db_session)
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
logger.notice("Verifying default standard answer category exists.")
|
||||
create_initial_default_standard_answer_category(db_session)
|
||||
|
||||
logger.notice("Loading LLM providers from env variables")
|
||||
load_llm_providers(db_session)
|
||||
|
||||
logger.notice("Loading default Prompts and Personas")
|
||||
delete_old_default_personas(db_session)
|
||||
load_chat_yamls(db_session)
|
||||
|
||||
logger.notice("Loading built-in tools")
|
||||
load_builtin_tools(db_session)
|
||||
refresh_built_in_tools_cache(db_session)
|
||||
auto_add_search_tool_to_personas(db_session)
|
||||
@@ -3,13 +3,21 @@ from danswer.document_index.vespa.index import VespaIndex
|
||||
|
||||
|
||||
def get_default_document_index(
|
||||
primary_index_name: str,
|
||||
secondary_index_name: str | None,
|
||||
primary_index_name: str | None = None,
|
||||
indices: list[str] | None = None,
|
||||
secondary_index_name: str | None = None
|
||||
) -> DocumentIndex:
|
||||
"""Primary index is the index that is used for querying/updating etc.
|
||||
Secondary index is for when both the currently used index and the upcoming
|
||||
index both need to be updated, updates are applied to both indices"""
|
||||
# Currently only supporting Vespa
|
||||
|
||||
indices = [primary_index_name] if primary_index_name is not None else indices
|
||||
if not indices:
|
||||
raise ValueError("No indices provided")
|
||||
|
||||
return VespaIndex(
|
||||
index_name=primary_index_name, secondary_index_name=secondary_index_name
|
||||
indices=indices,
|
||||
secondary_index_name=secondary_index_name
|
||||
)
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ class Verifiable(abc.ABC):
|
||||
all valid in the schema.
|
||||
|
||||
Parameters:
|
||||
- index_name: The name of the primary index currently used for querying
|
||||
- indices: The names of the primary indices currently used for querying
|
||||
- secondary_index_name: The name of the secondary index being built in the background, if it
|
||||
currently exists. Some functions on the document index act on both the primary and
|
||||
secondary index, some act on just one.
|
||||
@@ -86,20 +86,21 @@ class Verifiable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
indices: list[str],
|
||||
secondary_index_name: str | None,
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.index_name = index_name
|
||||
self.indices = indices
|
||||
self.secondary_index_name = secondary_index_name
|
||||
|
||||
@abc.abstractmethod
|
||||
def ensure_indices_exist(
|
||||
self,
|
||||
index_embedding_dim: int,
|
||||
secondary_index_embedding_dim: int | None,
|
||||
embedding_dims: list[int] | None = None,
|
||||
index_embedding_dim: int | None = None,
|
||||
secondary_index_embedding_dim: int | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Verify that the document index exists and is consistent with the expectations in the code.
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
schema DANSWER_CHUNK_NAME {
|
||||
document DANSWER_CHUNK_NAME {
|
||||
TENANT_ID_REPLACEMENT
|
||||
# Not to be confused with the UUID generated for this chunk which is called documentid by default
|
||||
field document_id type string {
|
||||
indexing: summary | attribute
|
||||
|
||||
@@ -335,6 +335,8 @@ def query_vespa(
|
||||
return inference_chunks
|
||||
|
||||
|
||||
|
||||
|
||||
def _get_chunks_via_batch_search(
|
||||
index_name: str,
|
||||
chunk_requests: list[VespaChunkRequest],
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import cast
|
||||
import httpx
|
||||
import requests
|
||||
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.chat_configs import DOC_TIME_DECAY
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
@@ -46,6 +47,8 @@ from danswer.document_index.vespa_constants import BATCH_SIZE
|
||||
from danswer.document_index.vespa_constants import BOOST
|
||||
from danswer.document_index.vespa_constants import CONTENT_SUMMARY
|
||||
from danswer.document_index.vespa_constants import DANSWER_CHUNK_REPLACEMENT_PAT
|
||||
from danswer.document_index.vespa_constants import TENANT_ID_PAT
|
||||
from danswer.document_index.vespa_constants import TENANT_ID_REPLACEMENT
|
||||
from danswer.document_index.vespa_constants import DATE_REPLACEMENT
|
||||
from danswer.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from danswer.document_index.vespa_constants import DOCUMENT_REPLACEMENT_PAT
|
||||
@@ -83,7 +86,7 @@ def in_memory_zip_from_file_bytes(file_contents: dict[str, bytes]) -> BinaryIO:
|
||||
return zip_buffer
|
||||
|
||||
|
||||
def _create_document_xml_lines(doc_names: list[str | None]) -> str:
|
||||
def _create_document_xml_lines(doc_names: list[str]) -> str:
|
||||
doc_lines = [
|
||||
f'<document type="{doc_name}" mode="index" />'
|
||||
for doc_name in doc_names
|
||||
@@ -108,15 +111,29 @@ def add_ngrams_to_schema(schema_content: str) -> str:
|
||||
|
||||
|
||||
class VespaIndex(DocumentIndex):
|
||||
def __init__(self, index_name: str, secondary_index_name: str | None) -> None:
|
||||
self.index_name = index_name
|
||||
def __init__(self, indices: list[str], secondary_index_name: str | None) -> None:
|
||||
self.indices = indices
|
||||
self.secondary_index_name = secondary_index_name
|
||||
|
||||
@property
|
||||
def index_name(self) -> str:
|
||||
if len(self.indices) == 0:
|
||||
raise ValueError("No indices provided")
|
||||
return self.indices[0]
|
||||
|
||||
def ensure_indices_exist(
|
||||
self,
|
||||
index_embedding_dim: int,
|
||||
secondary_index_embedding_dim: int | None,
|
||||
embedding_dims: list[int] | None = None,
|
||||
index_embedding_dim: int | None = None,
|
||||
secondary_index_embedding_dim: int | None = None
|
||||
) -> None:
|
||||
|
||||
if embedding_dims is None:
|
||||
if index_embedding_dim is not None:
|
||||
embedding_dims = [index_embedding_dim]
|
||||
else:
|
||||
raise ValueError("No embedding dimensions provided")
|
||||
|
||||
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
|
||||
logger.debug(f"Sending Vespa zip to {deploy_url}")
|
||||
|
||||
@@ -130,9 +147,15 @@ class VespaIndex(DocumentIndex):
|
||||
with open(services_file, "r") as services_f:
|
||||
services_template = services_f.read()
|
||||
|
||||
schema_names = [self.index_name, self.secondary_index_name]
|
||||
# Generate schema names from index settings
|
||||
schema_names = [index_name for index_name in self.indices]
|
||||
|
||||
full_schemas = schema_names
|
||||
if self.secondary_index_name:
|
||||
full_schemas.append(self.secondary_index_name)
|
||||
|
||||
doc_lines = _create_document_xml_lines(full_schemas)
|
||||
|
||||
doc_lines = _create_document_xml_lines(schema_names)
|
||||
services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines)
|
||||
kv_store = get_dynamic_config_store()
|
||||
|
||||
@@ -160,27 +183,38 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
with open(schema_file, "r") as schema_f:
|
||||
schema_template = schema_f.read()
|
||||
schema = schema_template.replace(
|
||||
DANSWER_CHUNK_REPLACEMENT_PAT, self.index_name
|
||||
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
|
||||
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
|
||||
zip_dict[f"schemas/{schema_names[0]}.sd"] = schema.encode("utf-8")
|
||||
|
||||
for i, index_name in enumerate(self.indices):
|
||||
embedding_dim = embedding_dims[i]
|
||||
logger.info(f"Creating index: {index_name} with embedding dimension: {embedding_dim}")
|
||||
|
||||
schema = schema_template.replace(
|
||||
DANSWER_CHUNK_REPLACEMENT_PAT, index_name
|
||||
).replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim))
|
||||
schema = schema.replace(TENANT_ID_PAT, TENANT_ID_REPLACEMENT if MULTI_TENANT else "")
|
||||
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
|
||||
zip_dict[f"schemas/{index_name}.sd"] = schema.encode("utf-8")
|
||||
|
||||
if self.secondary_index_name:
|
||||
logger.info("Creating secondary index:"
|
||||
f"{self.secondary_index_name} with embedding dimension: {secondary_index_embedding_dim}")
|
||||
upcoming_schema = schema_template.replace(
|
||||
DANSWER_CHUNK_REPLACEMENT_PAT, self.secondary_index_name
|
||||
).replace(VESPA_DIM_REPLACEMENT_PAT, str(secondary_index_embedding_dim))
|
||||
zip_dict[f"schemas/{schema_names[1]}.sd"] = upcoming_schema.encode("utf-8")
|
||||
upcoming_schema = upcoming_schema.replace(TENANT_ID_PAT, TENANT_ID_REPLACEMENT if MULTI_TENANT else "")
|
||||
zip_dict[f"schemas/{self.secondary_index_name}.sd"] = upcoming_schema.encode("utf-8")
|
||||
|
||||
zip_file = in_memory_zip_from_file_bytes(zip_dict)
|
||||
|
||||
headers = {"Content-Type": "application/zip"}
|
||||
response = requests.post(deploy_url, headers=headers, data=zip_file)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Failed to prepare Vespa Danswer Index. Response: {response.text}"
|
||||
f"Failed to prepare Vespa Danswer Indexes. Response: {response.text}"
|
||||
)
|
||||
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
@@ -230,7 +264,6 @@ class VespaIndex(DocumentIndex):
|
||||
)
|
||||
|
||||
all_doc_ids = {chunk.source_document.id for chunk in cleaned_chunks}
|
||||
|
||||
return {
|
||||
DocumentInsertionRecord(
|
||||
document_id=doc_id,
|
||||
@@ -282,7 +315,7 @@ class VespaIndex(DocumentIndex):
|
||||
raise requests.HTTPError(failure_msg) from e
|
||||
|
||||
def update(self, update_requests: list[UpdateRequest]) -> None:
|
||||
logger.info(f"Updating {len(update_requests)} documents in Vespa")
|
||||
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
|
||||
|
||||
# Handle Vespa character limitations
|
||||
# Mutating update_requests but it's not used later anyway
|
||||
@@ -371,6 +404,91 @@ class VespaIndex(DocumentIndex):
|
||||
time.monotonic() - update_start,
|
||||
)
|
||||
|
||||
def update_single(self, update_request: UpdateRequest) -> None:
|
||||
"""Note: if the document id does not exist, the update will be a no-op and the
|
||||
function will complete with no errors or exceptions.
|
||||
Handle other exceptions if you wish to implement retry behavior
|
||||
"""
|
||||
if len(update_request.document_ids) != 1:
|
||||
raise ValueError("update_request must contain a single document id")
|
||||
|
||||
# Handle Vespa character limitations
|
||||
# Mutating update_request but it's not used later anyway
|
||||
update_request.document_ids = [
|
||||
replace_invalid_doc_id_characters(doc_id)
|
||||
for doc_id in update_request.document_ids
|
||||
]
|
||||
|
||||
update_start = time.monotonic()
|
||||
|
||||
# Fetch all chunks for each document ahead of time
|
||||
index_names = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
chunk_id_start_time = time.monotonic()
|
||||
all_doc_chunk_ids: list[str] = []
|
||||
for index_name in index_names:
|
||||
for document_id in update_request.document_ids:
|
||||
# this calls vespa and can raise http exceptions
|
||||
doc_chunk_ids = get_all_vespa_ids_for_document_id(
|
||||
document_id=document_id,
|
||||
index_name=index_name,
|
||||
filters=None,
|
||||
get_large_chunks=True,
|
||||
)
|
||||
all_doc_chunk_ids.extend(doc_chunk_ids)
|
||||
logger.debug(
|
||||
f"Took {time.monotonic() - chunk_id_start_time:.2f} seconds to fetch all Vespa chunk IDs"
|
||||
)
|
||||
|
||||
# Build the _VespaUpdateRequest objects
|
||||
update_dict: dict[str, dict] = {"fields": {}}
|
||||
if update_request.boost is not None:
|
||||
update_dict["fields"][BOOST] = {"assign": update_request.boost}
|
||||
if update_request.document_sets is not None:
|
||||
update_dict["fields"][DOCUMENT_SETS] = {
|
||||
"assign": {
|
||||
document_set: 1 for document_set in update_request.document_sets
|
||||
}
|
||||
}
|
||||
if update_request.access is not None:
|
||||
update_dict["fields"][ACCESS_CONTROL_LIST] = {
|
||||
"assign": {acl_entry: 1 for acl_entry in update_request.access.to_acl()}
|
||||
}
|
||||
if update_request.hidden is not None:
|
||||
update_dict["fields"][HIDDEN] = {"assign": update_request.hidden}
|
||||
|
||||
if not update_dict["fields"]:
|
||||
logger.error("Update request received but nothing to update")
|
||||
return
|
||||
|
||||
processed_update_requests: list[_VespaUpdateRequest] = []
|
||||
for document_id in update_request.document_ids:
|
||||
for doc_chunk_id in all_doc_chunk_ids:
|
||||
processed_update_requests.append(
|
||||
_VespaUpdateRequest(
|
||||
document_id=document_id,
|
||||
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
|
||||
update_request=update_dict,
|
||||
)
|
||||
)
|
||||
|
||||
with httpx.Client(http2=True) as http_client:
|
||||
for update in processed_update_requests:
|
||||
http_client.put(
|
||||
update.url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=update.update_request,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Finished updating Vespa documents in %.2f seconds",
|
||||
time.monotonic() - update_start,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
def delete(self, doc_ids: list[str]) -> None:
|
||||
logger.info(f"Deleting {len(doc_ids)} documents from Vespa")
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from danswer.document_index.vespa_constants import CHUNK_ID
|
||||
from danswer.document_index.vespa_constants import CONTENT
|
||||
from danswer.document_index.vespa_constants import CONTENT_SUMMARY
|
||||
from danswer.document_index.vespa_constants import DOC_UPDATED_AT
|
||||
from danswer.document_index.vespa_constants import TENANT_ID
|
||||
from danswer.document_index.vespa_constants import DOCUMENT_ID
|
||||
from danswer.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from danswer.document_index.vespa_constants import DOCUMENT_SETS
|
||||
@@ -65,6 +66,8 @@ def _does_document_exist(
|
||||
raise RuntimeError(
|
||||
f"Unexpected fetch document by ID value from Vespa "
|
||||
f"with error {doc_fetch_response.status_code}"
|
||||
f"Index name: {index_name}"
|
||||
f"Doc chunk id: {doc_chunk_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -95,7 +98,7 @@ def get_existing_documents_from_chunks(
|
||||
try:
|
||||
chunk_existence_future = {
|
||||
executor.submit(
|
||||
_does_document_exist,
|
||||
_does_document_exist,
|
||||
str(get_uuid_from_chunk(chunk)),
|
||||
index_name,
|
||||
http_client,
|
||||
@@ -117,7 +120,9 @@ def get_existing_documents_from_chunks(
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _index_vespa_chunk(
|
||||
chunk: DocMetadataAwareIndexChunk, index_name: str, http_client: httpx.Client
|
||||
chunk: DocMetadataAwareIndexChunk,
|
||||
index_name: str,
|
||||
http_client: httpx.Client,
|
||||
) -> None:
|
||||
json_header = {
|
||||
"Content-Type": "application/json",
|
||||
@@ -172,8 +177,10 @@ def _index_vespa_chunk(
|
||||
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
|
||||
}
|
||||
|
||||
if chunk.tenant_id:
|
||||
vespa_document_fields[TENANT_ID] = chunk.tenant_id
|
||||
|
||||
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_chunk_id}"
|
||||
logger.debug(f'Indexing to URL "{vespa_url}"')
|
||||
res = http_client.post(
|
||||
vespa_url, headers=json_header, json={"fields": vespa_document_fields}
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ from danswer.document_index.interfaces import VespaChunkRequest
|
||||
from danswer.document_index.vespa_constants import ACCESS_CONTROL_LIST
|
||||
from danswer.document_index.vespa_constants import CHUNK_ID
|
||||
from danswer.document_index.vespa_constants import DOC_UPDATED_AT
|
||||
from danswer.document_index.vespa_constants import TENANT_ID
|
||||
from danswer.document_index.vespa_constants import DOCUMENT_ID
|
||||
from danswer.document_index.vespa_constants import DOCUMENT_SETS
|
||||
from danswer.document_index.vespa_constants import HIDDEN
|
||||
@@ -19,6 +20,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) -> str:
|
||||
|
||||
def _build_or_filters(key: str, vals: list[str] | None) -> str:
|
||||
if vals is None:
|
||||
return ""
|
||||
@@ -53,6 +55,10 @@ def build_vespa_filters(filters: IndexFilters, include_hidden: bool = False) ->
|
||||
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
|
||||
if filters.tenant_id:
|
||||
filter_str += f'({TENANT_ID} contains "{filters.tenant_id}") and '
|
||||
|
||||
|
||||
# CAREFUL touching this one, currently there is no second ACL double-check post retrieval
|
||||
if filters.access_control_list is not None:
|
||||
filter_str += _build_or_filters(
|
||||
|
||||
@@ -8,7 +8,14 @@ VESPA_DIM_REPLACEMENT_PAT = "VARIABLE_DIM"
|
||||
DANSWER_CHUNK_REPLACEMENT_PAT = "DANSWER_CHUNK_NAME"
|
||||
DOCUMENT_REPLACEMENT_PAT = "DOCUMENT_REPLACEMENT"
|
||||
DATE_REPLACEMENT = "DATE_REPLACEMENT"
|
||||
SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER"
|
||||
TENANT_ID_PAT = "TENANT_ID_REPLACEMENT"
|
||||
|
||||
TENANT_ID_REPLACEMENT = """field tenant_id type string {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}"""
|
||||
# config server
|
||||
VESPA_CONFIG_SERVER_URL = f"http://{VESPA_CONFIG_SERVER_HOST}:{VESPA_TENANT_PORT}"
|
||||
VESPA_APPLICATION_ENDPOINT = f"{VESPA_CONFIG_SERVER_URL}/application/v2"
|
||||
@@ -31,7 +38,7 @@ MAX_ID_SEARCH_QUERY_SIZE = 400
|
||||
VESPA_TIMEOUT = "3s"
|
||||
BATCH_SIZE = 128 # Specific to Vespa
|
||||
|
||||
|
||||
TENANT_ID = "tenant_id"
|
||||
DOCUMENT_ID = "document_id"
|
||||
CHUNK_ID = "chunk_id"
|
||||
BLURB = "blurb"
|
||||
|
||||
@@ -134,6 +134,7 @@ def index_doc_batch_with_handler(
|
||||
attempt_id: int | None,
|
||||
db_session: Session,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
) -> tuple[int, int]:
|
||||
r = (0, 0)
|
||||
try:
|
||||
@@ -145,6 +146,7 @@ def index_doc_batch_with_handler(
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
db_session=db_session,
|
||||
ignore_time_skip=ignore_time_skip,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
if INDEXING_EXCEPTION_LIMIT == 0:
|
||||
@@ -258,6 +260,7 @@ def index_doc_batch(
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
) -> tuple[int, int]:
|
||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||
Note that the documents should already be batched at this point so that it does not inflate the
|
||||
@@ -316,6 +319,7 @@ def index_doc_batch(
|
||||
if chunk.source_document.id in ctx.id_to_db_doc_map
|
||||
else DEFAULT_BOOST
|
||||
),
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
for chunk in chunks_with_embeddings
|
||||
]
|
||||
@@ -357,6 +361,7 @@ def build_indexing_pipeline(
|
||||
chunker: Chunker | None = None,
|
||||
ignore_time_skip: bool = False,
|
||||
attempt_id: int | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
@@ -393,4 +398,5 @@ def build_indexing_pipeline(
|
||||
ignore_time_skip=ignore_time_skip,
|
||||
attempt_id=attempt_id,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
@@ -55,12 +55,10 @@ class DocAwareChunk(BaseChunk):
|
||||
f"Chunk ID: '{self.chunk_id}'; {self.source_document.to_short_descriptor()}"
|
||||
)
|
||||
|
||||
|
||||
class IndexChunk(DocAwareChunk):
|
||||
embeddings: ChunkEmbedding
|
||||
title_embedding: Embedding | None
|
||||
|
||||
|
||||
class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
"""An `IndexChunk` that contains all necessary metadata to be indexed. This includes
|
||||
the following:
|
||||
@@ -73,6 +71,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
negative -> ranked lower.
|
||||
"""
|
||||
|
||||
tenant_id: str | None = None
|
||||
access: "DocumentAccess"
|
||||
document_sets: set[str]
|
||||
boost: int
|
||||
@@ -84,6 +83,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access: "DocumentAccess",
|
||||
document_sets: set[str],
|
||||
boost: int,
|
||||
tenant_id: str | None,
|
||||
) -> "DocMetadataAwareIndexChunk":
|
||||
index_chunk_data = index_chunk.model_dump()
|
||||
return cls(
|
||||
@@ -91,6 +91,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access=access,
|
||||
document_sets=document_sets,
|
||||
boost=boost,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
@@ -94,13 +95,14 @@ def compute_max_document_tokens(
|
||||
|
||||
def compute_max_document_tokens_for_persona(
|
||||
persona: Persona,
|
||||
db_session: Session,
|
||||
actual_user_input: str | None = None,
|
||||
max_llm_token_override: int | None = None,
|
||||
) -> int:
|
||||
prompt = persona.prompts[0] if persona.prompts else get_default_prompt__read_only()
|
||||
return compute_max_document_tokens(
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config,
|
||||
llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona, db_session=db_session)).config,
|
||||
actual_user_input=actual_user_input,
|
||||
max_llm_token_override=max_llm_token_override,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_default_provider
|
||||
from danswer.db.llm import fetch_provider
|
||||
from danswer.db.models import Persona
|
||||
@@ -10,16 +10,20 @@ from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.headers import build_llm_extra_headers
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
def get_main_llm_from_tuple(
|
||||
llms: tuple[LLM, LLM],
|
||||
) -> LLM:
|
||||
return llms[0]
|
||||
|
||||
|
||||
def get_llms_for_persona(
|
||||
persona: Persona,
|
||||
db_session: Session,
|
||||
llm_override: LLMOverride | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
) -> tuple[LLM, LLM]:
|
||||
@@ -28,14 +32,15 @@ def get_llms_for_persona(
|
||||
temperature_override = llm_override.temperature if llm_override else None
|
||||
|
||||
provider_name = model_provider_override or persona.llm_model_provider_override
|
||||
|
||||
if not provider_name:
|
||||
return get_default_llms(
|
||||
temperature=temperature_override or GEN_AI_TEMPERATURE,
|
||||
additional_headers=additional_headers,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
llm_provider = fetch_provider(db_session, provider_name)
|
||||
llm_provider = fetch_provider(db_session, provider_name)
|
||||
|
||||
if not llm_provider:
|
||||
raise ValueError("No LLM provider found")
|
||||
@@ -57,7 +62,6 @@ def get_llms_for_persona(
|
||||
custom_config=llm_provider.custom_config,
|
||||
additional_headers=additional_headers,
|
||||
)
|
||||
|
||||
return _create_llm(model), _create_llm(fast_model)
|
||||
|
||||
|
||||
@@ -65,13 +69,19 @@ def get_default_llms(
|
||||
timeout: int = QA_TIMEOUT,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
db_session: Session | None = None,
|
||||
) -> tuple[LLM, LLM]:
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise GenAIDisabledException()
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
|
||||
if db_session is None:
|
||||
with get_session_context_manager() as db_session:
|
||||
llm_provider = fetch_default_provider(db_session)
|
||||
else:
|
||||
llm_provider = fetch_default_provider(db_session)
|
||||
|
||||
|
||||
if not llm_provider:
|
||||
raise ValueError("No default LLM provider found")
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -12,15 +14,17 @@ from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer import __version__
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRead
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.auth.users import auth_backend
|
||||
from danswer.auth.users import fastapi_users
|
||||
from danswer.chat.load_yamls import load_chat_yamls
|
||||
from sqlalchemy.orm import Session
|
||||
from danswer.indexing.models import IndexingSetting
|
||||
from danswer.configs.app_configs import APP_API_PREFIX
|
||||
from danswer.configs.app_configs import APP_HOST
|
||||
from danswer.configs.app_configs import APP_PORT
|
||||
@@ -37,30 +41,22 @@ from danswer.configs.constants import KV_REINDEX_KEY
|
||||
from danswer.configs.constants import KV_SEARCH_SETTINGS
|
||||
from danswer.configs.constants import POSTGRES_WEB_APP_NAME
|
||||
from danswer.db.connector import check_connectors_exist
|
||||
from danswer.db.connector import create_initial_default_connector
|
||||
from danswer.db.connector_credential_pair import associate_default_cc_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import resync_cc_pair
|
||||
from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.db.document import check_docs_exist
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import init_sqlalchemy_engine
|
||||
from danswer.db.engine import warm_up_connections
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import expire_index_attempts
|
||||
from danswer.db.persona import delete_old_default_personas
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.search_settings import update_current_search_settings
|
||||
from danswer.db.search_settings import update_secondary_search_settings
|
||||
from danswer.db.standard_answer import create_initial_default_standard_answer_category
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.indexing.models import IndexingSetting
|
||||
from danswer.llm.llm_initialization import load_llm_providers
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
|
||||
@@ -106,9 +102,7 @@ from danswer.server.settings.api import basic_router as settings_router
|
||||
from danswer.server.token_rate_limits.api import (
|
||||
router as token_rate_limit_settings_router,
|
||||
)
|
||||
from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
|
||||
from danswer.tools.built_in_tools import load_builtin_tools
|
||||
from danswer.tools.built_in_tools import refresh_built_in_tools_cache
|
||||
from danswer.server.tenants.api import basic_router as tenants_router
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
@@ -117,7 +111,8 @@ from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
from shared_configs.configs import SUPPORTED_EMBEDDING_MODELS
|
||||
from danswer.db_setup import setup_postgres
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -170,27 +165,6 @@ def include_router_with_global_prefix_prepended(
|
||||
application.include_router(router, **final_kwargs)
|
||||
|
||||
|
||||
def setup_postgres(db_session: Session) -> None:
|
||||
logger.notice("Verifying default connector/credential exist.")
|
||||
create_initial_public_credential(db_session)
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
logger.notice("Verifying default standard answer category exists.")
|
||||
create_initial_default_standard_answer_category(db_session)
|
||||
|
||||
logger.notice("Loading LLM providers from env variables")
|
||||
load_llm_providers(db_session)
|
||||
|
||||
logger.notice("Loading default Prompts and Personas")
|
||||
delete_old_default_personas(db_session)
|
||||
load_chat_yamls()
|
||||
|
||||
logger.notice("Loading built-in tools")
|
||||
load_builtin_tools(db_session)
|
||||
refresh_built_in_tools_cache(db_session)
|
||||
auto_add_search_tool_to_personas(db_session)
|
||||
|
||||
|
||||
def translate_saved_search_settings(db_session: Session) -> None:
|
||||
kv_store = get_dynamic_config_store()
|
||||
@@ -258,23 +232,22 @@ def mark_reindex_flag(db_session: Session) -> None:
|
||||
|
||||
def setup_vespa(
|
||||
document_index: DocumentIndex,
|
||||
index_setting: IndexingSetting,
|
||||
secondary_index_setting: IndexingSetting | None,
|
||||
embedding_dims: list[int],
|
||||
secondary_embedding_dim: int | None = None
|
||||
) -> None:
|
||||
# Vespa startup is a bit slow, so give it a few seconds
|
||||
wait_time = 5
|
||||
for _ in range(5):
|
||||
try:
|
||||
document_index.ensure_indices_exist(
|
||||
index_embedding_dim=index_setting.model_dim,
|
||||
secondary_index_embedding_dim=secondary_index_setting.model_dim
|
||||
if secondary_index_setting
|
||||
else None,
|
||||
embedding_dims=embedding_dims,
|
||||
secondary_index_embedding_dim=secondary_embedding_dim
|
||||
)
|
||||
break
|
||||
except Exception:
|
||||
logger.notice(f"Waiting on Vespa, retrying in {wait_time} seconds...")
|
||||
time.sleep(wait_time)
|
||||
logger.exception("Error ensuring multi-tenant indices exist")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@@ -304,6 +277,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
|
||||
# Break bad state for thrashing indexes
|
||||
if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
|
||||
expire_index_attempts(
|
||||
search_settings_id=search_settings.id, db_session=db_session
|
||||
)
|
||||
@@ -316,12 +290,14 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
|
||||
logger.notice(f'Using Embedding model: "{search_settings.model_name}"')
|
||||
if search_settings.query_prefix or search_settings.passage_prefix:
|
||||
|
||||
logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"')
|
||||
logger.notice(
|
||||
f'Passage embedding prefix: "{search_settings.passage_prefix}"'
|
||||
)
|
||||
|
||||
if search_settings:
|
||||
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
logger.notice("Reranking is enabled.")
|
||||
|
||||
@@ -347,19 +323,39 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
|
||||
# ensure Vespa is setup correctly
|
||||
logger.notice("Verifying Document Index(s) is/are available.")
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None,
|
||||
)
|
||||
setup_vespa(
|
||||
document_index,
|
||||
IndexingSetting.from_db_model(search_settings),
|
||||
IndexingSetting.from_db_model(secondary_search_settings)
|
||||
if secondary_search_settings
|
||||
else None,
|
||||
)
|
||||
|
||||
# document_index = get_default_document_index(
|
||||
# indices=[model.index_name for model in SUPPORTED_EMBEDDING_MODELS]
|
||||
# ) if MULTI_TENANT else get_default_document_index(
|
||||
# indices=[model.index_name for model in SUPPORTED_EMBEDDING_MODELS],
|
||||
# secondary_index_name=secondary_search_settings.index_name if secondary_search_settings else None
|
||||
# )
|
||||
|
||||
if MULTI_TENANT:
|
||||
document_index = get_default_document_index(
|
||||
indices=[model.index_name for model in SUPPORTED_EMBEDDING_MODELS]
|
||||
)
|
||||
setup_vespa(
|
||||
document_index,
|
||||
[model.dim for model in SUPPORTED_EMBEDDING_MODELS],
|
||||
secondary_embedding_dim=secondary_search_settings.model_dim if secondary_search_settings else None
|
||||
)
|
||||
|
||||
else:
|
||||
document_index = get_default_document_index(
|
||||
indices=[search_settings.index_name],
|
||||
secondary_index_name=secondary_search_settings.index_name if secondary_search_settings else None
|
||||
)
|
||||
|
||||
setup_vespa(
|
||||
document_index,
|
||||
[IndexingSetting.from_db_model(search_settings).model_dim],
|
||||
secondary_embedding_dim=(
|
||||
IndexingSetting.from_db_model(secondary_search_settings).model_dim
|
||||
if secondary_search_settings
|
||||
else None
|
||||
)
|
||||
)
|
||||
|
||||
logger.notice(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
||||
if search_settings.provider_type is None:
|
||||
@@ -370,6 +366,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
),
|
||||
)
|
||||
logger.notice("Setup complete")
|
||||
|
||||
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
||||
yield
|
||||
@@ -415,6 +412,9 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, token_rate_limit_settings_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, tenants_router
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, indexing_router)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
|
||||
@@ -25,7 +25,6 @@ from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.chat import update_search_docs_table_with_relevance
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompt_by_id
|
||||
from danswer.llm.answering.answer import Answer
|
||||
@@ -118,7 +117,8 @@ def stream_answer_objects(
|
||||
one_shot=True,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
)
|
||||
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona)
|
||||
|
||||
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona, db_session=db_session)
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
@@ -139,6 +139,7 @@ def stream_answer_objects(
|
||||
rephrased_query = query_req.query_override or thread_based_query_rephrase(
|
||||
user_query=query_msg.message,
|
||||
history_str=history_str,
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
# Given back ahead of the documents for latency reasons
|
||||
@@ -209,7 +210,8 @@ def stream_answer_objects(
|
||||
question=query_msg.message,
|
||||
answer_style_config=answer_config,
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona)),
|
||||
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona, db_session=db_session)),
|
||||
# TODO: change back
|
||||
single_message_history=history_str,
|
||||
tools=[search_tool],
|
||||
force_use_tool=ForceUseTool(
|
||||
@@ -316,17 +318,17 @@ def stream_search_answer(
|
||||
user: User | None,
|
||||
max_document_tokens: int | None,
|
||||
max_history_tokens: int | None,
|
||||
db_session: Session,
|
||||
) -> Iterator[str]:
|
||||
with get_session_context_manager() as session:
|
||||
objects = stream_answer_objects(
|
||||
query_req=query_req,
|
||||
user=user,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=session,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.model_dump())
|
||||
objects = stream_answer_objects(
|
||||
query_req=query_req,
|
||||
user=user,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=db_session,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
|
||||
def get_search_answer(
|
||||
|
||||
@@ -25,7 +25,7 @@ class ThreadMessage(BaseModel):
|
||||
|
||||
class DirectQARequest(ChunkContext):
|
||||
messages: list[ThreadMessage]
|
||||
prompt_id: int | None
|
||||
prompt_id: int | None = None
|
||||
persona_id: int
|
||||
multilingual_query_expansion: list[str] | None = None
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
|
||||
@@ -11,7 +11,6 @@ class RecencyBiasSetting(str, Enum):
|
||||
# Determine based on query if to use base_decay or favor_recent
|
||||
AUTO = "auto"
|
||||
|
||||
|
||||
class OptionalSearchSetting(str, Enum):
|
||||
ALWAYS = "always"
|
||||
NEVER = "never"
|
||||
|
||||
@@ -27,7 +27,7 @@ class RerankingDetails(BaseModel):
|
||||
# If model is None (or num_rerank is 0), then reranking is turned off
|
||||
rerank_model_name: str | None
|
||||
rerank_provider_type: RerankerProvider | None
|
||||
rerank_api_key: str | None
|
||||
rerank_api_key: str | None = None
|
||||
|
||||
num_rerank: int
|
||||
|
||||
@@ -98,6 +98,7 @@ class BaseFilters(BaseModel):
|
||||
|
||||
class IndexFilters(BaseFilters):
|
||||
access_control_list: list[str] | None
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class ChunkMetric(BaseModel):
|
||||
|
||||
@@ -29,11 +29,11 @@ from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.db.engine import current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def query_analysis(query: str) -> tuple[bool, list[str]]:
|
||||
analysis_model = QueryAnalysisModel()
|
||||
return analysis_model.predict(query)
|
||||
@@ -121,6 +121,7 @@ def retrieval_preprocessing(
|
||||
]
|
||||
if filter_fn
|
||||
]
|
||||
|
||||
parallel_results = run_functions_in_parallel(functions_to_run)
|
||||
|
||||
predicted_time_cutoff, predicted_favor_recent = (
|
||||
@@ -151,12 +152,15 @@ def retrieval_preprocessing(
|
||||
user_acl_filters = (
|
||||
None if bypass_acl else build_access_filters_for_user(user, db_session)
|
||||
)
|
||||
|
||||
|
||||
final_filters = IndexFilters(
|
||||
source_type=preset_filters.source_type or predicted_source_filters,
|
||||
document_set=preset_filters.document_set,
|
||||
time_cutoff=preset_filters.time_cutoff or predicted_time_cutoff,
|
||||
tags=preset_filters.tags, # Tags are never auto-extracted
|
||||
access_control_list=user_acl_filters,
|
||||
tenant_id=current_tenant_id.get() if MULTI_TENANT else None,
|
||||
)
|
||||
|
||||
llm_evaluation_type = LLMEvaluationType.BASIC
|
||||
|
||||
@@ -237,7 +237,7 @@ def retrieve_chunks(
|
||||
|
||||
# Currently only uses query expansion on multilingual use cases
|
||||
query_rephrases = multilingual_query_expansion(
|
||||
query.query, multilingual_expansion
|
||||
query.query, multilingual_expansion, db_session=db_session
|
||||
)
|
||||
# Just to be extra sure, add the original query.
|
||||
query_rephrases.append(query.query)
|
||||
|
||||
@@ -15,11 +15,11 @@ from danswer.prompts.miscellaneous_prompts import LANGUAGE_REPHRASE_PROMPT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import count_punctuation
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def llm_multilingual_query_expansion(query: str, language: str) -> str:
|
||||
def llm_multilingual_query_expansion(query: str, language: str, db_session: Session) -> str:
|
||||
def _get_rephrase_messages() -> list[dict[str, str]]:
|
||||
messages = [
|
||||
{
|
||||
@@ -33,7 +33,7 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str:
|
||||
return messages
|
||||
|
||||
try:
|
||||
_, fast_llm = get_default_llms(timeout=5)
|
||||
_, fast_llm = get_default_llms(timeout=5, db_session=db_session)
|
||||
except GenAIDisabledException:
|
||||
logger.warning(
|
||||
"Unable to perform multilingual query expansion, Gen AI disabled"
|
||||
@@ -51,12 +51,13 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str:
|
||||
def multilingual_query_expansion(
|
||||
query: str,
|
||||
expansion_languages: list[str],
|
||||
db_session: Session,
|
||||
use_threads: bool = True,
|
||||
) -> list[str]:
|
||||
languages = [language.strip() for language in expansion_languages]
|
||||
if use_threads:
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(llm_multilingual_query_expansion, (query, language))
|
||||
(llm_multilingual_query_expansion, (query, language, db_session))
|
||||
for language in languages
|
||||
]
|
||||
|
||||
@@ -65,7 +66,7 @@ def multilingual_query_expansion(
|
||||
|
||||
else:
|
||||
query_rephrases = [
|
||||
llm_multilingual_query_expansion(query, language) for language in languages
|
||||
llm_multilingual_query_expansion(query, language, db_session) for language in languages
|
||||
]
|
||||
return query_rephrases
|
||||
|
||||
@@ -134,9 +135,10 @@ def history_based_query_rephrase(
|
||||
def thread_based_query_rephrase(
|
||||
user_query: str,
|
||||
history_str: str,
|
||||
db_session: Session,
|
||||
llm: LLM | None = None,
|
||||
size_heuristic: int = 200,
|
||||
punctuation_heuristic: int = 10,
|
||||
punctuation_heuristic: int = 10
|
||||
) -> str:
|
||||
if not history_str:
|
||||
return user_query
|
||||
@@ -149,7 +151,7 @@ def thread_based_query_rephrase(
|
||||
|
||||
if llm is None:
|
||||
try:
|
||||
llm, _ = get_default_llms()
|
||||
llm, _ = get_default_llms(db_session=db_session)
|
||||
except GenAIDisabledException:
|
||||
# If Generative AI is turned off, just return the original query
|
||||
return user_query
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
from sqlalchemy.orm import Session
|
||||
from collections.abc import Iterator
|
||||
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
@@ -46,13 +47,14 @@ def extract_answerability_bool(model_raw: str) -> bool:
|
||||
|
||||
|
||||
def get_query_answerability(
|
||||
user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY
|
||||
db_session: Session,
|
||||
user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY,
|
||||
) -> tuple[str, bool]:
|
||||
if skip_check:
|
||||
return "Query Answerability Evaluation feature is turned off", True
|
||||
|
||||
try:
|
||||
llm, _ = get_default_llms()
|
||||
llm, _ = get_default_llms(db_session=db_session)
|
||||
except GenAIDisabledException:
|
||||
return "Generative AI is turned off - skipping check", True
|
||||
|
||||
@@ -67,7 +69,7 @@ def get_query_answerability(
|
||||
|
||||
|
||||
def stream_query_answerability(
|
||||
user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY
|
||||
db_session: Session, user_query: str, skip_check: bool = DISABLE_LLM_QUERY_ANSWERABILITY,
|
||||
) -> Iterator[str]:
|
||||
if skip_check:
|
||||
yield get_json_line(
|
||||
@@ -79,7 +81,7 @@ def stream_query_answerability(
|
||||
return
|
||||
|
||||
try:
|
||||
llm, _ = get_default_llms()
|
||||
llm, _ = get_default_llms(db_session=db_session)
|
||||
except GenAIDisabledException:
|
||||
yield get_json_line(
|
||||
QueryValidationResponse(
|
||||
|
||||
@@ -166,6 +166,5 @@ if __name__ == "__main__":
|
||||
while True:
|
||||
user_input = input("Query to Extract Sources: ")
|
||||
sources = extract_source_filter(
|
||||
user_input, get_main_llm_from_tuple(get_default_llms()), db_session
|
||||
user_input, get_main_llm_from_tuple(get_default_llms(db_session=db_session)), db_session
|
||||
)
|
||||
print(sources)
|
||||
|
||||
@@ -34,6 +34,7 @@ PUBLIC_ENDPOINT_SPECS = [
|
||||
("/auth/reset-password", {"POST"}),
|
||||
("/auth/request-verify-token", {"POST"}),
|
||||
("/auth/verify", {"POST"}),
|
||||
("/tenants/auth/sso-callback", {"POST"}),
|
||||
("/users/me", {"GET"}),
|
||||
("/users/me", {"PATCH"}),
|
||||
("/users/{id}", {"GET"}),
|
||||
@@ -42,6 +43,8 @@ PUBLIC_ENDPOINT_SPECS = [
|
||||
# oauth
|
||||
("/auth/oauth/authorize", {"GET"}),
|
||||
("/auth/oauth/callback", {"GET"}),
|
||||
# tenant service related (must use API key)
|
||||
("/tenants/create", {"POST"}),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from danswer.server.manage.models import BoostUpdateRequest
|
||||
from danswer.server.manage.models import HiddenUpdateRequest
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.db.engine import current_tenant_id
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
logger = setup_logger()
|
||||
@@ -195,7 +196,7 @@ def create_deletion_attempt_for_connector_id(
|
||||
)
|
||||
# actually kick off the deletion
|
||||
cleanup_connector_credential_pair_task.apply_async(
|
||||
kwargs=dict(connector_id=connector_id, credential_id=credential_id),
|
||||
kwargs=dict(connector_id=connector_id, credential_id=credential_id, tenant_id=current_tenant_id.get()),
|
||||
)
|
||||
|
||||
if cc_pair.connector.source == DocumentSource.FILE:
|
||||
|
||||
@@ -60,6 +60,7 @@ def set_new_search_settings(
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
if search_settings_new.index_name is None:
|
||||
|
||||
# We define index name here
|
||||
index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
|
||||
if (
|
||||
@@ -97,6 +98,7 @@ def set_new_search_settings(
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=new_search_settings.index_name,
|
||||
)
|
||||
|
||||
document_index.ensure_indices_exist(
|
||||
index_embedding_dim=search_settings.model_dim,
|
||||
secondary_index_embedding_dim=new_search_settings.model_dim,
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import os
|
||||
import re
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from enum import Enum
|
||||
|
||||
import stripe
|
||||
from email_validator import validate_email
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Body
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import status
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import desc
|
||||
@@ -27,7 +32,9 @@ from danswer.auth.users import current_user
|
||||
from danswer.auth.users import optional_user
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.configs.app_configs import STRIPE_PRICE
|
||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import AccessToken
|
||||
@@ -47,6 +54,13 @@ from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.api_key import is_api_key_email_address
|
||||
from ee.danswer.db.user_group import remove_curator_status__no_commit
|
||||
|
||||
STRIPE_SECRET_KEY = os.getenv("STRIPE_SECRET_KEY")
|
||||
STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET")
|
||||
# STRIPE_PRICE = os.getenv("STRIPE_PRICE")
|
||||
|
||||
|
||||
stripe.api_key = "sk_test_51NwZq2HlhTYqRZibT2cssHV8E5QcLAUmaRLQPMjGb5aOxOWomVxOmzRgxf82ziDBuGdPP2GIDod8xe6DyqeGgUDi00KbsHPoT4"
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter()
|
||||
@@ -319,6 +333,63 @@ def verify_user_logged_in(
|
||||
return user_info
|
||||
|
||||
|
||||
class BillingPlanType(str, Enum):
|
||||
FREE = "free"
|
||||
PREMIUM = "premium"
|
||||
ENTERPRISE = "enterprise"
|
||||
|
||||
|
||||
class CheckoutSessionUpdateBillingStatus(BaseModel):
|
||||
quantity: int
|
||||
plan: BillingPlanType
|
||||
|
||||
|
||||
@router.post("/create-checkout-session")
|
||||
async def create_checkout_session(
|
||||
request: Request,
|
||||
checkout_billing: CheckoutSessionUpdateBillingStatus,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> JSONResponse:
|
||||
quantity = checkout_billing.quantity
|
||||
plan = checkout_billing.plan
|
||||
|
||||
logger.info(f"Creating checkout session for plan: {plan} with quantity: {quantity}")
|
||||
|
||||
user_email = "pablosfsanchez@gmail.com"
|
||||
|
||||
success_url = f"{WEB_DOMAIN}/admin/plan?success=true"
|
||||
cancel_url = f"{WEB_DOMAIN}/admin/plan?success=false"
|
||||
|
||||
logger.info(f"Stripe price being used: {STRIPE_PRICE}")
|
||||
logger.info(
|
||||
f"Creating checkout session with success_url: {success_url} and cancel_url: {cancel_url}"
|
||||
)
|
||||
|
||||
try:
|
||||
checkout_session = stripe.checkout.Session.create(
|
||||
customer_email=user_email,
|
||||
line_items=[
|
||||
{
|
||||
"price": STRIPE_PRICE,
|
||||
"quantity": quantity,
|
||||
},
|
||||
],
|
||||
mode="subscription",
|
||||
success_url=success_url,
|
||||
cancel_url=cancel_url,
|
||||
metadata={"tenant_id": str("random tenant")},
|
||||
)
|
||||
logger.info(
|
||||
f"Checkout session created successfully with id: {checkout_session.id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="An unexpected error occurred")
|
||||
|
||||
return JSONResponse({"sessionId": checkout_session.id})
|
||||
|
||||
|
||||
"""APIs to adjust user preferences"""
|
||||
|
||||
|
||||
|
||||
@@ -226,7 +226,7 @@ def rename_chat_session(
|
||||
|
||||
try:
|
||||
llm, _ = get_default_llms(
|
||||
additional_headers=get_litellm_additional_request_headers(request.headers)
|
||||
additional_headers=get_litellm_additional_request_headers(request.headers), db_session=db_session
|
||||
)
|
||||
except GenAIDisabledException:
|
||||
# This may be longer than what the LLM tends to produce but is the most
|
||||
@@ -296,6 +296,7 @@ async def is_disconnected(request: Request) -> Callable[[], bool]:
|
||||
def handle_new_chat_message(
|
||||
chat_message_req: CreateChatMessageRequest,
|
||||
request: Request,
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User | None = Depends(current_user),
|
||||
_: None = Depends(check_token_rate_limits),
|
||||
is_disconnected_func: Callable[[], bool] = Depends(is_disconnected),
|
||||
@@ -308,6 +309,8 @@ def handle_new_chat_message(
|
||||
To avoid extra overhead/latency, this assumes (and checks) that previous messages on the path
|
||||
have already been set as latest"""
|
||||
logger.debug(f"Received new chat message: {chat_message_req.message}")
|
||||
logger.debug("Messge info")
|
||||
logger.debug(chat_message_req.__dict__)
|
||||
|
||||
if (
|
||||
not chat_message_req.message
|
||||
@@ -328,6 +331,7 @@ def handle_new_chat_message(
|
||||
request.headers
|
||||
),
|
||||
is_connected=is_disconnected_func,
|
||||
db_session=db_session,
|
||||
):
|
||||
yield json.dumps(packet) if isinstance(packet, dict) else packet
|
||||
|
||||
@@ -424,7 +428,7 @@ def get_max_document_tokens(
|
||||
raise HTTPException(status_code=404, detail="Persona not found")
|
||||
|
||||
return MaxSelectedDocumentTokens(
|
||||
max_tokens=compute_max_document_tokens_for_persona(persona),
|
||||
max_tokens=compute_max_document_tokens_for_persona(persona, db_session=db_session),
|
||||
)
|
||||
|
||||
|
||||
@@ -474,7 +478,7 @@ def seed_chat(
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=new_chat_session.id, db_session=db_session
|
||||
)
|
||||
llm, fast_llm = get_llms_for_persona(persona=new_chat_session.persona)
|
||||
llm, fast_llm = get_llms_for_persona(persona=new_chat_session.persona, db_session=db_session)
|
||||
|
||||
tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
|
||||
@@ -40,6 +40,7 @@ from danswer.server.query_and_chat.models import SourceTag
|
||||
from danswer.server.query_and_chat.models import TagResponse
|
||||
from danswer.server.query_and_chat.token_limit import check_token_rate_limits
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.db.engine import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -52,6 +53,7 @@ def admin_search(
|
||||
question: AdminSearchRequest,
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> AdminSearchResponse:
|
||||
query = question.query
|
||||
logger.notice(f"Received admin search query: {query}")
|
||||
@@ -62,6 +64,7 @@ def admin_search(
|
||||
time_cutoff=question.filters.time_cutoff,
|
||||
tags=question.filters.tags,
|
||||
access_control_list=user_acl_filters,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
@@ -117,13 +120,13 @@ def get_tags(
|
||||
|
||||
@basic_router.post("/query-validation")
|
||||
def query_validation(
|
||||
simple_query: SimpleQueryRequest, _: User = Depends(current_user)
|
||||
simple_query: SimpleQueryRequest, _: User = Depends(current_user), db_session: Session = Depends(get_session)
|
||||
) -> QueryValidationResponse:
|
||||
# Note if weak model prompt is chosen, this check does not occur and will simply return that
|
||||
# the query is valid, this is because weaker models cannot really handle this task well.
|
||||
# Additionally, some weak model servers cannot handle concurrent inferences.
|
||||
logger.notice(f"Validating query: {simple_query.query}")
|
||||
reasoning, answerable = get_query_answerability(simple_query.query)
|
||||
reasoning, answerable = get_query_answerability(db_session=db_session, user_query=simple_query.query)
|
||||
return QueryValidationResponse(reasoning=reasoning, answerable=answerable)
|
||||
|
||||
|
||||
@@ -226,14 +229,14 @@ def get_search_session(
|
||||
# No search responses are answered with a conversational generative AI response
|
||||
@basic_router.post("/stream-query-validation")
|
||||
def stream_query_validation(
|
||||
simple_query: SimpleQueryRequest, _: User = Depends(current_user)
|
||||
simple_query: SimpleQueryRequest, _: User = Depends(current_user), db_session: Session = Depends(get_session)
|
||||
) -> StreamingResponse:
|
||||
# Note if weak model prompt is chosen, this check does not occur and will simply return that
|
||||
# the query is valid, this is because weaker models cannot really handle this task well.
|
||||
# Additionally, some weak model servers cannot handle concurrent inferences.
|
||||
logger.notice(f"Validating query: {simple_query.query}")
|
||||
return StreamingResponse(
|
||||
stream_query_answerability(simple_query.query), media_type="application/json"
|
||||
stream_query_answerability(user_query=simple_query.query, db_session=db_session), media_type="application/json"
|
||||
)
|
||||
|
||||
|
||||
@@ -242,6 +245,7 @@ def get_answer_with_quote(
|
||||
query_request: DirectQARequest,
|
||||
user: User = Depends(current_user),
|
||||
_: None = Depends(check_token_rate_limits),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
query = query_request.messages[0].message
|
||||
|
||||
@@ -252,5 +256,6 @@ def get_answer_with_quote(
|
||||
user=user,
|
||||
max_document_tokens=None,
|
||||
max_history_tokens=0,
|
||||
db_session=db_session,
|
||||
)
|
||||
return StreamingResponse(packets, media_type="application/json")
|
||||
|
||||
@@ -33,6 +33,7 @@ def check_token_rate_limits(
|
||||
) -> None:
|
||||
# short circuit if no rate limits are set up
|
||||
# NOTE: result of `any_rate_limit_exists` is cached, so this call is fast 99% of the time
|
||||
return
|
||||
if not any_rate_limit_exists():
|
||||
return
|
||||
|
||||
|
||||
@@ -1,17 +1,16 @@
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.auth.users import is_user_admin
|
||||
from danswer.configs.constants import KV_REINDEX_KEY
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.db.engine import get_session
|
||||
|
||||
from danswer.db.models import User
|
||||
from danswer.db.notification import create_notification
|
||||
from danswer.db.notification import dismiss_all_notifications
|
||||
@@ -27,15 +26,17 @@ from danswer.server.settings.models import UserSettings
|
||||
from danswer.server.settings.store import load_settings
|
||||
from danswer.server.settings.store import store_settings
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
admin_router = APIRouter(prefix="/admin/settings")
|
||||
basic_router = APIRouter(prefix="/settings")
|
||||
|
||||
|
||||
|
||||
@admin_router.put("")
|
||||
def put_settings(
|
||||
settings: Settings, _: User | None = Depends(current_admin_user)
|
||||
@@ -66,7 +67,7 @@ def fetch_settings(
|
||||
return UserSettings(
|
||||
**general_settings.model_dump(),
|
||||
notifications=user_notifications,
|
||||
needs_reindexing=needs_reindexing
|
||||
needs_reindexing=needs_reindexing,
|
||||
)
|
||||
|
||||
|
||||
@@ -91,6 +92,7 @@ def dismiss_notification_endpoint(
|
||||
def get_user_notifications(
|
||||
user: User | None, db_session: Session
|
||||
) -> list[Notification]:
|
||||
return cast(list[Notification], [])
|
||||
"""Get notifications for the user, currently the logic is very specific to the reindexing flag"""
|
||||
is_admin = is_user_admin(user)
|
||||
if not is_admin:
|
||||
|
||||
0
backend/danswer/server/tenants/__init__.py
Normal file
0
backend/danswer/server/tenants/__init__.py
Normal file
72
backend/danswer/server/tenants/api.py
Normal file
72
backend/danswer/server/tenants/api.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from sqlalchemy.orm import Session
|
||||
from fastapi import Body
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.auth.users import create_user_session
|
||||
from danswer.auth.users import get_user_manager
|
||||
from danswer.auth.users import UserManager
|
||||
from danswer.auth.users import verify_sso_token
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.utils.logger import setup_logger
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi import HTTPException
|
||||
from ee.danswer.auth.users import control_plane_dep
|
||||
from danswer.server.tenants.provisioning import setup_postgres_and_initial_settings
|
||||
from danswer.server.tenants.provisioning import check_schema_exists
|
||||
from danswer.server.tenants.provisioning import run_alembic_migrations
|
||||
from danswer.server.tenants.provisioning import create_tenant_schema
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
basic_router = APIRouter(prefix="/tenants")
|
||||
|
||||
@basic_router.post("/create")
|
||||
def create_tenant(tenant_id: str, _: None= Depends(control_plane_dep)) -> dict[str, str]:
|
||||
if not MULTI_TENANT:
|
||||
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
|
||||
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="tenant_id is required")
|
||||
|
||||
create_tenant_schema(tenant_id)
|
||||
run_alembic_migrations(tenant_id)
|
||||
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
|
||||
setup_postgres_and_initial_settings(db_session)
|
||||
|
||||
logger.info(f"Tenant {tenant_id} created successfully")
|
||||
return {"status": "success", "message": f"Tenant {tenant_id} created successfully"}
|
||||
|
||||
|
||||
@basic_router.post("/auth/sso-callback")
|
||||
async def sso_callback(
|
||||
sso_token: str = Body(..., embed=True),
|
||||
user_manager: UserManager = Depends(get_user_manager),
|
||||
) -> JSONResponse:
|
||||
if not MULTI_TENANT:
|
||||
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
|
||||
|
||||
payload = verify_sso_token(sso_token)
|
||||
user = await user_manager.sso_authenticate(
|
||||
payload["email"], payload["tenant_id"]
|
||||
)
|
||||
|
||||
tenant_id = payload["tenant_id"]
|
||||
schema_exists = await check_schema_exists(tenant_id)
|
||||
if not schema_exists:
|
||||
raise HTTPException(status_code=403, detail="Your Danswer app has not been set up yet!")
|
||||
|
||||
session_token = await create_user_session(user, payload["tenant_id"])
|
||||
|
||||
response = JSONResponse(content={"message": "Authentication successful"})
|
||||
response.set_cookie(
|
||||
key="tenant_details",
|
||||
value=session_token,
|
||||
max_age=SESSION_EXPIRE_TIME_SECONDS,
|
||||
expires=SESSION_EXPIRE_TIME_SECONDS,
|
||||
path="/",
|
||||
secure=False,
|
||||
httponly=True,
|
||||
samesite="lax",
|
||||
)
|
||||
return response
|
||||
157
backend/danswer/server/tenants/provisioning.py
Normal file
157
backend/danswer/server/tenants/provisioning.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import contextlib
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.db.engine import get_async_session
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import resync_cc_pair
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import expire_index_attempts
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from danswer.llm.llm_initialization import load_llm_providers
|
||||
from danswer.db.connector import create_initial_default_connector
|
||||
from danswer.db.connector_credential_pair import associate_default_cc_pair
|
||||
from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.db.standard_answer import create_initial_default_standard_answer_category
|
||||
from danswer.db.persona import delete_old_default_personas
|
||||
from danswer.chat.load_yamls import load_chat_yamls
|
||||
from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
|
||||
from danswer.tools.built_in_tools import load_builtin_tools
|
||||
from danswer.tools.built_in_tools import refresh_built_in_tools_cache
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from sqlalchemy.schema import CreateSchema
|
||||
from sqlalchemy import text
|
||||
from alembic.config import Config
|
||||
from alembic import command
|
||||
from danswer.db.engine import build_connection_string
|
||||
import os
|
||||
from danswer.db_setup import setup_postgres
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
def run_alembic_migrations(schema_name: str) -> None:
|
||||
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
|
||||
try:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
|
||||
alembic_ini_path = os.path.join(root_dir, 'alembic.ini')
|
||||
|
||||
# Configure Alembic
|
||||
alembic_cfg = Config(alembic_ini_path)
|
||||
alembic_cfg.set_main_option('sqlalchemy.url', build_connection_string())
|
||||
|
||||
# Prepare the x arguments
|
||||
x_arguments = [f"schema={schema_name}"]
|
||||
alembic_cfg.cmd_opts.x = x_arguments # type: ignore
|
||||
|
||||
# Run migrations programmatically
|
||||
command.upgrade(alembic_cfg, 'head')
|
||||
|
||||
logger.info(f"Alembic migrations completed successfully for schema: {schema_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
|
||||
raise
|
||||
|
||||
def create_tenant_schema(tenant_id: str) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with db_session.begin():
|
||||
result = db_session.execute(
|
||||
text("""
|
||||
SELECT schema_name
|
||||
FROM information_schema.schemata
|
||||
WHERE schema_name = :schema_name
|
||||
"""),
|
||||
{"schema_name": tenant_id}
|
||||
)
|
||||
schema_exists = result.scalar() is not None
|
||||
|
||||
if not schema_exists:
|
||||
db_session.execute(CreateSchema(tenant_id))
|
||||
logger.info(f"Schema {tenant_id} created")
|
||||
else:
|
||||
logger.info(f"Schema {tenant_id} already exists")
|
||||
|
||||
|
||||
def setup_postgres_and_initial_settings(db_session: Session) -> None:
|
||||
check_index_swap(db_session=db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
# Break bad state for thrashing indexes
|
||||
if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
expire_index_attempts(
|
||||
search_settings_id=search_settings.id, db_session=db_session
|
||||
)
|
||||
|
||||
for cc_pair in get_connector_credential_pairs(db_session):
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
# Expire all old embedding models indexing attempts, technically redundant
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
logger.notice(f'Using Embedding model: "{search_settings.model_name}"')
|
||||
if search_settings.query_prefix or search_settings.passage_prefix:
|
||||
logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"')
|
||||
logger.notice(
|
||||
f'Passage embedding prefix: "{search_settings.passage_prefix}"'
|
||||
)
|
||||
|
||||
if search_settings:
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
logger.notice("Reranking is enabled.")
|
||||
|
||||
if search_settings.multilingual_expansion:
|
||||
logger.notice(
|
||||
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
|
||||
)
|
||||
|
||||
if search_settings.rerank_model_name and not search_settings.provider_type:
|
||||
warm_up_cross_encoder(search_settings.rerank_model_name)
|
||||
|
||||
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
download_nltk_data()
|
||||
|
||||
# setup Postgres with default credentials, llm providers, etc.
|
||||
setup_postgres(db_session)
|
||||
|
||||
# ensure Vespa is setup correctly
|
||||
logger.notice("Verifying Document Index(s) is/are available.")
|
||||
|
||||
|
||||
logger.notice("Verifying default connector/credential exist.")
|
||||
create_initial_public_credential(db_session)
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
logger.notice("Verifying default standard answer category exists.")
|
||||
create_initial_default_standard_answer_category(db_session)
|
||||
|
||||
logger.notice("Loading LLM providers from env variables")
|
||||
load_llm_providers(db_session)
|
||||
|
||||
logger.notice("Loading default Prompts and Personas")
|
||||
delete_old_default_personas(db_session)
|
||||
load_chat_yamls(db_session)
|
||||
|
||||
logger.notice("Loading built-in tools")
|
||||
load_builtin_tools(db_session)
|
||||
refresh_built_in_tools_cache(db_session)
|
||||
auto_add_search_tool_to_personas(db_session)
|
||||
|
||||
|
||||
async def check_schema_exists(tenant_id: str) -> bool:
|
||||
get_async_session_context = contextlib.asynccontextmanager(
|
||||
get_async_session
|
||||
)
|
||||
async with get_async_session_context() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"),
|
||||
{"schema_name": tenant_id}
|
||||
)
|
||||
return result.scalar() is not None
|
||||
@@ -13,15 +13,43 @@ from ee.danswer.db.api_key import fetch_user_for_api_key
|
||||
from ee.danswer.db.saml import get_saml_account
|
||||
from ee.danswer.server.seeding import get_seed_config
|
||||
from ee.danswer.utils.secrets import extract_hashed_cookie
|
||||
import jwt
|
||||
from danswer.configs.app_configs import DATA_PLANE_SECRET
|
||||
from danswer.configs.app_configs import EXPECTED_API_KEY
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
# All the Auth flows are valid for EE version
|
||||
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
|
||||
|
||||
|
||||
async def control_plane_dep(request: Request) -> None:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
api_key = request.headers.get("X-API-KEY")
|
||||
|
||||
if api_key != EXPECTED_API_KEY:
|
||||
logger.warning("Invalid API key")
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
logger.warning("Invalid authorization header")
|
||||
raise HTTPException(status_code=401, detail="Invalid authorization header")
|
||||
|
||||
token = auth_header.split(" ")[1]
|
||||
try:
|
||||
payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=["HS256"])
|
||||
if payload.get("scope") != "tenant:create":
|
||||
logger.warning("Insufficient permissions")
|
||||
raise HTTPException(status_code=403, detail="Insufficient permissions")
|
||||
except jwt.ExpiredSignatureError:
|
||||
logger.warning("Token has expired")
|
||||
raise HTTPException(status_code=401, detail="Token has expired")
|
||||
except jwt.InvalidTokenError:
|
||||
logger.warning("Invalid token")
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
|
||||
|
||||
async def optional_user_(
|
||||
request: Request,
|
||||
user: User | None,
|
||||
@@ -44,6 +72,7 @@ async def optional_user_(
|
||||
return user
|
||||
|
||||
|
||||
|
||||
def api_key_dep(
|
||||
request: Request, db_session: Session = Depends(get_session)
|
||||
) -> User | None:
|
||||
@@ -63,6 +92,7 @@ def api_key_dep(
|
||||
return user
|
||||
|
||||
|
||||
|
||||
def get_default_admin_user_emails_() -> list[str]:
|
||||
seed_config = get_seed_config()
|
||||
if seed_config and seed_config.admin_user_emails:
|
||||
|
||||
@@ -183,7 +183,7 @@ def handle_send_message_simple_with_history(
|
||||
one_shot=False,
|
||||
)
|
||||
|
||||
llm, _ = get_llms_for_persona(persona=chat_session.persona)
|
||||
llm, _ = get_llms_for_persona(persona=chat_session.persona, db_session=db_session)
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
@@ -223,6 +223,7 @@ def handle_send_message_simple_with_history(
|
||||
rephrased_query = req.query_override or thread_based_query_rephrase(
|
||||
user_query=query,
|
||||
history_str=history_str,
|
||||
db_session=db_session
|
||||
)
|
||||
|
||||
full_chat_msg_info = CreateChatMessageRequest(
|
||||
|
||||
@@ -53,7 +53,7 @@ def handle_search_request(
|
||||
query = search_request.message
|
||||
logger.notice(f"Received document search query: {query}")
|
||||
|
||||
llm, fast_llm = get_default_llms()
|
||||
llm, fast_llm = get_default_llms(db_session=db_session)
|
||||
|
||||
search_pipeline = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
@@ -141,7 +141,7 @@ def get_answer_with_quote(
|
||||
)
|
||||
|
||||
llm = get_main_llm_from_tuple(
|
||||
get_default_llms() if not persona else get_llms_for_persona(persona)
|
||||
get_default_llms(db_session=db_session) if not persona else get_llms_for_persona(persona, db_session=db_session)
|
||||
)
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm.config.model_name, model_provider=llm.config.model_provider
|
||||
@@ -154,6 +154,7 @@ def get_answer_with_quote(
|
||||
persona=persona,
|
||||
actual_user_input=query,
|
||||
max_llm_token_override=remaining_tokens,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
answer_details = get_search_answer(
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from pydantic import BaseModel
|
||||
import os
|
||||
|
||||
# Used for logging
|
||||
@@ -52,8 +53,9 @@ LOG_FILE_NAME = os.environ.get("LOG_FILE_NAME") or "danswer"
|
||||
|
||||
# Enable generating persistent log files for local dev environments
|
||||
DEV_LOGGING_ENABLED = os.environ.get("DEV_LOGGING_ENABLED", "").lower() == "true"
|
||||
|
||||
# notset, debug, info, notice, warning, error, or critical
|
||||
LOG_LEVEL = os.environ.get("LOG_LEVEL", "notice")
|
||||
LOG_LEVEL = os.environ.get("LOG_LEVEL", "debug")
|
||||
|
||||
|
||||
# Fields which should only be set on new search setting
|
||||
@@ -68,3 +70,31 @@ PRESERVED_SEARCH_FIELDS = [
|
||||
"passage_prefix",
|
||||
"query_prefix",
|
||||
]
|
||||
|
||||
class SupportedEmbeddingModel(BaseModel):
|
||||
name: str
|
||||
dim: int
|
||||
index_name: str
|
||||
|
||||
SUPPORTED_EMBEDDING_MODELS = [
|
||||
SupportedEmbeddingModel(
|
||||
name="intfloat/e5-small-v2",
|
||||
dim=384,
|
||||
index_name="danswer_chunk_intfloat_e5_small_v2"
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="intfloat/e5-large-v2",
|
||||
dim=1024,
|
||||
index_name="danswer_chunk_intfloat_e5_large_v2"
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="sentence-transformers/all-distilroberta-v1",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_sentence_transformers_all_distilroberta_v1"
|
||||
),
|
||||
SupportedEmbeddingModel(
|
||||
name="sentence-transformers/all-mpnet-base-v2",
|
||||
dim=768,
|
||||
index_name="danswer_chunk_sentence_transformers_all_mpnet_base_v2"
|
||||
),
|
||||
]
|
||||
|
||||
@@ -18,7 +18,7 @@ from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
from danswer.indexing.models import IndexingSetting
|
||||
from danswer.main import setup_postgres
|
||||
from danswer.db_setup import setup_postgres
|
||||
from danswer.main import setup_vespa
|
||||
from tests.integration.common_utils.llm import seed_default_openai_provider
|
||||
|
||||
@@ -132,9 +132,9 @@ def reset_vespa() -> None:
|
||||
index_name = search_settings.index_name
|
||||
|
||||
setup_vespa(
|
||||
document_index=VespaIndex(index_name=index_name, secondary_index_name=None),
|
||||
index_setting=IndexingSetting.from_db_model(search_settings),
|
||||
secondary_index_setting=None,
|
||||
document_index=VespaIndex(indices=[index_name], secondary_index_name=None),
|
||||
embedding_dims=[search_settings.model_dim],
|
||||
secondary_embedding_dim=None,
|
||||
)
|
||||
|
||||
for _ in range(5):
|
||||
|
||||
@@ -296,7 +296,7 @@ services:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
- "5433:5432"
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
|
||||
@@ -306,7 +306,7 @@ services:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
- "5433:5432"
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
|
||||
@@ -150,7 +150,7 @@ services:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
ports:
|
||||
- "5432"
|
||||
- "5433"
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ const nextConfig = {
|
||||
|
||||
return defaultRedirects.concat([
|
||||
{
|
||||
|
||||
source: "/api/chat/send-message:params*",
|
||||
destination: "http://127.0.0.1:8080/chat/send-message:params*", // Proxy to Backend
|
||||
permanent: true,
|
||||
|
||||
10
web/package-lock.json
generated
10
web/package-lock.json
generated
@@ -15,6 +15,7 @@
|
||||
"@radix-ui/react-dialog": "^1.0.5",
|
||||
"@radix-ui/react-popover": "^1.0.7",
|
||||
"@radix-ui/react-tooltip": "^1.0.7",
|
||||
"@stripe/stripe-js": "^4.4.0",
|
||||
"@tremor/react": "^3.9.2",
|
||||
"@types/js-cookie": "^3.0.3",
|
||||
"@types/lodash": "^4.17.0",
|
||||
@@ -1670,6 +1671,15 @@
|
||||
"integrity": "sha512-qC/xYId4NMebE6w/V33Fh9gWxLgURiNYgVNObbJl2LZv0GUUItCcCqC5axQSwRaAgaxl2mELq1rMzlswaQ0Zxg==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/@stripe/stripe-js": {
|
||||
"version": "4.4.0",
|
||||
"resolved": "https://registry.npmjs.org/@stripe/stripe-js/-/stripe-js-4.4.0.tgz",
|
||||
"integrity": "sha512-p1WeTOwnAyXQ9I5/YC3+JXoUB6NKMR4qGjBobie2+rgYa3ftUTRS2L5qRluw/tGACty5SxqnfORCdsaymD1XjQ==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=12.16"
|
||||
}
|
||||
},
|
||||
"node_modules/@swc/counter": {
|
||||
"version": "0.1.3",
|
||||
"resolved": "https://registry.npmjs.org/@swc/counter/-/counter-0.1.3.tgz",
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
"@radix-ui/react-dialog": "^1.0.5",
|
||||
"@radix-ui/react-popover": "^1.0.7",
|
||||
"@radix-ui/react-tooltip": "^1.0.7",
|
||||
"@stripe/stripe-js": "^4.4.0",
|
||||
"@tremor/react": "^3.9.2",
|
||||
"@types/js-cookie": "^3.0.3",
|
||||
"@types/lodash": "^4.17.0",
|
||||
|
||||
@@ -112,7 +112,7 @@ export default function Page() {
|
||||
value={searchTerm}
|
||||
onChange={(e) => setSearchTerm(e.target.value)}
|
||||
onKeyDown={handleKeyPress}
|
||||
className="flex mt-2 max-w-sm h-9 w-full rounded-md border-2 border border-input bg-transparent px-3 py-1 text-sm shadow-sm transition-colors placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
|
||||
className="ml-1 w-96 h-9 flex-none rounded-md border border-border bg-background-50 px-3 py-1 text-sm shadow-sm transition-colors placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
|
||||
/>
|
||||
|
||||
{Object.entries(categorizedSources)
|
||||
|
||||
@@ -173,7 +173,7 @@ export function ProviderCreationModal({
|
||||
</a>
|
||||
</Text>
|
||||
|
||||
<div className="flex flex-col gap-y-2">
|
||||
<div className="flex w-full flex-col gap-y-2">
|
||||
{useFileUpload ? (
|
||||
<>
|
||||
<Label>Upload JSON File</Label>
|
||||
|
||||
@@ -457,7 +457,7 @@ export function CCPairIndexingStatusTable({
|
||||
placeholder="Search connectors..."
|
||||
value={searchTerm}
|
||||
onChange={(e) => setSearchTerm(e.target.value)}
|
||||
className="ml-2 w-96 h-9 flex-none rounded-md border border-border bg-background-50 px-3 py-1 text-sm shadow-sm transition-colors placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
|
||||
className="ml-1 w-96 h-9 flex-none rounded-md border border-border bg-background-50 px-3 py-1 text-sm shadow-sm transition-colors placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-1 focus-visible:ring-ring"
|
||||
/>
|
||||
|
||||
<Button className="h-9" onClick={() => toggleSources()}>
|
||||
|
||||
@@ -26,11 +26,52 @@ export interface EnterpriseSettings {
|
||||
custom_popup_header: string | null;
|
||||
custom_popup_content: string | null;
|
||||
}
|
||||
import { FiStar, FiDollarSign, FiAward } from "react-icons/fi";
|
||||
|
||||
export enum BillingPlanType {
|
||||
FREE = "free",
|
||||
PREMIUM = "premium",
|
||||
ENTERPRISE = "enterprise",
|
||||
}
|
||||
|
||||
export interface CloudSettings {
|
||||
numberOfSeats: number;
|
||||
planType: BillingPlanType;
|
||||
}
|
||||
|
||||
export interface CombinedSettings {
|
||||
settings: Settings;
|
||||
enterpriseSettings: EnterpriseSettings | null;
|
||||
cloudSettings: CloudSettings | null;
|
||||
customAnalyticsScript: string | null;
|
||||
isMobile?: boolean;
|
||||
webVersion: string | null;
|
||||
}
|
||||
|
||||
|
||||
export const defaultCombinedSettings: CombinedSettings = {
|
||||
settings: {
|
||||
chat_page_enabled: true,
|
||||
search_page_enabled: true,
|
||||
default_page: "search",
|
||||
maximum_chat_retention_days: 30,
|
||||
notifications: [],
|
||||
needs_reindexing: false,
|
||||
},
|
||||
enterpriseSettings: {
|
||||
application_name: "Danswer",
|
||||
use_custom_logo: false,
|
||||
use_custom_logotype: false,
|
||||
custom_lower_disclaimer_content: null,
|
||||
custom_header_content: null,
|
||||
custom_popup_header: null,
|
||||
custom_popup_content: null,
|
||||
},
|
||||
cloudSettings: {
|
||||
numberOfSeats: 0,
|
||||
planType: BillingPlanType.FREE,
|
||||
},
|
||||
customAnalyticsScript: null,
|
||||
isMobile: false,
|
||||
webVersion: null,
|
||||
};
|
||||
@@ -147,7 +147,7 @@ function AssistantListItem({
|
||||
</div>
|
||||
|
||||
<div className="text-sm mt-2">{assistant.description}</div>
|
||||
<div className="mt-2 flex items-start gap-y-2 flex-col gap-x-3">
|
||||
<div className="mt-2 flex flex-none items-start gap-y-2 flex-col gap-x-3">
|
||||
<AssistantSharedStatusDisplay assistant={assistant} user={user} />
|
||||
{assistant.tools.length != 0 && (
|
||||
<AssistantTools list assistant={assistant} />
|
||||
@@ -175,6 +175,7 @@ function AssistantListItem({
|
||||
<FiEdit2 size={16} />
|
||||
</Link>
|
||||
|
||||
|
||||
<DefaultPopover
|
||||
content={
|
||||
<div className="hover:bg-hover rounded p-2 cursor-pointer">
|
||||
|
||||
@@ -89,6 +89,7 @@ const Page = async ({
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
{authTypeMetadata?.authType === "basic" && (
|
||||
<Card className="mt-4 w-96">
|
||||
<div className="flex">
|
||||
|
||||
17
web/src/app/auth/sso-callback/layout.tsx
Normal file
17
web/src/app/auth/sso-callback/layout.tsx
Normal file
@@ -0,0 +1,17 @@
|
||||
import React from "react";
|
||||
|
||||
export const metadata = {
|
||||
title: "SSO Callback",
|
||||
};
|
||||
|
||||
export default function SSOCallbackLayout({
|
||||
children,
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
return (
|
||||
<html lang="en">
|
||||
<body>{children}</body>
|
||||
</html>
|
||||
);
|
||||
}
|
||||
127
web/src/app/auth/sso-callback/page.tsx
Normal file
127
web/src/app/auth/sso-callback/page.tsx
Normal file
@@ -0,0 +1,127 @@
|
||||
"use client";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { Card, Text } from "@tremor/react";
|
||||
import { Logo } from "@/components/Logo";
|
||||
|
||||
export default function SSOCallback() {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [authStatus, setAuthStatus] = useState<string>("Authenticating...");
|
||||
|
||||
const verificationStartedRef = useRef(false);
|
||||
|
||||
useEffect(() => {
|
||||
const verifyToken = async () => {
|
||||
if (verificationStartedRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
verificationStartedRef.current = true;
|
||||
const hashParams = new URLSearchParams(window.location.hash.slice(1));
|
||||
const ssoToken = hashParams.get("sso_token");
|
||||
|
||||
if (!ssoToken) {
|
||||
setError("No SSO token found in URL hash");
|
||||
return;
|
||||
}
|
||||
|
||||
window.history.replaceState(null, '', window.location.pathname);
|
||||
|
||||
if (!ssoToken) {
|
||||
setError("No SSO token found");
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
setAuthStatus("Verifying SSO token...");
|
||||
const response = await fetch(
|
||||
`/api/tenants/auth/sso-callback`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
credentials: "include",
|
||||
body: JSON.stringify({ sso_token: ssoToken }),
|
||||
}
|
||||
)
|
||||
|
||||
if (response.ok) {
|
||||
setAuthStatus("Authentication successful!");
|
||||
router.replace("/admin/configuration/llm");
|
||||
} else {
|
||||
const errorData = await response.json();
|
||||
setError(errorData.detail || "Authentication failed");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error verifying token:", error);
|
||||
setError("An unexpected error occurred");
|
||||
}
|
||||
};
|
||||
|
||||
verifyToken();
|
||||
}, [router, searchParams]);
|
||||
|
||||
|
||||
return (
|
||||
<div className="flex items-center justify-center min-h-screen bg-gradient-to-r from-background-50 to-blue-50">
|
||||
<Card className="max-w-lg p-8 text-center shadow-xl rounded-xl bg-white">
|
||||
{error ? (
|
||||
<div className="space-y-4">
|
||||
<svg
|
||||
className="w-16 h-16 mx-auto text-neutral-600"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
viewBox="0 0 24 24"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
>
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
strokeWidth={2}
|
||||
d="M12 8v4m0 4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"
|
||||
/>
|
||||
</svg>
|
||||
<Text className="text-xl font-bold text-red-600">{error}</Text>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-6 flex flex-col">
|
||||
<div className="flex mx-auto">
|
||||
<Logo height={200} width={200} />
|
||||
</div>
|
||||
<Text className="text-2xl font-semibold text-text-900">
|
||||
{authStatus}
|
||||
</Text>
|
||||
<div className="w-full h-2 bg-background-100 rounded-full overflow-hidden">
|
||||
<div
|
||||
className="h-full bg-background-600 rounded-full animate-progress"
|
||||
style={{
|
||||
animation: "progress 5s ease-out forwards",
|
||||
width: "0%",
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<style jsx>{`
|
||||
@keyframes progress {
|
||||
0% {
|
||||
width: 0%;
|
||||
}
|
||||
60% {
|
||||
width: 75%;
|
||||
}
|
||||
100% {
|
||||
width: 99%;
|
||||
}
|
||||
}
|
||||
.animate-progress {
|
||||
animation: progress 5s ease-out forwards;
|
||||
}
|
||||
`}</style>
|
||||
</div>
|
||||
)}
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,12 +1,9 @@
|
||||
"use client";
|
||||
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { useContext, useState, useRef, useLayoutEffect } from "react";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import { Popover } from "@/components/popover/Popover";
|
||||
import { ChevronDownIcon } from "@/components/icons/icons";
|
||||
import { Divider } from "@tremor/react";
|
||||
import { MinimalMarkdown } from "@/components/chat_search/MinimalMarkdown";
|
||||
|
||||
export function ChatBanner() {
|
||||
@@ -34,27 +31,6 @@ export function ChatBanner() {
|
||||
return null;
|
||||
}
|
||||
|
||||
const renderMarkdown = (className: string) => (
|
||||
<ReactMarkdown
|
||||
className={`w-full text-wrap break-word ${className}`}
|
||||
components={{
|
||||
a: ({ node, ...props }) => (
|
||||
<a
|
||||
{...props}
|
||||
className="text-sm text-link hover:text-link-hover"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
/>
|
||||
),
|
||||
p: ({ node, ...props }) => (
|
||||
<p {...props} className="text-wrap break-word text-sm m-0 w-full" />
|
||||
),
|
||||
}}
|
||||
remarkPlugins={[remarkGfm]}
|
||||
>
|
||||
{settings.enterpriseSettings?.custom_header_content}
|
||||
</ReactMarkdown>
|
||||
);
|
||||
return (
|
||||
<div
|
||||
className={`
|
||||
@@ -65,7 +41,6 @@ export function ChatBanner() {
|
||||
w-full
|
||||
mx-auto
|
||||
relative
|
||||
bg-background-100
|
||||
shadow-sm
|
||||
rounded
|
||||
border-l-8 border-l-400
|
||||
@@ -81,7 +56,6 @@ export function ChatBanner() {
|
||||
className="line-clamp-2 text-center w-full overflow-hidden pr-8"
|
||||
>
|
||||
<MinimalMarkdown
|
||||
className=""
|
||||
content={settings.enterpriseSettings.custom_header_content}
|
||||
/>
|
||||
</div>
|
||||
@@ -90,7 +64,6 @@ export function ChatBanner() {
|
||||
className="absolute top-0 left-0 invisible w-full"
|
||||
>
|
||||
<MinimalMarkdown
|
||||
className=""
|
||||
content={settings.enterpriseSettings.custom_header_content}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -210,6 +210,11 @@ export function ChatPage({
|
||||
liveAssistant,
|
||||
llmProviders
|
||||
);
|
||||
console.log("persona default", personaDefault);
|
||||
console.log(
|
||||
destructureValue(user?.preferences.default_model || "openai:gpt-4o")
|
||||
);
|
||||
|
||||
|
||||
if (personaDefault) {
|
||||
llmOverrideManager.setLlmOverride(personaDefault);
|
||||
@@ -882,6 +887,11 @@ export function ChatPage({
|
||||
modelOverRide?: LlmOverride;
|
||||
regenerationRequest?: RegenerationRequest | null;
|
||||
} = {}) => {
|
||||
console.log("model override", modelOverRide);
|
||||
console.log(modelOverRide?.name);
|
||||
console.log(llmOverrideManager.llmOverride.name);
|
||||
console.log("HII")
|
||||
console.log(llmOverrideManager.globalDefault.name);
|
||||
let frozenSessionId = currentSessionId();
|
||||
|
||||
if (currentChatState() != "input") {
|
||||
|
||||
@@ -205,7 +205,8 @@ const FolderItem = ({
|
||||
className="text-sm px-1 flex-1 min-w-0 -my-px mr-2"
|
||||
/>
|
||||
) : (
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex-1 break-all min-w-0">
|
||||
|
||||
{editedFolderName || folder.folder_name}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -116,6 +116,7 @@ export function ChatInputBar({
|
||||
const { llmProviders } = useChatContext();
|
||||
const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null);
|
||||
|
||||
|
||||
const suggestionsRef = useRef<HTMLDivElement | null>(null);
|
||||
const [showSuggestions, setShowSuggestions] = useState(false);
|
||||
const [showPrompts, setShowPrompts] = useState(false);
|
||||
|
||||
@@ -152,6 +152,8 @@ export async function* sendMessage({
|
||||
}): AsyncGenerator<PacketType, void, unknown> {
|
||||
const documentsAreSelected =
|
||||
selectedDocumentIds && selectedDocumentIds.length > 0;
|
||||
|
||||
console.log("llm ovverride deatilas", modelProvider, modelVersion)
|
||||
|
||||
const body = JSON.stringify({
|
||||
alternate_assistant_id: alternateAssistantId,
|
||||
@@ -164,30 +166,30 @@ export async function* sendMessage({
|
||||
regenerate,
|
||||
retrieval_options: !documentsAreSelected
|
||||
? {
|
||||
run_search:
|
||||
promptId === null ||
|
||||
run_search:
|
||||
promptId === null ||
|
||||
promptId === undefined ||
|
||||
queryOverride ||
|
||||
forceSearch
|
||||
? "always"
|
||||
: "auto",
|
||||
real_time: true,
|
||||
filters: filters,
|
||||
}
|
||||
? "always"
|
||||
: "auto",
|
||||
real_time: true,
|
||||
filters: filters,
|
||||
}
|
||||
: null,
|
||||
query_override: queryOverride,
|
||||
prompt_override: systemPromptOverride
|
||||
? {
|
||||
system_prompt: systemPromptOverride,
|
||||
}
|
||||
system_prompt: systemPromptOverride,
|
||||
}
|
||||
: null,
|
||||
llm_override:
|
||||
temperature || modelVersion
|
||||
? {
|
||||
temperature,
|
||||
model_provider: modelProvider,
|
||||
model_version: modelVersion,
|
||||
}
|
||||
temperature,
|
||||
model_provider: modelProvider,
|
||||
model_version: modelVersion,
|
||||
}
|
||||
: null,
|
||||
use_existing_user_message: useExistingUserMessage,
|
||||
});
|
||||
@@ -427,11 +429,11 @@ export function processRawChatHistory(
|
||||
// this is identical to what is computed at streaming time
|
||||
...(messageInfo.message_type === "assistant"
|
||||
? {
|
||||
retrievalType: retrievalType,
|
||||
query: messageInfo.rephrased_query,
|
||||
documents: messageInfo?.context_docs?.top_documents || [],
|
||||
citations: messageInfo?.citations || {},
|
||||
}
|
||||
retrievalType: retrievalType,
|
||||
query: messageInfo.rephrased_query,
|
||||
documents: messageInfo?.context_docs?.top_documents || [],
|
||||
citations: messageInfo?.citations || {},
|
||||
}
|
||||
: {}),
|
||||
toolCalls: messageInfo.tool_calls,
|
||||
parentMessageId: messageInfo.parent_message,
|
||||
@@ -596,8 +598,7 @@ export function buildChatUrl(
|
||||
const finalSearchParams: string[] = [];
|
||||
if (chatSessionId) {
|
||||
finalSearchParams.push(
|
||||
`${
|
||||
search ? SEARCH_PARAM_NAMES.SEARCH_ID : SEARCH_PARAM_NAMES.CHAT_ID
|
||||
`${search ? SEARCH_PARAM_NAMES.SEARCH_ID : SEARCH_PARAM_NAMES.CHAT_ID
|
||||
}=${chatSessionId}`
|
||||
);
|
||||
}
|
||||
|
||||
@@ -13,7 +13,6 @@ import {
|
||||
sortableKeyboardCoordinates,
|
||||
verticalListSortingStrategy,
|
||||
} from "@dnd-kit/sortable";
|
||||
import { CSS } from "@dnd-kit/utilities";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { getFinalLLM } from "@/lib/llm/utils";
|
||||
@@ -41,6 +40,7 @@ export function AssistantsTab({
|
||||
coordinateGetter: sortableKeyboardCoordinates,
|
||||
})
|
||||
);
|
||||
console.log("llm providers", llmProviders);
|
||||
|
||||
function handleDragEnd(event: DragEndEvent) {
|
||||
const { active, over } = event;
|
||||
|
||||
325
web/src/app/ee/admin/plan/BillingSettings.tsx
Normal file
325
web/src/app/ee/admin/plan/BillingSettings.tsx
Normal file
@@ -0,0 +1,325 @@
|
||||
"use client";
|
||||
|
||||
import { BillingPlanType } from "@/app/admin/settings/interfaces";
|
||||
import { useContext, useEffect, useState } from "react";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { Button, Divider, Card } from "@tremor/react";
|
||||
import { StripeCheckoutButton } from "./StripeCheckoutButton";
|
||||
import {
|
||||
CheckmarkIcon,
|
||||
CheckmarkCircleIcon,
|
||||
LightningIcon,
|
||||
PeopleIcon,
|
||||
XIcon,
|
||||
ChevronDownIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import { FiAward, FiDollarSign, FiStar } from "react-icons/fi";
|
||||
import Cookies from "js-cookie";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Logo } from "@/components/Logo";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
|
||||
export function BillingSettings({ newUser }: { newUser: boolean }) {
|
||||
const settings = useContext(SettingsContext);
|
||||
const cloudSettings = settings?.cloudSettings;
|
||||
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const [isNewUserOpen, setIsNewUserOpen] = useState(true)
|
||||
|
||||
const [newSeats, setNewSeats] = useState<null | number>(null);
|
||||
const [newPlan, setNewPlan] = useState<null | BillingPlanType>(null);
|
||||
|
||||
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
|
||||
useEffect(() => {
|
||||
const success = searchParams.get("success");
|
||||
if (success === "true") {
|
||||
setPopup({
|
||||
message: "Your plan has been successfully updated!",
|
||||
type: "success",
|
||||
});
|
||||
} else if (success === "false") {
|
||||
setPopup({
|
||||
message: "There was an error updating your plan",
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
|
||||
// Clear the 'success' parameter from the URL
|
||||
if (success) {
|
||||
const newUrl = new URL(window.location.href);
|
||||
newUrl.searchParams.delete("success");
|
||||
window.history.replaceState({}, "", newUrl);
|
||||
}
|
||||
}, []);
|
||||
|
||||
// Use actual data from cloudSettings
|
||||
const currentPlan = cloudSettings?.planType;
|
||||
const seats = cloudSettings?.numberOfSeats!;
|
||||
|
||||
useEffect(() => {
|
||||
if (cloudSettings) {
|
||||
setNewSeats(cloudSettings.numberOfSeats);
|
||||
setNewPlan(cloudSettings.planType);
|
||||
}
|
||||
}, [cloudSettings]);
|
||||
|
||||
|
||||
if (!cloudSettings) {
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
const features = [
|
||||
{ name: "All Connector Access", included: true },
|
||||
{ name: "Basic Support", included: true },
|
||||
{ name: "Custom Branding", included: currentPlan !== BillingPlanType.FREE },
|
||||
{
|
||||
name: "Analytics Dashboard",
|
||||
included: currentPlan !== BillingPlanType.FREE,
|
||||
},
|
||||
{ name: "Query History", included: currentPlan !== BillingPlanType.FREE },
|
||||
{
|
||||
name: "Priority Support",
|
||||
included: currentPlan !== BillingPlanType.FREE,
|
||||
},
|
||||
{
|
||||
name: "Service Level Agreements (SLAs)",
|
||||
included: currentPlan === BillingPlanType.ENTERPRISE,
|
||||
},
|
||||
{
|
||||
name: "Advanced Support",
|
||||
included: currentPlan === BillingPlanType.ENTERPRISE,
|
||||
},
|
||||
{
|
||||
name: "Custom Integrations",
|
||||
included: currentPlan === BillingPlanType.ENTERPRISE,
|
||||
},
|
||||
{
|
||||
name: "Dedicated Account Manager",
|
||||
included: currentPlan === BillingPlanType.ENTERPRISE,
|
||||
},
|
||||
];
|
||||
|
||||
function getBillingPlanIcon(planType: BillingPlanType) {
|
||||
switch (planType) {
|
||||
case BillingPlanType.FREE:
|
||||
return <FiStar />;
|
||||
case BillingPlanType.PREMIUM:
|
||||
return <FiDollarSign />;
|
||||
case BillingPlanType.ENTERPRISE:
|
||||
return <FiAward />;
|
||||
default:
|
||||
return <FiStar />;
|
||||
}
|
||||
}
|
||||
|
||||
const handleCloseModal = () => {
|
||||
setIsNewUserOpen(false);
|
||||
Cookies.set("new_auth_user", "false");
|
||||
};
|
||||
|
||||
|
||||
if (newSeats === null || currentPlan === undefined) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="max-w-4xl mr-auto space-y-8 p-6">
|
||||
{newUser && isNewUserOpen && (
|
||||
<Modal
|
||||
onOutsideClick={handleCloseModal}
|
||||
className="max-w-lg w-full p-8 bg-background-150 rounded-lg shadow-xl"
|
||||
>
|
||||
<>
|
||||
<h2 className="text-3xl font-semibold text-text-900 mb-6 text-center">
|
||||
Welcome to Danswer!
|
||||
</h2>
|
||||
<div className="text-center mb-8">
|
||||
<Logo className="mx-auto mb-4" height={150} width={150} />
|
||||
<p className="text-lg text-text-700 leading-relaxed">
|
||||
We're thrilled to have you on board! Here, you can manage your
|
||||
billing settings and explore your plan details.
|
||||
</p>
|
||||
</div>
|
||||
<div className="flex justify-center">
|
||||
<Button
|
||||
onClick={handleCloseModal}
|
||||
className="px-8 py-3 bg-blue-600 text-white text-lg font-semibold rounded-full hover:bg-blue-700 transition duration-300 ease-in-out focus:outline-none focus:ring-2 focus:ring-blue-500 focus:ring-opacity-50"
|
||||
>
|
||||
Let's Get Started
|
||||
</Button>
|
||||
<Button
|
||||
onClick={() => window.open("mailto:support@danswer.ai")}
|
||||
className="border-0 hover:underline ml-4 px-4 py-2 bg-gray-200 w-fit text-text-700 text-sm font-medium rounded-full hover:bg-gray-300 transition duration-300 ease-in-out focus:outline-none focus:ring-2 focus:ring-gray-400 focus:ring-opacity-50 flex items-center"
|
||||
>
|
||||
Questions?
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
</Modal>
|
||||
)}
|
||||
<Card className="bg-white shadow-lg rounded-lg overflow-hidden">
|
||||
<div className="px-8 py-6">
|
||||
<h2 className="text-3xl gap-x-2 font-bold text-text-800 mb-6 flex items-center">
|
||||
Your Plan
|
||||
<CheckmarkCircleIcon size={24} className="text-blue-500" />
|
||||
</h2>
|
||||
<div className="space-y-6">
|
||||
<div className="flex justify-between items-center">
|
||||
<p className="text-lg text-text-600 flex gap-x-2 items-center">
|
||||
<LightningIcon size={20} />
|
||||
Tier:
|
||||
</p>
|
||||
<span className="text-xl font-semibold text-blue-600 capitalize">
|
||||
{currentPlan}
|
||||
</span>
|
||||
</div>
|
||||
<div className="flex justify-between items-center">
|
||||
<p className="text-lg text-text-600 gap-x-2 flex items-center">
|
||||
<PeopleIcon size={20} />
|
||||
Current Seats:
|
||||
</p>
|
||||
<span className="text-xl font-semibold text-blue-600">
|
||||
{seats}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<Divider />
|
||||
|
||||
<div className="mt-6 relative">
|
||||
<label className="block text-lg font-medium text-text-700 mb-2 flex items-center">
|
||||
New Tier:
|
||||
</label>
|
||||
<div
|
||||
className="w-full px-4 py-2 text-lg border border-gray-300 rounded-md focus:ring-2 focus:ring-blue-500 focus:border-transparent transition duration-150 ease-in-out flex items-center justify-between cursor-pointer"
|
||||
onClick={() => setIsOpen(!isOpen)}
|
||||
>
|
||||
<span className="flex items-center">
|
||||
{getBillingPlanIcon(newPlan!)}
|
||||
<span className="ml-2 capitalize">{newPlan}</span>
|
||||
</span>
|
||||
<ChevronDownIcon size={12} />
|
||||
</div>
|
||||
{isOpen && (
|
||||
<div className="absolute z-10 w-full mt-1 bg-white border border-gray-300 rounded-md shadow-lg">
|
||||
{Object.values(BillingPlanType).map((plan) => (
|
||||
<div
|
||||
key={plan}
|
||||
className="px-4 py-2 hover:bg-gray-100 cursor-pointer flex items-center"
|
||||
onClick={() => {
|
||||
setNewPlan(plan);
|
||||
setIsOpen(false);
|
||||
if (plan === BillingPlanType.FREE) {
|
||||
setNewSeats(1);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{getBillingPlanIcon(plan)}
|
||||
<span className="ml-2 capitalize">{plan}</span>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-6">
|
||||
<label className="block text-lg font-medium text-text-700 mb-2 flex items-center">
|
||||
New Number of Seats:
|
||||
</label>
|
||||
{newPlan === BillingPlanType.FREE ? (
|
||||
<input
|
||||
type="number"
|
||||
value={1}
|
||||
disabled
|
||||
className="w-full px-4 py-2 text-lg border border-gray-300 rounded-md bg-gray-100 cursor-not-allowed"
|
||||
/>
|
||||
) : (
|
||||
<input
|
||||
type="number"
|
||||
value={newSeats}
|
||||
onChange={(e) => setNewSeats(Number(e.target.value))}
|
||||
min="1"
|
||||
className="w-full px-4 py-2 text-lg border border-gray-300 rounded-md focus:ring-2 focus:ring-blue-500 focus:border-transparent transition duration-150 ease-in-out"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-8 flex justify-center">
|
||||
<StripeCheckoutButton
|
||||
currentPlan={currentPlan}
|
||||
currentQuantity={cloudSettings.numberOfSeats}
|
||||
newQuantity={newSeats}
|
||||
newPlan={newPlan!}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
|
||||
<Card className="bg-white shadow-lg rounded-lg overflow-hidden">
|
||||
<div className="px-6 py-4">
|
||||
<h2 className="text-3xl font-bold text-text-800 mb-4">Features</h2>
|
||||
<ul className="space-y-3">
|
||||
{features.map((feature, index) => (
|
||||
<li key={index} className="flex items-center text-lg">
|
||||
<span className="mr-3">
|
||||
{feature.included ? (
|
||||
<CheckmarkIcon className="text-success" />
|
||||
) : (
|
||||
<XIcon className="text-error" />
|
||||
)}
|
||||
</span>
|
||||
<span
|
||||
className={
|
||||
feature.included ? "text-text-800" : "text-text-500"
|
||||
}
|
||||
>
|
||||
{feature.name}
|
||||
</span>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
</Card>
|
||||
{currentPlan !== BillingPlanType.FREE && (
|
||||
<Card className="bg-white shadow-lg rounded-lg overflow-hidden">
|
||||
<div className="px-8 py-6">
|
||||
<h2 className="text-3xl font-bold text-text-800 mb-4">
|
||||
Tenant Deletion
|
||||
</h2>
|
||||
<p className="text-text-600 mb-6">
|
||||
Permanently delete your tenant and all associated data.
|
||||
</p>
|
||||
<div
|
||||
className="bg-red-100 border-l-4 border-red-500 text-error p-4 mb-6"
|
||||
role="alert"
|
||||
>
|
||||
<p className="font-bold">Warning:</p>
|
||||
<p>Deleting your tenant will result in the following:</p>
|
||||
<ul className="list-disc list-inside mt-2">
|
||||
<li>
|
||||
All data associated with this tenant will be permanently
|
||||
deleted
|
||||
</li>
|
||||
<li>This action cannot be undone</li>
|
||||
</ul>
|
||||
</div>
|
||||
<Button
|
||||
onClick={() => {
|
||||
alert("not implemented");
|
||||
}}
|
||||
className="bg-red-500 hover:bg-red-600 text-white font-bold py-3 px-6 rounded-lg transition duration-300 shadow-md hover:shadow-lg"
|
||||
>
|
||||
Delete Tenant
|
||||
</Button>
|
||||
</div>
|
||||
</Card>
|
||||
)}
|
||||
{popup}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
72
web/src/app/ee/admin/plan/ImageUpload.tsx
Normal file
72
web/src/app/ee/admin/plan/ImageUpload.tsx
Normal file
@@ -0,0 +1,72 @@
|
||||
import { SubLabel } from "@/components/admin/connectors/Field";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { useEffect, useState } from "react";
|
||||
import Dropzone from "react-dropzone";
|
||||
|
||||
export function ImageUpload({
|
||||
selectedFile,
|
||||
setSelectedFile,
|
||||
}: {
|
||||
selectedFile: File | null;
|
||||
setSelectedFile: (file: File) => void;
|
||||
}) {
|
||||
const [tmpImageUrl, setTmpImageUrl] = useState<string>("");
|
||||
|
||||
useEffect(() => {
|
||||
if (selectedFile) {
|
||||
setTmpImageUrl(URL.createObjectURL(selectedFile));
|
||||
} else {
|
||||
setTmpImageUrl("");
|
||||
}
|
||||
}, [selectedFile]);
|
||||
|
||||
const [dragActive, setDragActive] = useState(false);
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
return (
|
||||
<>
|
||||
{popup}
|
||||
<Dropzone
|
||||
onDrop={(acceptedFiles) => {
|
||||
if (acceptedFiles.length !== 1) {
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: "Only one file can be uploaded at a time",
|
||||
});
|
||||
}
|
||||
|
||||
setTmpImageUrl(URL.createObjectURL(acceptedFiles[0]));
|
||||
setSelectedFile(acceptedFiles[0]);
|
||||
setDragActive(false);
|
||||
}}
|
||||
onDragLeave={() => setDragActive(false)}
|
||||
onDragEnter={() => setDragActive(true)}
|
||||
>
|
||||
{({ getRootProps, getInputProps }) => (
|
||||
<section>
|
||||
<div
|
||||
{...getRootProps()}
|
||||
className={
|
||||
"flex flex-col items-center w-full px-4 py-12 rounded " +
|
||||
"shadow-lg tracking-wide border border-border cursor-pointer" +
|
||||
(dragActive ? " border-accent" : "")
|
||||
}
|
||||
>
|
||||
<input {...getInputProps()} />
|
||||
<b className="text-emphasis">
|
||||
Drag and drop a .png or .jpg file, or click to select a file!
|
||||
</b>
|
||||
</div>
|
||||
|
||||
{tmpImageUrl && (
|
||||
<div className="mt-4 mb-8">
|
||||
<SubLabel>Uploaded Image:</SubLabel>
|
||||
<img src={tmpImageUrl} className="mt-4 max-w-xs max-h-64" />
|
||||
</div>
|
||||
)}
|
||||
</section>
|
||||
)}
|
||||
</Dropzone>
|
||||
</>
|
||||
);
|
||||
}
|
||||
78
web/src/app/ee/admin/plan/StripeCheckoutButton.tsx
Normal file
78
web/src/app/ee/admin/plan/StripeCheckoutButton.tsx
Normal file
@@ -0,0 +1,78 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { loadStripe } from "@stripe/stripe-js";
|
||||
import { BillingPlanType } from "@/app/admin/settings/interfaces";
|
||||
|
||||
export function StripeCheckoutButton({
|
||||
newQuantity,
|
||||
newPlan,
|
||||
currentQuantity,
|
||||
currentPlan,
|
||||
}: {
|
||||
newQuantity: number;
|
||||
newPlan: BillingPlanType;
|
||||
currentQuantity: number;
|
||||
currentPlan: BillingPlanType;
|
||||
}) {
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
|
||||
const handleClick = async () => {
|
||||
setIsLoading(true);
|
||||
try {
|
||||
const response = await fetch("/api/create-checkout-session", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ quantity: newQuantity, plan: newPlan }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to create checkout session");
|
||||
}
|
||||
|
||||
const { sessionId } = await response.json();
|
||||
const stripe = await loadStripe(
|
||||
process.env.NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY!
|
||||
);
|
||||
if (stripe) {
|
||||
await stripe.redirectToCheckout({ sessionId });
|
||||
} else {
|
||||
throw new Error("Stripe failed to load");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error:", error);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<button
|
||||
onClick={handleClick}
|
||||
className={`py-2 px-4 text-white rounded ${currentPlan === newPlan && currentQuantity === newQuantity
|
||||
? "bg-gray-400 cursor-not-allowed"
|
||||
: "bg-blue-500 hover:bg-blue-600"
|
||||
} disabled:bg-blue-300`}
|
||||
disabled={
|
||||
(currentPlan === newPlan && currentQuantity === newQuantity) ||
|
||||
isLoading
|
||||
}
|
||||
>
|
||||
{isLoading
|
||||
? "Loading..."
|
||||
: currentPlan === newPlan && currentQuantity === newQuantity
|
||||
? "No Changes"
|
||||
: newPlan > currentPlan ||
|
||||
(newPlan === currentPlan && newQuantity > currentQuantity)
|
||||
? "Upgrade Plan"
|
||||
: newPlan == BillingPlanType.ENTERPRISE
|
||||
? "Talk to us"
|
||||
: // : newPlan < currentPlan ||
|
||||
newPlan === currentPlan && newQuantity < currentQuantity
|
||||
? "Upgrade Plan"
|
||||
: "Change Plan"}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
19
web/src/app/ee/admin/plan/page.tsx
Normal file
19
web/src/app/ee/admin/plan/page.tsx
Normal file
@@ -0,0 +1,19 @@
|
||||
import { BillingSettings } from "./BillingSettings";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { CreditCardIcon } from "@/components/icons/icons";
|
||||
import { cookies } from "next/headers";
|
||||
|
||||
export default async function Whitelabeling() {
|
||||
const newUser =
|
||||
cookies().get("new_auth_user")?.value.toLocaleLowerCase() === "true";
|
||||
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<AdminPageTitle
|
||||
title="Billing"
|
||||
icon={<CreditCardIcon size={32} className="my-auto" />}
|
||||
/>
|
||||
<BillingSettings newUser={newUser} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -9,11 +9,11 @@ export default async function AdminLayout({
|
||||
return (
|
||||
<div className="flex h-screen">
|
||||
<div className="mx-auto my-auto text-lg font-bold text-red-500">
|
||||
This funcitonality is only available in the Enterprise Edition :(
|
||||
This functionality is only available in the Enterprise Edition :(
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return children;
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,5 @@
|
||||
import "./globals.css";
|
||||
|
||||
import {
|
||||
fetchEnterpriseSettingsSS,
|
||||
fetchSettingsSS,
|
||||
SettingsError,
|
||||
} from "@/components/settings/lib";
|
||||
import { fetchEnterpriseSettingsSS, fetchSettingsSS } from "@/components/settings/lib";
|
||||
import {
|
||||
CUSTOM_ANALYTICS_ENABLED,
|
||||
SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED,
|
||||
@@ -14,10 +9,8 @@ import { Metadata } from "next";
|
||||
import { buildClientUrl } from "@/lib/utilsSS";
|
||||
import { Inter } from "next/font/google";
|
||||
import Head from "next/head";
|
||||
import { EnterpriseSettings } from "./admin/settings/interfaces";
|
||||
import { redirect } from "next/navigation";
|
||||
import { Button, Card } from "@tremor/react";
|
||||
import LogoType from "@/components/header/LogoType";
|
||||
import { CombinedSettings, defaultCombinedSettings, EnterpriseSettings } from "./admin/settings/interfaces";
|
||||
import { Card } from "@tremor/react";
|
||||
import { HeaderTitle } from "@/components/header/HeaderTitle";
|
||||
import { Logo } from "@/components/Logo";
|
||||
import { UserProvider } from "@/components/user/UserProvider";
|
||||
@@ -55,10 +48,9 @@ export default async function RootLayout({
|
||||
}: {
|
||||
children: React.ReactNode;
|
||||
}) {
|
||||
const combinedSettings = await fetchSettingsSS();
|
||||
if (!combinedSettings) {
|
||||
// Just display a simple full page error if fetching fails.
|
||||
const combinedSettings = await fetchSettingsSS() || defaultCombinedSettings;
|
||||
|
||||
if (!combinedSettings) {
|
||||
return (
|
||||
<html lang="en" className={`${inter.variable} font-sans`}>
|
||||
<Head>
|
||||
|
||||
@@ -22,7 +22,7 @@ export const IsPublicGroupSelector = <T extends IsPublicGroupSelectorFormType>({
|
||||
enforceGroupSelection?: boolean;
|
||||
}) => {
|
||||
const { data: userGroups, isLoading: userGroupsIsLoading } = useUserGroups();
|
||||
const { isAdmin, user, isLoadingUser } = useUser();
|
||||
const { isAdmin, user, isLoadingUser, isCurator } = useUser();
|
||||
const [shouldHideContent, setShouldHideContent] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
@@ -87,42 +87,45 @@ export const IsPublicGroupSelector = <T extends IsPublicGroupSelectorFormType>({
|
||||
)}
|
||||
|
||||
{(!formikProps.values.is_public ||
|
||||
!isAdmin ||
|
||||
formikProps.values.groups.length > 0) && (
|
||||
<>
|
||||
<div className="flex mt-4 gap-x-2 items-center">
|
||||
<div className="block font-medium text-base">
|
||||
Assign group access for this {objectName}
|
||||
</div>
|
||||
</div>
|
||||
<Text className="mb-3">
|
||||
{isAdmin || !enforceGroupSelection ? (
|
||||
<>
|
||||
This {objectName} will be visible/accessible by the groups
|
||||
selected below
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
Curators must select one or more groups to give access to this{" "}
|
||||
{objectName}
|
||||
</>
|
||||
)}
|
||||
</Text>
|
||||
<FieldArray
|
||||
name="groups"
|
||||
render={(arrayHelpers: ArrayHelpers) => (
|
||||
<div className="flex gap-2 flex-wrap mb-4">
|
||||
{userGroupsIsLoading ? (
|
||||
<div className="animate-pulse bg-gray-200 h-8 w-32 rounded"></div>
|
||||
isCurator ||
|
||||
formikProps.values.groups.length > 0) &&
|
||||
(userGroupsIsLoading ? (
|
||||
<div className="animate-pulse bg-gray-200 h-8 w-32 rounded"></div>
|
||||
) : (
|
||||
userGroups &&
|
||||
userGroups.length > 0 && (
|
||||
<>
|
||||
<div className="flex mt-4 gap-x-2 items-center">
|
||||
<div className="block font-medium text-base">
|
||||
Assign group access for this {objectName}
|
||||
</div>
|
||||
</div>
|
||||
<Text className="mb-3">
|
||||
{isAdmin || !enforceGroupSelection ? (
|
||||
<>
|
||||
This {objectName} will be visible/accessible by the groups
|
||||
selected below
|
||||
</>
|
||||
) : (
|
||||
userGroups &&
|
||||
userGroups.map((userGroup: UserGroup) => {
|
||||
const ind = formikProps.values.groups.indexOf(userGroup.id);
|
||||
let isSelected = ind !== -1;
|
||||
return (
|
||||
<div
|
||||
key={userGroup.id}
|
||||
className={`
|
||||
<>
|
||||
Curators must select one or more groups to give access to
|
||||
this {objectName}
|
||||
</>
|
||||
)}
|
||||
</Text>
|
||||
<FieldArray
|
||||
name="groups"
|
||||
render={(arrayHelpers: ArrayHelpers) => (
|
||||
<div className="flex gap-2 flex-wrap mb-4">
|
||||
{userGroups.map((userGroup: UserGroup) => {
|
||||
const ind = formikProps.values.groups.indexOf(
|
||||
userGroup.id
|
||||
);
|
||||
let isSelected = ind !== -1;
|
||||
return (
|
||||
<div
|
||||
key={userGroup.id}
|
||||
className={`
|
||||
px-3
|
||||
py-1
|
||||
rounded-lg
|
||||
@@ -132,32 +135,33 @@ export const IsPublicGroupSelector = <T extends IsPublicGroupSelectorFormType>({
|
||||
flex
|
||||
cursor-pointer
|
||||
${isSelected ? "bg-background-strong" : "hover:bg-hover"}
|
||||
`}
|
||||
onClick={() => {
|
||||
if (isSelected) {
|
||||
arrayHelpers.remove(ind);
|
||||
} else {
|
||||
arrayHelpers.push(userGroup.id);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="my-auto flex">
|
||||
<FiUsers className="my-auto mr-2" /> {userGroup.name}
|
||||
`}
|
||||
onClick={() => {
|
||||
if (isSelected) {
|
||||
arrayHelpers.remove(ind);
|
||||
} else {
|
||||
arrayHelpers.push(userGroup.id);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="my-auto flex">
|
||||
<FiUsers className="my-auto mr-2" />{" "}
|
||||
{userGroup.name}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
})
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
<ErrorMessage
|
||||
name="groups"
|
||||
component="div"
|
||||
className="text-error text-sm mt-1"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
/>
|
||||
<ErrorMessage
|
||||
name="groups"
|
||||
component="div"
|
||||
className="text-error text-sm mt-1"
|
||||
/>
|
||||
</>
|
||||
)
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -21,6 +21,7 @@ import {
|
||||
AssistantsIconSkeleton,
|
||||
ClosedBookIcon,
|
||||
SearchIcon,
|
||||
CreditCardIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import { UserRole } from "@/lib/types";
|
||||
import { FiActivity, FiBarChart2 } from "react-icons/fi";
|
||||
@@ -29,16 +30,17 @@ import { User } from "@/lib/types";
|
||||
import { usePathname } from "next/navigation";
|
||||
import { SettingsContext } from "../settings/SettingsProvider";
|
||||
import { useContext } from "react";
|
||||
import { CustomTooltip } from "../tooltip/CustomTooltip";
|
||||
|
||||
export function ClientLayout({
|
||||
user,
|
||||
children,
|
||||
enableEnterprise,
|
||||
cloudEnabled,
|
||||
}: {
|
||||
user: User | null;
|
||||
children: React.ReactNode;
|
||||
enableEnterprise: boolean;
|
||||
cloudEnabled: boolean;
|
||||
}) {
|
||||
const isCurator =
|
||||
user?.role === UserRole.CURATOR || user?.role === UserRole.GLOBAL_CURATOR;
|
||||
@@ -317,6 +319,20 @@ export function ClientLayout({
|
||||
},
|
||||
]
|
||||
: []),
|
||||
|
||||
...(cloudEnabled
|
||||
? [
|
||||
{
|
||||
name: (
|
||||
<div className="flex">
|
||||
<CreditCardIcon size={18} />
|
||||
<div className="ml-1">Billing</div>
|
||||
</div>
|
||||
),
|
||||
link: "/admin/plan",
|
||||
},
|
||||
]
|
||||
: []),
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
@@ -6,7 +6,10 @@ import {
|
||||
} from "@/lib/userSS";
|
||||
import { redirect } from "next/navigation";
|
||||
import { ClientLayout } from "./ClientLayout";
|
||||
import { SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED } from "@/lib/constants";
|
||||
import {
|
||||
CLOUD_ENABLED,
|
||||
SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED,
|
||||
} from "@/lib/constants";
|
||||
import { AnnouncementBanner } from "../header/AnnouncementBanner";
|
||||
|
||||
export async function Layout({ children }: { children: React.ReactNode }) {
|
||||
@@ -43,6 +46,7 @@ export async function Layout({ children }: { children: React.ReactNode }) {
|
||||
return (
|
||||
<ClientLayout
|
||||
enableEnterprise={SERVER_SIDE_ONLY__PAID_ENTERPRISE_FEATURES_ENABLED}
|
||||
cloudEnabled={CLOUD_ENABLED}
|
||||
user={user}
|
||||
>
|
||||
<AnnouncementBanner />
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user