mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
91 Commits
eval/split
...
v0.4.14
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a54ea9f9fa | ||
|
|
73a92c046d | ||
|
|
459bd46846 | ||
|
|
445f7e70ba | ||
|
|
ca893f9918 | ||
|
|
1be1959d80 | ||
|
|
1654378850 | ||
|
|
d6d391d244 | ||
|
|
7c283b090d | ||
|
|
40226678af | ||
|
|
288e6fa606 | ||
|
|
5307d38472 | ||
|
|
d619602a6f | ||
|
|
348a2176f0 | ||
|
|
89b6da36a6 | ||
|
|
036d5c737e | ||
|
|
60a87d9472 | ||
|
|
eb9bb56829 | ||
|
|
d151082871 | ||
|
|
e4b1f5b963 | ||
|
|
3938a053aa | ||
|
|
7932e764d6 | ||
|
|
fb6695a983 | ||
|
|
015f415b71 | ||
|
|
96b582070b | ||
|
|
4a0a927a64 | ||
|
|
ea9a9cb553 | ||
|
|
38af12ab97 | ||
|
|
1b3154188d | ||
|
|
1f321826ad | ||
|
|
cbfbe4e5d8 | ||
|
|
3aa0e0124b | ||
|
|
f2f60c9cc0 | ||
|
|
6c32821ad4 | ||
|
|
d839595330 | ||
|
|
e422f96dff | ||
|
|
d28f460330 | ||
|
|
8e441d975d | ||
|
|
5c78af1f07 | ||
|
|
e325e063ed | ||
|
|
c81b45300b | ||
|
|
26a1e963d1 | ||
|
|
2a983263c7 | ||
|
|
2a37c95a5e | ||
|
|
c277a74f82 | ||
|
|
e4b31cd0d9 | ||
|
|
a40d2a1e2e | ||
|
|
c9fb99d719 | ||
|
|
a4d71e08aa | ||
|
|
546bfbd24b | ||
|
|
27824d6cc6 | ||
|
|
9d5c4ad634 | ||
|
|
9b32003816 | ||
|
|
8bc4123ed7 | ||
|
|
d58aaf7a59 | ||
|
|
a0056a1b3c | ||
|
|
d2584c773a | ||
|
|
807bef8ada | ||
|
|
5afddacbb2 | ||
|
|
4fb6a88f1e | ||
|
|
7057be6a88 | ||
|
|
91be8e7bfb | ||
|
|
9651ea828b | ||
|
|
6ee74bd0d1 | ||
|
|
48a0d29a5c | ||
|
|
6ff8e6c0ea | ||
|
|
2470c68506 | ||
|
|
866bc803b1 | ||
|
|
9c6084bd0d | ||
|
|
a0b46c60c6 | ||
|
|
4029233df0 | ||
|
|
6c88c0156c | ||
|
|
33332d08f2 | ||
|
|
17005fb705 | ||
|
|
48a7fe80b1 | ||
|
|
1276732409 | ||
|
|
f91b92a898 | ||
|
|
6222f533be | ||
|
|
1b49d17239 | ||
|
|
2f5f19642e | ||
|
|
6db4634871 | ||
|
|
5cfed45cef | ||
|
|
581ffde35a | ||
|
|
6313e6d91d | ||
|
|
c09c94bf32 | ||
|
|
0e8ba111c8 | ||
|
|
2ba24b1734 | ||
|
|
44820b4909 | ||
|
|
eb3e7610fc | ||
|
|
7fbbb174bb | ||
|
|
3854ca11af |
25
.github/pull_request_template.md
vendored
Normal file
25
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
## Description
|
||||
[Provide a brief description of the changes in this PR]
|
||||
|
||||
|
||||
## How Has This Been Tested?
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Accepted Risk
|
||||
[Any know risks or failure modes to point out to reviewers]
|
||||
|
||||
|
||||
## Related Issue(s)
|
||||
[If applicable, link to the issue(s) this PR addresses]
|
||||
|
||||
|
||||
## Checklist:
|
||||
- [ ] All of the automated tests pass
|
||||
- [ ] All PR comments are addressed and marked resolved
|
||||
- [ ] If there are migrations, they have been rebased to latest main
|
||||
- [ ] If there are new dependencies, they are added to the requirements
|
||||
- [ ] If there are new environment variables, they are added to all of the deployment methods
|
||||
- [ ] If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- [ ] Docker images build and basic functionalities work
|
||||
- [ ] Author has done a final read through of the PR right before merge
|
||||
@@ -68,7 +68,9 @@ RUN apt-get update && \
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('intfloat/e5-base-v2')"
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
|
||||
@@ -18,14 +18,17 @@ RUN apt-get remove -y --allow-remove-essential perl-base && \
|
||||
apt-get autoremove -y
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from transformers import AutoModel, AutoTokenizer, TFDistilBertForSequenceClassification; \
|
||||
RUN python -c "from transformers import AutoTokenizer; \
|
||||
AutoTokenizer.from_pretrained('danswer/intent-model', cache_folder='/root/.cache/temp_huggingface/hub/'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1', cache_folder='/root/.cache/temp_huggingface/hub/'); \
|
||||
from transformers import TFDistilBertForSequenceClassification; \
|
||||
TFDistilBertForSequenceClassification.from_pretrained('danswer/intent-model', cache_dir='/root/.cache/temp_huggingface/hub/'); \
|
||||
from huggingface_hub import snapshot_download; \
|
||||
AutoTokenizer.from_pretrained('danswer/intent-model'); \
|
||||
AutoTokenizer.from_pretrained('intfloat/e5-base-v2'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
snapshot_download('danswer/intent-model'); \
|
||||
snapshot_download('intfloat/e5-base-v2'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1')"
|
||||
snapshot_download('danswer/intent-model', cache_dir='/root/.cache/temp_huggingface/hub/'); \
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1', cache_dir='/root/.cache/temp_huggingface/hub/'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1', cache_dir='/root/.cache/temp_huggingface/hub/'); \
|
||||
from sentence_transformers import SentenceTransformer; \
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True, cache_folder='/root/.cache/temp_huggingface/hub/');"
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -17,15 +17,11 @@ depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("current_alternate_model", sa.String(), nullable=True),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("chat_session", "current_alternate_model")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""add_indexing_start_to_connector
|
||||
|
||||
Revision ID: 08a1eda20fe1
|
||||
Revises: 8a87bd6ec550
|
||||
Create Date: 2024-07-23 11:12:39.462397
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "08a1eda20fe1"
|
||||
down_revision = "8a87bd6ec550"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector", sa.Column("indexing_start", sa.DateTime(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector", "indexing_start")
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Add icon_color and icon_shape to Persona
|
||||
|
||||
Revision ID: 325975216eb3
|
||||
Revises: 91ffac7e65b3
|
||||
Create Date: 2024-07-24 21:29:31.784562
|
||||
|
||||
"""
|
||||
import random
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column, select
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "325975216eb3"
|
||||
down_revision = "91ffac7e65b3"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
colorOptions = [
|
||||
"#FF6FBF",
|
||||
"#6FB1FF",
|
||||
"#B76FFF",
|
||||
"#FFB56F",
|
||||
"#6FFF8D",
|
||||
"#FF6F6F",
|
||||
"#6FFFFF",
|
||||
]
|
||||
|
||||
|
||||
# Function to generate a random shape ensuring at least 3 of the middle 4 squares are filled
|
||||
def generate_random_shape() -> int:
|
||||
center_squares = [12, 10, 6, 14, 13, 11, 7, 15]
|
||||
center_fill = random.choice(center_squares)
|
||||
remaining_squares = [i for i in range(16) if not (center_fill & (1 << i))]
|
||||
random.shuffle(remaining_squares)
|
||||
for i in range(10 - bin(center_fill).count("1")):
|
||||
center_fill |= 1 << remaining_squares[i]
|
||||
return center_fill
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("persona", sa.Column("icon_color", sa.String(), nullable=True))
|
||||
op.add_column("persona", sa.Column("icon_shape", sa.Integer(), nullable=True))
|
||||
op.add_column("persona", sa.Column("uploaded_image_id", sa.String(), nullable=True))
|
||||
|
||||
persona = table(
|
||||
"persona",
|
||||
column("id", sa.Integer),
|
||||
column("icon_color", sa.String),
|
||||
column("icon_shape", sa.Integer),
|
||||
)
|
||||
|
||||
conn = op.get_bind()
|
||||
personas = conn.execute(select(persona.c.id))
|
||||
|
||||
for persona_id in personas:
|
||||
random_color = random.choice(colorOptions)
|
||||
random_shape = generate_random_shape()
|
||||
conn.execute(
|
||||
persona.update()
|
||||
.where(persona.c.id == persona_id[0])
|
||||
.values(icon_color=random_color, icon_shape=random_shape)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "icon_shape")
|
||||
op.drop_column("persona", "uploaded_image_id")
|
||||
op.drop_column("persona", "icon_color")
|
||||
@@ -18,7 +18,6 @@ depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
@@ -29,10 +28,8 @@ def upgrade() -> None:
|
||||
["alternate_assistant_id"],
|
||||
["id"],
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint("fk_chat_message_persona", "chat_message", type_="foreignkey")
|
||||
op.drop_column("chat_message", "alternate_assistant_id")
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Add display_model_names to llm_provider
|
||||
|
||||
Revision ID: 473a1a7ca408
|
||||
Revises: 325975216eb3
|
||||
Create Date: 2024-07-25 14:31:02.002917
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "473a1a7ca408"
|
||||
down_revision = "325975216eb3"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
default_models_by_provider = {
|
||||
"openai": ["gpt-4", "gpt-4o", "gpt-4o-mini"],
|
||||
"bedrock": [
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
"anthropic.claude-3-opus-20240229-v1:0",
|
||||
"mistral.mistral-large-2402-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
],
|
||||
"anthropic": ["claude-3-opus-20240229", "claude-3-5-sonnet-20240620"],
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("display_model_names", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
)
|
||||
|
||||
connection = op.get_bind()
|
||||
for provider, models in default_models_by_provider.items():
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"UPDATE llm_provider SET display_model_names = :models WHERE provider = :provider"
|
||||
),
|
||||
{"models": models, "provider": provider},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("llm_provider", "display_model_names")
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Add type to credentials
|
||||
|
||||
Revision ID: 4ea2c93919c1
|
||||
Revises: 473a1a7ca408
|
||||
Create Date: 2024-07-18 13:07:13.655895
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4ea2c93919c1"
|
||||
down_revision = "473a1a7ca408"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new 'source' column to the 'credential' table
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column(
|
||||
"source",
|
||||
sa.String(length=100), # Use String instead of Enum
|
||||
nullable=True, # Initially allow NULL values
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column(
|
||||
"name",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create a temporary table that maps each credential to a single connector source.
|
||||
# This is needed because a credential can be associated with multiple connectors,
|
||||
# but we want to assign a single source to each credential.
|
||||
# We use DISTINCT ON to ensure we only get one row per credential_id.
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TEMPORARY TABLE temp_connector_credential AS
|
||||
SELECT DISTINCT ON (cc.credential_id)
|
||||
cc.credential_id,
|
||||
c.source AS connector_source
|
||||
FROM connector_credential_pair cc
|
||||
JOIN connector c ON cc.connector_id = c.id
|
||||
"""
|
||||
)
|
||||
|
||||
# Update the 'source' column in the 'credential' table
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE credential cred
|
||||
SET source = COALESCE(
|
||||
(SELECT connector_source
|
||||
FROM temp_connector_credential temp
|
||||
WHERE cred.id = temp.credential_id),
|
||||
'NOT_APPLICABLE'
|
||||
)
|
||||
"""
|
||||
)
|
||||
# If no exception was raised, alter the column
|
||||
op.alter_column("credential", "source", nullable=True) # TODO modify
|
||||
# # ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("credential", "source")
|
||||
op.drop_column("credential", "name")
|
||||
@@ -0,0 +1,41 @@
|
||||
"""add_llm_group_permissions_control
|
||||
|
||||
Revision ID: 795b20b85b4b
|
||||
Revises: 05c07bf07c00
|
||||
Create Date: 2024-07-19 11:54:35.701558
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = "795b20b85b4b"
|
||||
down_revision = "05c07bf07c00"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"llm_provider__user_group",
|
||||
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["llm_provider_id"],
|
||||
["llm_provider.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"],
|
||||
["user_group.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
|
||||
)
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("llm_provider__user_group")
|
||||
op.drop_column("llm_provider", "is_public")
|
||||
@@ -0,0 +1,103 @@
|
||||
"""associate index attempts with ccpair
|
||||
|
||||
Revision ID: 8a87bd6ec550
|
||||
Revises: 4ea2c93919c1
|
||||
Create Date: 2024-07-22 15:15:52.558451
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8a87bd6ec550"
|
||||
down_revision = "4ea2c93919c1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new connector_credential_pair_id column
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
# Create a foreign key constraint to the connector_credential_pair table
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_connector_credential_pair_id",
|
||||
"index_attempt",
|
||||
"connector_credential_pair",
|
||||
["connector_credential_pair_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Populate the new connector_credential_pair_id column using existing connector_id and credential_id
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE index_attempt ia
|
||||
SET connector_credential_pair_id =
|
||||
CASE
|
||||
WHEN ia.credential_id IS NULL THEN
|
||||
(SELECT id FROM connector_credential_pair
|
||||
WHERE connector_id = ia.connector_id
|
||||
LIMIT 1)
|
||||
ELSE
|
||||
(SELECT id FROM connector_credential_pair
|
||||
WHERE connector_id = ia.connector_id
|
||||
AND credential_id = ia.credential_id)
|
||||
END
|
||||
WHERE ia.connector_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the new connector_credential_pair_id column non-nullable
|
||||
op.alter_column("index_attempt", "connector_credential_pair_id", nullable=False)
|
||||
|
||||
# Drop the old connector_id and credential_id columns
|
||||
op.drop_column("index_attempt", "connector_id")
|
||||
op.drop_column("index_attempt", "credential_id")
|
||||
|
||||
# Update the index to use connector_credential_pair_id
|
||||
op.create_index(
|
||||
"ix_index_attempt_latest_for_connector_credential_pair",
|
||||
"index_attempt",
|
||||
["connector_credential_pair_id", "time_created"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the old connector_id and credential_id columns
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("connector_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("credential_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Populate the old connector_id and credential_id columns using the connector_credential_pair_id
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE index_attempt ia
|
||||
SET connector_id = ccp.connector_id, credential_id = ccp.credential_id
|
||||
FROM connector_credential_pair ccp
|
||||
WHERE ia.connector_credential_pair_id = ccp.id
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the old connector_id and credential_id columns non-nullable
|
||||
op.alter_column("index_attempt", "connector_id", nullable=False)
|
||||
op.alter_column("index_attempt", "credential_id", nullable=False)
|
||||
|
||||
# Drop the new connector_credential_pair_id column
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_connector_credential_pair_id",
|
||||
"index_attempt",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_column("index_attempt", "connector_credential_pair_id")
|
||||
|
||||
op.create_index(
|
||||
"ix_index_attempt_latest_for_connector_credential_pair",
|
||||
"index_attempt",
|
||||
["connector_id", "credential_id", "time_created"],
|
||||
)
|
||||
26
backend/alembic/versions/91ffac7e65b3_add_expiry_time.py
Normal file
26
backend/alembic/versions/91ffac7e65b3_add_expiry_time.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""add expiry time
|
||||
|
||||
Revision ID: 91ffac7e65b3
|
||||
Revises: bc9771dccadf
|
||||
Create Date: 2024-06-24 09:39:56.462242
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "91ffac7e65b3"
|
||||
down_revision = "795b20b85b4b"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user", sa.Column("oidc_expiry", sa.DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "oidc_expiry")
|
||||
@@ -16,7 +16,6 @@ depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"last_attempt_status",
|
||||
@@ -29,11 +28,9 @@ def upgrade() -> None:
|
||||
),
|
||||
nullable=True,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"last_attempt_status",
|
||||
@@ -46,4 +43,3 @@ def downgrade() -> None:
|
||||
),
|
||||
nullable=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
58
backend/alembic/versions/e1392f05e840_added_input_prompts.py
Normal file
58
backend/alembic/versions/e1392f05e840_added_input_prompts.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Added input prompts
|
||||
|
||||
Revision ID: e1392f05e840
|
||||
Revises: 08a1eda20fe1
|
||||
Create Date: 2024-07-13 19:09:22.556224
|
||||
|
||||
"""
|
||||
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e1392f05e840"
|
||||
down_revision = "08a1eda20fe1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"inputprompt",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("prompt", sa.String(), nullable=False),
|
||||
sa.Column("content", sa.String(), nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"inputprompt__user",
|
||||
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["input_prompt_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("inputprompt__user")
|
||||
op.drop_table("inputprompt")
|
||||
@@ -1,6 +1,8 @@
|
||||
import smtplib
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Optional
|
||||
@@ -50,8 +52,10 @@ from danswer.db.auth import get_default_admin_user_emails
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import User
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
@@ -92,12 +96,18 @@ def user_needs_to_be_verified() -> bool:
|
||||
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str) -> None:
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
whitelist = get_invited_users()
|
||||
if (whitelist and email not in whitelist) or not email:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
|
||||
|
||||
def verify_email_domain(email: str) -> None:
|
||||
if VALID_EMAIL_DOMAINS:
|
||||
if email.count("@") != 1:
|
||||
@@ -147,7 +157,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> models.UP:
|
||||
verify_email_in_whitelist(user_create.email)
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
@@ -173,7 +183,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
verify_email_in_whitelist(account_email)
|
||||
verify_email_domain(account_email)
|
||||
|
||||
return await super().oauth_callback( # type: ignore
|
||||
user = await super().oauth_callback( # type: ignore
|
||||
oauth_name=oauth_name,
|
||||
access_token=access_token,
|
||||
account_id=account_id,
|
||||
@@ -185,6 +195,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
is_verified_by_default=is_verified_by_default,
|
||||
)
|
||||
|
||||
# NOTE: google oauth expires after 1hr. We don't want to force the user to
|
||||
# re-authenticate that frequently, so for now we'll just ignore this for
|
||||
# google oauth users
|
||||
if expires_at and AUTH_TYPE != AuthType.GOOGLE_OAUTH:
|
||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
|
||||
return user
|
||||
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
@@ -227,10 +245,12 @@ cookie_transport = CookieTransport(
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
strategy = DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
||||
)
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="database",
|
||||
@@ -327,6 +347,12 @@ async def double_check_user(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@@ -345,4 +371,5 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not an admin.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@@ -14,6 +14,7 @@ from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import POSTGRES_CELERY_APP_NAME
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
@@ -38,7 +39,9 @@ from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
connection_string = build_connection_string(db_api=SYNC_DB_API)
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API, app_name=POSTGRES_CELERY_APP_NAME
|
||||
)
|
||||
celery_broker_url = f"sqla+{connection_string}"
|
||||
celery_backend_url = f"db+{connection_string}"
|
||||
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
|
||||
|
||||
@@ -6,8 +6,8 @@ from sqlalchemy.orm import Session
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
@@ -80,7 +80,7 @@ def should_prune_cc_pair(
|
||||
return True
|
||||
return False
|
||||
|
||||
if PREVENT_SIMULTANEOUS_PRUNING:
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
pruning_type_task_name = name_cc_prune_task()
|
||||
last_pruning_type_task = get_latest_task_by_type(
|
||||
pruning_type_task_name, db_session
|
||||
@@ -89,11 +89,9 @@ def should_prune_cc_pair(
|
||||
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
|
||||
last_pruning_type_task, db_session
|
||||
):
|
||||
logger.info("Another Connector is already pruning. Skipping.")
|
||||
return False
|
||||
|
||||
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
|
||||
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
|
||||
return False
|
||||
|
||||
if not last_pruning_task.start_time:
|
||||
|
||||
@@ -20,7 +20,7 @@ from danswer.db.connector_credential_pair import update_connector_credential_pai
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress__no_commit
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
from danswer.db.index_attempt import update_docs_indexed
|
||||
from danswer.db.models import IndexAttempt
|
||||
@@ -49,19 +49,19 @@ def _get_document_generator(
|
||||
are the complete list of existing documents of the connector. If the task
|
||||
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
|
||||
"""
|
||||
task = attempt.connector.input_type
|
||||
task = attempt.connector_credential_pair.connector.input_type
|
||||
|
||||
try:
|
||||
runnable_connector = instantiate_connector(
|
||||
attempt.connector.source,
|
||||
attempt.connector_credential_pair.connector.source,
|
||||
task,
|
||||
attempt.connector.connector_specific_config,
|
||||
attempt.credential,
|
||||
attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
attempt.connector_credential_pair.credential,
|
||||
db_session,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
disable_connector(attempt.connector.id, db_session)
|
||||
disable_connector(attempt.connector_credential_pair.connector.id, db_session)
|
||||
raise e
|
||||
|
||||
if task == InputType.LOAD_STATE:
|
||||
@@ -70,7 +70,10 @@ def _get_document_generator(
|
||||
|
||||
elif task == InputType.POLL:
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
if attempt.connector_id is None or attempt.credential_id is None:
|
||||
if (
|
||||
attempt.connector_credential_pair.connector_id is None
|
||||
or attempt.connector_credential_pair.connector_id is None
|
||||
):
|
||||
raise ValueError(
|
||||
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
|
||||
f"can't fetch time range."
|
||||
@@ -127,16 +130,21 @@ def _run_indexing(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_connector = index_attempt.connector
|
||||
db_credential = index_attempt.credential
|
||||
db_connector = index_attempt.connector_credential_pair.connector
|
||||
db_credential = index_attempt.connector_credential_pair.credential
|
||||
|
||||
last_successful_index_time = (
|
||||
0.0
|
||||
if index_attempt.from_beginning
|
||||
else get_last_successful_attempt_time(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
embedding_model=index_attempt.embedding_model,
|
||||
db_session=db_session,
|
||||
db_connector.indexing_start.timestamp()
|
||||
if index_attempt.from_beginning and db_connector.indexing_start is not None
|
||||
else (
|
||||
0.0
|
||||
if index_attempt.from_beginning
|
||||
else get_last_successful_attempt_time(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
embedding_model=index_attempt.embedding_model,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -189,7 +197,7 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
documents=doc_batch,
|
||||
document_batch=doc_batch,
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
@@ -250,8 +258,8 @@ def _run_indexing(
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector.id,
|
||||
credential_id=index_attempt.credential.id,
|
||||
connector_id=index_attempt.connector_credential_pair.connector.id,
|
||||
credential_id=index_attempt.connector_credential_pair.credential.id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
raise e
|
||||
@@ -299,9 +307,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
||||
)
|
||||
|
||||
# only commit once, to make sure this all happens in a single transaction
|
||||
mark_attempt_in_progress__no_commit(attempt)
|
||||
if attempt.embedding_model.status != IndexModelStatus.PRESENT:
|
||||
db_session.commit()
|
||||
mark_attempt_in_progress(attempt, db_session)
|
||||
|
||||
return attempt
|
||||
|
||||
@@ -324,17 +330,17 @@ def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None:
|
||||
attempt = _prepare_index_attempt(db_session, index_attempt_id)
|
||||
|
||||
logger.info(
|
||||
f"Running indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
f"Running indexing attempt for connector: '{attempt.connector_credential_pair.connector.name}', "
|
||||
f"with config: '{attempt.connector_credential_pair.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
|
||||
_run_indexing(db_session, attempt)
|
||||
|
||||
logger.info(
|
||||
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
f"Completed indexing attempt for connector: '{attempt.connector_credential_pair.connector.name}', "
|
||||
f"with config: '{attempt.connector_credential_pair.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
|
||||
|
||||
@@ -16,15 +16,19 @@ from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
|
||||
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import init_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
from danswer.db.index_attempt import get_last_attempt
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import Connector
|
||||
@@ -33,7 +37,7 @@ from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
@@ -66,20 +70,26 @@ def _should_create_new_indexing(
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
if model.status == IndexModelStatus.FUTURE and not last_index:
|
||||
if connector.id == 0: # Ingestion API
|
||||
return False
|
||||
if model.status == IndexModelStatus.FUTURE:
|
||||
if last_index:
|
||||
# secondary indexes should not index again after success
|
||||
# or else the model will never be able to swap
|
||||
if last_index.status == IndexingStatus.SUCCESS:
|
||||
return False
|
||||
else:
|
||||
if connector.id == 0: # Ingestion API
|
||||
return False
|
||||
return True
|
||||
|
||||
# If the connector is disabled, don't index
|
||||
# NOTE: during an embedding model switch over, we ignore this
|
||||
# and index the disabled connectors as well (which is why this if
|
||||
# statement is below the first condition above)
|
||||
# NOTE: during an embedding model switch over, the following logic
|
||||
# is bypassed by the above check for a future model
|
||||
if connector.disabled:
|
||||
return False
|
||||
|
||||
if connector.refresh_freq is None:
|
||||
return False
|
||||
|
||||
if not last_index:
|
||||
return True
|
||||
|
||||
@@ -111,8 +121,8 @@ def _mark_run_failed(
|
||||
"""Marks the `index_attempt` row as failed + updates the `
|
||||
connector_credential_pair` to reflect that the run failed"""
|
||||
logger.warning(
|
||||
f"Marking in-progress attempt 'connector: {index_attempt.connector_id}, "
|
||||
f"credential: {index_attempt.credential_id}' as failed due to {failure_reason}"
|
||||
f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
|
||||
f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
index_attempt=index_attempt,
|
||||
@@ -131,7 +141,7 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
3. There is not already an ongoing indexing attempt for this pair
|
||||
"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
ongoing: set[tuple[int | None, int | None, int]] = set()
|
||||
ongoing: set[tuple[int | None, int]] = set()
|
||||
for attempt_id in existing_jobs:
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
@@ -144,8 +154,7 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
continue
|
||||
ongoing.add(
|
||||
(
|
||||
attempt.connector_id,
|
||||
attempt.credential_id,
|
||||
attempt.connector_credential_pair_id,
|
||||
attempt.embedding_model_id,
|
||||
)
|
||||
)
|
||||
@@ -155,31 +164,26 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
if secondary_embedding_model is not None:
|
||||
embedding_models.append(secondary_embedding_model)
|
||||
|
||||
all_connectors = fetch_connectors(db_session)
|
||||
for connector in all_connectors:
|
||||
for association in connector.credentials:
|
||||
for model in embedding_models:
|
||||
credential = association.credential
|
||||
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair in all_connector_credential_pairs:
|
||||
for model in embedding_models:
|
||||
# Check if there is an ongoing indexing attempt for this connector credential pair
|
||||
if (cc_pair.id, model.id) in ongoing:
|
||||
continue
|
||||
|
||||
# Check if there is an ongoing indexing attempt for this connector + credential pair
|
||||
if (connector.id, credential.id, model.id) in ongoing:
|
||||
continue
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, model.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
connector=cc_pair.connector,
|
||||
last_index=last_attempt,
|
||||
model=model,
|
||||
secondary_index_building=len(embedding_models) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt(
|
||||
connector.id, credential.id, model.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
connector=connector,
|
||||
last_index=last_attempt,
|
||||
model=model,
|
||||
secondary_index_building=len(embedding_models) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
create_index_attempt(
|
||||
connector.id, credential.id, model.id, db_session
|
||||
)
|
||||
create_index_attempt(cc_pair.id, model.id, db_session)
|
||||
|
||||
|
||||
def cleanup_indexing_jobs(
|
||||
@@ -271,6 +275,8 @@ def kickoff_indexing_jobs(
|
||||
# Don't include jobs waiting in the Dask queue that just haven't started running
|
||||
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
||||
with Session(engine) as db_session:
|
||||
# get_not_started_index_attempts orders its returned results from oldest to newest
|
||||
# we must process attempts in a FIFO manner to prevent connector starvation
|
||||
new_indexing_attempts = [
|
||||
(attempt, attempt.embedding_model)
|
||||
for attempt in get_not_started_index_attempts(db_session)
|
||||
@@ -288,7 +294,7 @@ def kickoff_indexing_jobs(
|
||||
if embedding_model is not None
|
||||
else False
|
||||
)
|
||||
if attempt.connector is None:
|
||||
if attempt.connector_credential_pair.connector is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Connector has been deleted: {attempt}"
|
||||
)
|
||||
@@ -297,7 +303,7 @@ def kickoff_indexing_jobs(
|
||||
attempt, db_session, failure_reason="Connector is null"
|
||||
)
|
||||
continue
|
||||
if attempt.credential is None:
|
||||
if attempt.connector_credential_pair.credential is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Credential has been deleted: {attempt}"
|
||||
)
|
||||
@@ -326,32 +332,35 @@ def kickoff_indexing_jobs(
|
||||
secondary_str = "(secondary index) " if use_secondary_index else ""
|
||||
logger.info(
|
||||
f"Kicked off {secondary_str}"
|
||||
f"indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
f"indexing attempt for connector: '{attempt.connector_credential_pair.connector.name}', "
|
||||
f"with config: '{attempt.connector_credential_pair.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.connector_credential_pair.credential_id}'"
|
||||
)
|
||||
existing_jobs_copy[attempt.id] = run
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
||||
def update_loop(
|
||||
delay: int = 10,
|
||||
num_workers: int = NUM_INDEXING_WORKERS,
|
||||
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
|
||||
) -> None:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
logger.info("Running a first inference to warm up embedding model")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
logger.info("Running a first inference to warm up embedding model")
|
||||
warm_up_encoders(
|
||||
embedding_model=db_embedding_model,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
client_primary: Client | SimpleJobClient
|
||||
client_secondary: Client | SimpleJobClient
|
||||
@@ -366,7 +375,7 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
cluster_secondary = LocalCluster(
|
||||
n_workers=num_workers,
|
||||
n_workers=num_secondary_workers,
|
||||
threads_per_worker=1,
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
@@ -376,7 +385,7 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
client_primary.register_worker_plugin(ResourceLogger())
|
||||
else:
|
||||
client_primary = SimpleJobClient(n_workers=num_workers)
|
||||
client_secondary = SimpleJobClient(n_workers=num_workers)
|
||||
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
|
||||
|
||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||
|
||||
@@ -411,6 +420,7 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
|
||||
def update__main() -> None:
|
||||
set_is_ee_based_on_env_variable()
|
||||
init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME)
|
||||
|
||||
logger.info("Starting Indexing Loop")
|
||||
update_loop()
|
||||
|
||||
@@ -35,6 +35,7 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
|
||||
def create_chat_chain(
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
prefetch_tool_calls: bool = True,
|
||||
) -> tuple[ChatMessage, list[ChatMessage]]:
|
||||
"""Build the linear chain of messages without including the root message"""
|
||||
mainline_messages: list[ChatMessage] = []
|
||||
@@ -43,6 +44,7 @@ def create_chat_chain(
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
skip_permission_check=True,
|
||||
prefetch_tool_calls=prefetch_tool_calls,
|
||||
)
|
||||
id_to_msg = {msg.id: msg for msg in all_chat_messages}
|
||||
|
||||
|
||||
24
backend/danswer/chat/input_prompts.yaml
Normal file
24
backend/danswer/chat/input_prompts.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
input_prompts:
|
||||
- id: -5
|
||||
prompt: "Elaborate"
|
||||
content: "Elaborate on the above, give me a more in depth explanation."
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -4
|
||||
prompt: "Reword"
|
||||
content: "Help me rewrite the following politely and concisely for professional communication:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -3
|
||||
prompt: "Email"
|
||||
content: "Write a professional email for me including a subject line, signature, etc. Template the parts that need editing with [ ]. The email should cover the following points:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -2
|
||||
prompt: "Debug"
|
||||
content: "Provide step-by-step troubleshooting instructions for the following issue:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
@@ -1,11 +1,13 @@
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.chat_configs import INPUT_PROMPT_YAML
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import PERSONAS_YAML
|
||||
from danswer.configs.chat_configs import PROMPTS_YAML
|
||||
from danswer.db.document_set import get_or_create_document_set_by_name
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import Prompt as PromptDBModel
|
||||
from danswer.db.persona import get_prompt_by_name
|
||||
@@ -88,6 +90,8 @@ def load_personas_from_yaml(
|
||||
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
||||
starter_messages=persona.get("starter_messages"),
|
||||
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
||||
icon_shape=persona.get("icon_shape"),
|
||||
icon_color=persona.get("icon_color"),
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
@@ -99,9 +103,32 @@ def load_personas_from_yaml(
|
||||
)
|
||||
|
||||
|
||||
def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None:
|
||||
with open(input_prompts_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_input_prompts = data.get("input_prompts", [])
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for input_prompt in all_input_prompts:
|
||||
# If these prompts are deleted (which is a hard delete in the DB), on server startup
|
||||
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
|
||||
insert_input_prompt_if_not_exists(
|
||||
user=None,
|
||||
input_prompt_id=input_prompt.get("id"),
|
||||
prompt=input_prompt["prompt"],
|
||||
content=input_prompt["content"],
|
||||
is_public=input_prompt["is_public"],
|
||||
active=input_prompt.get("active", True),
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
|
||||
def load_chat_yamls(
|
||||
prompt_yaml: str = PROMPTS_YAML,
|
||||
personas_yaml: str = PERSONAS_YAML,
|
||||
input_prompts_yaml: str = INPUT_PROMPT_YAML,
|
||||
) -> None:
|
||||
load_prompts_from_yaml(prompt_yaml)
|
||||
load_personas_from_yaml(personas_yaml)
|
||||
load_input_prompts_from_yaml(input_prompts_yaml)
|
||||
|
||||
@@ -37,7 +37,8 @@ personas:
|
||||
# - "Engineer Onboarding"
|
||||
# - "Benefits"
|
||||
document_sets: []
|
||||
|
||||
icon_shape: 23013
|
||||
icon_color: "#6FB1FF"
|
||||
|
||||
- id: 1
|
||||
name: "GPT"
|
||||
@@ -50,6 +51,8 @@ personas:
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
icon_shape: 50910
|
||||
icon_color: "#FF6F6F"
|
||||
|
||||
|
||||
- id: 2
|
||||
@@ -63,3 +66,6 @@ personas:
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
icon_shape: 45519
|
||||
icon_color: "#6FFF8D"
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
@@ -187,37 +187,46 @@ def _handle_internet_search_tool_response_summary(
|
||||
)
|
||||
|
||||
|
||||
def _check_should_force_search(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
) -> ForceUseTool | None:
|
||||
# If files are already provided, don't run the search tool
|
||||
def _get_force_search_settings(
|
||||
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
|
||||
) -> ForceUseTool:
|
||||
internet_search_available = any(
|
||||
isinstance(tool, InternetSearchTool) for tool in tools
|
||||
)
|
||||
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
|
||||
|
||||
if not internet_search_available and not search_tool_available:
|
||||
# Does not matter much which tool is set here as force is false and neither tool is available
|
||||
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
|
||||
|
||||
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
|
||||
# Currently, the internet search tool does not support query override
|
||||
args = (
|
||||
{"query": new_msg_req.query_override}
|
||||
if new_msg_req.query_override and tool_name == SearchTool._NAME
|
||||
else None
|
||||
)
|
||||
|
||||
if new_msg_req.file_descriptors:
|
||||
return None
|
||||
# If user has uploaded files they're using, don't run any of the search tools
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name)
|
||||
|
||||
if (
|
||||
new_msg_req.query_override
|
||||
or (
|
||||
should_force_search = any(
|
||||
[
|
||||
new_msg_req.retrieval_options
|
||||
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
|
||||
)
|
||||
or new_msg_req.search_doc_ids
|
||||
or DISABLE_LLM_CHOOSE_SEARCH
|
||||
):
|
||||
args = (
|
||||
{"query": new_msg_req.query_override}
|
||||
if new_msg_req.query_override
|
||||
else None
|
||||
)
|
||||
# if we are using selected docs, just put something here so the Tool doesn't need
|
||||
# to build its own args via an LLM call
|
||||
if new_msg_req.search_doc_ids:
|
||||
args = {"query": new_msg_req.message}
|
||||
and new_msg_req.retrieval_options.run_search
|
||||
== OptionalSearchSetting.ALWAYS,
|
||||
new_msg_req.search_doc_ids,
|
||||
DISABLE_LLM_CHOOSE_SEARCH,
|
||||
]
|
||||
)
|
||||
|
||||
return ForceUseTool(
|
||||
tool_name=SearchTool._NAME,
|
||||
args=args,
|
||||
)
|
||||
return None
|
||||
if should_force_search:
|
||||
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
|
||||
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
|
||||
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
|
||||
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
|
||||
|
||||
|
||||
ChatPacket = (
|
||||
@@ -253,7 +262,6 @@ def stream_chat_message_objects(
|
||||
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
|
||||
"""
|
||||
try:
|
||||
user_id = user.id if user is not None else None
|
||||
@@ -274,7 +282,10 @@ def stream_chat_message_objects(
|
||||
# use alternate persona if alternative assistant id is passed in
|
||||
if alternate_assistant_id is not None:
|
||||
persona = get_persona_by_id(
|
||||
alternate_assistant_id, user=user, db_session=db_session
|
||||
alternate_assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
else:
|
||||
persona = chat_session.persona
|
||||
@@ -297,7 +308,13 @@ def stream_chat_message_objects(
|
||||
except GenAIDisabledException:
|
||||
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
|
||||
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
llm_provider = llm.config.model_provider
|
||||
llm_model_name = llm.config.model_name
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm_model_name,
|
||||
provider_type=llm_provider,
|
||||
)
|
||||
llm_tokenizer_encode_func = cast(
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
@@ -361,6 +378,14 @@ def stream_chat_message_objects(
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
|
||||
# Disable Query Rephrasing for the first message
|
||||
# This leads to a better first response since the LLM rephrasing the question
|
||||
# leads to worst search quality
|
||||
if not history_msgs:
|
||||
new_msg_req.query_override = (
|
||||
new_msg_req.query_override or new_msg_req.message
|
||||
)
|
||||
|
||||
# load all files needed for this chat chain in memory
|
||||
files = load_all_chat_files(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
@@ -544,9 +569,11 @@ def stream_chat_message_objects(
|
||||
tools.extend(tool_list)
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(tools)
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
tools, llm_tokenizer
|
||||
)
|
||||
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
llm_provider, llm_model_name
|
||||
)
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
@@ -576,11 +603,7 @@ def stream_chat_message_objects(
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
],
|
||||
tools=tools,
|
||||
force_use_tool=(
|
||||
_check_should_force_search(new_msg_req)
|
||||
if search_tool and len(tools) == 1
|
||||
else None
|
||||
),
|
||||
force_use_tool=_get_force_search_settings(new_msg_req, tools),
|
||||
)
|
||||
|
||||
reference_db_search_docs = None
|
||||
|
||||
@@ -214,8 +214,8 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
|
||||
|
||||
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
|
||||
|
||||
PREVENT_SIMULTANEOUS_PRUNING = (
|
||||
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||
ALLOW_SIMULTANEOUS_PRUNING = (
|
||||
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||
)
|
||||
|
||||
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
|
||||
@@ -248,6 +248,9 @@ DISABLE_INDEX_UPDATE_ON_SWAP = (
|
||||
# fairly large amount of memory in order to increase substantially, since
|
||||
# each worker loads the embedding models into memory.
|
||||
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
|
||||
NUM_SECONDARY_INDEXING_WORKERS = int(
|
||||
os.environ.get("NUM_SECONDARY_INDEXING_WORKERS") or NUM_INDEXING_WORKERS
|
||||
)
|
||||
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
|
||||
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
|
||||
# Finer grained chunking for more detail retention
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
|
||||
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
|
||||
PERSONAS_YAML = "./danswer/chat/personas.yaml"
|
||||
INPUT_PROMPT_YAML = "./danswer/chat/input_prompts.yaml"
|
||||
|
||||
NUM_RETURNED_HITS = 50
|
||||
# Used for LLM filtering and reranking
|
||||
|
||||
@@ -44,7 +44,6 @@ QUERY_EVENT_ID = "query_event_id"
|
||||
LLM_CHUNKS = "llm_chunks"
|
||||
|
||||
# For chunking/processing chunks
|
||||
MAX_CHUNK_TITLE_LEN = 1000
|
||||
RETURN_SEPARATOR = "\n\r\n"
|
||||
SECTION_SEPARATOR = "\n\n"
|
||||
# For combining attributes, doesn't have to be unique/perfect to work
|
||||
@@ -60,6 +59,14 @@ DISABLED_GEN_AI_MSG = (
|
||||
"You can still use Danswer as a search engine."
|
||||
)
|
||||
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
POSTGRES_CELERY_APP_NAME = "celery"
|
||||
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
|
||||
POSTGRES_CELERY_WORKER_APP_NAME = "celery_worker"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
|
||||
# API Keys
|
||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
|
||||
@@ -12,7 +12,7 @@ import os
|
||||
# The useable models configured as below must be SentenceTransformer compatible
|
||||
# NOTE: DO NOT CHANGE SET THESE UNLESS YOU KNOW WHAT YOU ARE DOING
|
||||
# IDEALLY, YOU SHOULD CHANGE EMBEDDING MODELS VIA THE UI
|
||||
DEFAULT_DOCUMENT_ENCODER_MODEL = "intfloat/e5-base-v2"
|
||||
DEFAULT_DOCUMENT_ENCODER_MODEL = "nomic-ai/nomic-embed-text-v1"
|
||||
DOCUMENT_ENCODER_MODEL = (
|
||||
os.environ.get("DOCUMENT_ENCODER_MODEL") or DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
)
|
||||
@@ -34,8 +34,8 @@ OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS = False
|
||||
SIM_SCORE_RANGE_LOW = float(os.environ.get("SIM_SCORE_RANGE_LOW") or 0.0)
|
||||
SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
|
||||
# Certain models like e5, BGE, etc use a prefix for asymmetric retrievals (query generally shorter than docs)
|
||||
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "query: ")
|
||||
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
|
||||
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "search_query: ")
|
||||
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "search_document: ")
|
||||
# Purely an optimization, memory limitation consideration
|
||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||
# For score display purposes, only way is to know the expected ranges
|
||||
|
||||
@@ -56,6 +56,16 @@ def extract_text_from_content(content: dict) -> str:
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
|
||||
if hasattr(jira_issue.fields, field):
|
||||
return getattr(jira_issue.fields, field)
|
||||
|
||||
try:
|
||||
return jira_issue.raw["fields"][field]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_comment_strs(
|
||||
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
|
||||
) -> list[str]:
|
||||
@@ -117,8 +127,10 @@ def fetch_jira_issues_batch(
|
||||
continue
|
||||
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
semantic_rep = f"{jira.fields.description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments]
|
||||
semantic_rep = (
|
||||
f"{jira.fields.description}\n"
|
||||
if jira.fields.description
|
||||
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
|
||||
)
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
@@ -147,14 +159,18 @@ def fetch_jira_issues_batch(
|
||||
pass
|
||||
|
||||
metadata_dict = {}
|
||||
if jira.fields.priority:
|
||||
metadata_dict["priority"] = jira.fields.priority.name
|
||||
if jira.fields.status:
|
||||
metadata_dict["status"] = jira.fields.status.name
|
||||
if jira.fields.resolution:
|
||||
metadata_dict["resolution"] = jira.fields.resolution.name
|
||||
if jira.fields.labels:
|
||||
metadata_dict["label"] = jira.fields.labels
|
||||
priority = best_effort_get_field_from_issue(jira, "priority")
|
||||
if priority:
|
||||
metadata_dict["priority"] = priority.name
|
||||
status = best_effort_get_field_from_issue(jira, "status")
|
||||
if status:
|
||||
metadata_dict["status"] = status.name
|
||||
resolution = best_effort_get_field_from_issue(jira, "resolution")
|
||||
if resolution:
|
||||
metadata_dict["resolution"] = resolution.name
|
||||
labels = best_effort_get_field_from_issue(jira, "labels")
|
||||
if labels:
|
||||
metadata_dict["label"] = labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
|
||||
@@ -64,7 +64,7 @@ class DiscourseConnector(PollConnector):
|
||||
self.permissions: DiscoursePerms | None = None
|
||||
self.active_categories: set | None = None
|
||||
|
||||
@rate_limit_builder(max_calls=100, period=60)
|
||||
@rate_limit_builder(max_calls=50, period=60)
|
||||
def _make_request(self, endpoint: str, params: dict | None = None) -> Response:
|
||||
if not self.permissions:
|
||||
raise ConnectorMissingCredentialError("Discourse")
|
||||
|
||||
@@ -11,6 +11,7 @@ from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.gmail.constants import CRED_KEY
|
||||
from danswer.connectors.gmail.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
@@ -146,6 +147,7 @@ def build_service_account_creds(
|
||||
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
|
||||
|
||||
return CredentialBase(
|
||||
source=DocumentSource.GMAIL,
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.google_drive.constants import CRED_KEY
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
@@ -118,6 +119,7 @@ def update_credential_access_tokens(
|
||||
|
||||
|
||||
def build_service_account_creds(
|
||||
source: DocumentSource,
|
||||
delegated_user_email: str | None = None,
|
||||
) -> CredentialBase:
|
||||
service_account_key = get_service_account_key()
|
||||
@@ -131,6 +133,7 @@ def build_service_account_creds(
|
||||
return CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -86,7 +86,6 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
categories: The categories to include in the index.
|
||||
pages: The pages to include in the index.
|
||||
recurse_depth: The depth to recurse into categories. -1 means unbounded recursion.
|
||||
connector_name: The name of the connector.
|
||||
language_code: The language code of the wiki.
|
||||
batch_size: The batch size for loading documents.
|
||||
|
||||
@@ -104,7 +103,6 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
categories: list[str],
|
||||
pages: list[str],
|
||||
recurse_depth: int,
|
||||
connector_name: str,
|
||||
language_code: str = "en",
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
@@ -118,10 +116,8 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
|
||||
# short names can only have ascii letters and digits
|
||||
self.connector_name = connector_name
|
||||
connector_name = "".join(ch for ch in connector_name if ch.isalnum())
|
||||
|
||||
self.family = family_class_dispatch(hostname, connector_name)()
|
||||
self.family = family_class_dispatch(hostname, "Wikipedia Connector")()
|
||||
self.site = pywikibot.Site(fam=self.family, code=language_code)
|
||||
self.categories = [
|
||||
pywikibot.Category(self.site, f"Category:{category.replace(' ', '_')}")
|
||||
@@ -210,7 +206,6 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
if __name__ == "__main__":
|
||||
HOSTNAME = "fallout.fandom.com"
|
||||
test_connector = MediaWikiConnector(
|
||||
connector_name="Fallout",
|
||||
hostname=HOSTNAME,
|
||||
categories=["Fallout:_New_Vegas_factions"],
|
||||
pages=["Fallout: New Vegas"],
|
||||
|
||||
@@ -114,7 +114,9 @@ class DocumentBase(BaseModel):
|
||||
title: str | None = None
|
||||
from_ingestion_api: bool = False
|
||||
|
||||
def get_title_for_document_index(self) -> str | None:
|
||||
def get_title_for_document_index(
|
||||
self,
|
||||
) -> str | None:
|
||||
# If title is explicitly empty, return a None here for embedding purposes
|
||||
if self.title == "":
|
||||
return None
|
||||
|
||||
@@ -15,6 +15,7 @@ from playwright.sync_api import BrowserContext
|
||||
from playwright.sync_api import Playwright
|
||||
from playwright.sync_api import sync_playwright
|
||||
from requests_oauthlib import OAuth2Session # type:ignore
|
||||
from urllib3.exceptions import MaxRetryError
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import WEB_CONNECTOR_OAUTH_CLIENT_ID
|
||||
@@ -83,6 +84,13 @@ def check_internet_connection(url: str) -> None:
|
||||
try:
|
||||
response = requests.get(url, timeout=3)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.SSLError as e:
|
||||
cause = (
|
||||
e.args[0].reason
|
||||
if isinstance(e.args, tuple) and isinstance(e.args[0], MaxRetryError)
|
||||
else e.args
|
||||
)
|
||||
raise Exception(f"SSL error {str(cause)}")
|
||||
except (requests.RequestException, ValueError):
|
||||
raise Exception(f"Unable to reach {url} - check your internet connection")
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ class WikipediaConnector(wiki.MediaWikiConnector):
|
||||
categories: list[str],
|
||||
pages: list[str],
|
||||
recurse_depth: int,
|
||||
connector_name: str,
|
||||
language_code: str = "en",
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
@@ -24,7 +23,6 @@ class WikipediaConnector(wiki.MediaWikiConnector):
|
||||
categories=categories,
|
||||
pages=pages,
|
||||
recurse_depth=recurse_depth,
|
||||
connector_name=connector_name,
|
||||
language_code=language_code,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
@@ -50,9 +50,9 @@ from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
@@ -471,8 +471,7 @@ if __name__ == "__main__":
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
if embedding_model.cloud_provider_id is None:
|
||||
warm_up_encoders(
|
||||
model_name=embedding_model.model_name,
|
||||
normalize=embedding_model.normalize,
|
||||
embedding_model=embedding_model,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.server.documents.models import ConnectorBase
|
||||
from danswer.server.documents.models import ObjectCreationIdResponse
|
||||
@@ -85,6 +86,7 @@ def create_connector(
|
||||
input_type=connector_data.input_type,
|
||||
connector_specific_config=connector_data.connector_specific_config,
|
||||
refresh_freq=connector_data.refresh_freq,
|
||||
indexing_start=connector_data.indexing_start,
|
||||
prune_freq=connector_data.prune_freq
|
||||
if connector_data.prune_freq is not None
|
||||
else DEFAULT_PRUNING_FREQ,
|
||||
@@ -191,7 +193,8 @@ def fetch_latest_index_attempt_by_connector(
|
||||
for connector in connectors:
|
||||
latest_index_attempt = (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(IndexAttempt.connector_id == connector.id)
|
||||
.join(ConnectorCredentialPair)
|
||||
.filter(ConnectorCredentialPair.connector_id == connector.id)
|
||||
.order_by(IndexAttempt.time_updated.desc())
|
||||
.first()
|
||||
)
|
||||
@@ -207,13 +210,11 @@ def fetch_latest_index_attempts_by_status(
|
||||
) -> list[IndexAttempt]:
|
||||
subquery = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_id,
|
||||
IndexAttempt.credential_id,
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
IndexAttempt.status,
|
||||
func.max(IndexAttempt.time_updated).label("time_updated"),
|
||||
)
|
||||
.group_by(IndexAttempt.connector_id)
|
||||
.group_by(IndexAttempt.credential_id)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
.group_by(IndexAttempt.status)
|
||||
.subquery()
|
||||
)
|
||||
@@ -223,12 +224,13 @@ def fetch_latest_index_attempts_by_status(
|
||||
query = db_session.query(IndexAttempt).join(
|
||||
alias,
|
||||
and_(
|
||||
IndexAttempt.connector_id == alias.connector_id,
|
||||
IndexAttempt.credential_id == alias.credential_id,
|
||||
IndexAttempt.connector_credential_pair_id
|
||||
== alias.connector_credential_pair_id,
|
||||
IndexAttempt.status == alias.status,
|
||||
IndexAttempt.time_updated == alias.time_updated,
|
||||
),
|
||||
)
|
||||
|
||||
return cast(list[IndexAttempt], query.all())
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy import desc
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.credentials import fetch_credential_by_id
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
@@ -42,6 +43,17 @@ def get_connector_credential_pair(
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
def get_connector_credential_source_from_id(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
) -> DocumentSource | None:
|
||||
stmt = select(ConnectorCredentialPair)
|
||||
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
|
||||
result = db_session.execute(stmt)
|
||||
cc_pair = result.scalar_one_or_none()
|
||||
return cc_pair.connector.source if cc_pair else None
|
||||
|
||||
|
||||
def get_connector_credential_pair_from_id(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
@@ -75,17 +87,23 @@ def get_last_successful_attempt_time(
|
||||
# For Secondary Index we don't keep track of the latest success, so have to calculate it live
|
||||
attempt = (
|
||||
db_session.query(IndexAttempt)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.filter(
|
||||
IndexAttempt.connector_id == connector_id,
|
||||
IndexAttempt.credential_id == credential_id,
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
IndexAttempt.embedding_model_id == embedding_model.id,
|
||||
IndexAttempt.status == IndexingStatus.SUCCESS,
|
||||
)
|
||||
.order_by(IndexAttempt.time_started.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if not attempt or not attempt.time_started:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
if connector and connector.indexing_start:
|
||||
return connector.indexing_start.timestamp()
|
||||
return 0.0
|
||||
|
||||
return attempt.time_started.timestamp()
|
||||
@@ -241,6 +259,12 @@ def remove_credential_from_connector(
|
||||
)
|
||||
|
||||
|
||||
def fetch_connector_credential_pairs(
|
||||
db_session: Session,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
return db_session.query(ConnectorCredentialPair).all()
|
||||
|
||||
|
||||
def resync_cc_pair(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
@@ -253,10 +277,14 @@ def resync_cc_pair(
|
||||
) -> IndexAttempt | None:
|
||||
query = (
|
||||
db_session.query(IndexAttempt)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(EmbeddingModel, IndexAttempt.embedding_model_id == EmbeddingModel.id)
|
||||
.filter(
|
||||
IndexAttempt.connector_id == connector_id,
|
||||
IndexAttempt.credential_id == credential_id,
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
EmbeddingModel.status == IndexModelStatus.PRESENT,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2,10 +2,13 @@ from typing import Any
|
||||
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import and_
|
||||
from sqlalchemy.sql.expression import or_
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.gmail.constants import (
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
@@ -14,8 +17,10 @@ from danswer.connectors.google_drive.constants import (
|
||||
)
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import User
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import CredentialDataUpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -74,6 +79,69 @@ def fetch_credential_by_id(
|
||||
return credential
|
||||
|
||||
|
||||
def fetch_credentials_by_source(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
document_source: DocumentSource | None = None,
|
||||
) -> list[Credential]:
|
||||
base_query = select(Credential).where(Credential.source == document_source)
|
||||
base_query = _attach_user_filters(base_query, user)
|
||||
credentials = db_session.execute(base_query).scalars().all()
|
||||
return list(credentials)
|
||||
|
||||
|
||||
def swap_credentials_connector(
|
||||
new_credential_id: int, connector_id: int, user: User | None, db_session: Session
|
||||
) -> ConnectorCredentialPair:
|
||||
# Check if the user has permission to use the new credential
|
||||
new_credential = fetch_credential_by_id(new_credential_id, user, db_session)
|
||||
if not new_credential:
|
||||
raise ValueError(
|
||||
f"No Credential found with id {new_credential_id} or user doesn't have permission to use it"
|
||||
)
|
||||
|
||||
# Existing pair
|
||||
existing_pair = db_session.execute(
|
||||
select(ConnectorCredentialPair).where(
|
||||
ConnectorCredentialPair.connector_id == connector_id
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not existing_pair:
|
||||
raise ValueError(
|
||||
f"No ConnectorCredentialPair found for connector_id {connector_id}"
|
||||
)
|
||||
|
||||
# Check if the new credential is compatible with the connector
|
||||
if new_credential.source != existing_pair.connector.source:
|
||||
raise ValueError(
|
||||
f"New credential source {new_credential.source} does not match connector source {existing_pair.connector.source}"
|
||||
)
|
||||
|
||||
db_session.execute(
|
||||
update(DocumentByConnectorCredentialPair)
|
||||
.where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== existing_pair.credential_id,
|
||||
)
|
||||
)
|
||||
.values(credential_id=new_credential_id)
|
||||
)
|
||||
|
||||
# Update the existing pair with the new credential
|
||||
existing_pair.credential_id = new_credential_id
|
||||
existing_pair.credential = new_credential
|
||||
|
||||
# Commit the changes
|
||||
db_session.commit()
|
||||
|
||||
# Refresh the object to ensure all relationships are up-to-date
|
||||
db_session.refresh(existing_pair)
|
||||
return existing_pair
|
||||
|
||||
|
||||
def create_credential(
|
||||
credential_data: CredentialBase,
|
||||
user: User | None,
|
||||
@@ -83,6 +151,8 @@ def create_credential(
|
||||
credential_json=credential_data.credential_json,
|
||||
user_id=user.id if user else None,
|
||||
admin_public=credential_data.admin_public,
|
||||
source=credential_data.source,
|
||||
name=credential_data.name,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.commit()
|
||||
@@ -90,6 +160,30 @@ def create_credential(
|
||||
return credential
|
||||
|
||||
|
||||
def alter_credential(
|
||||
credential_id: int,
|
||||
credential_data: CredentialDataUpdateRequest,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> Credential | None:
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
credential.name = credential_data.name
|
||||
credential.name = credential_data.name
|
||||
|
||||
# Update only the keys present in credential_data.credential_json
|
||||
for key, value in credential_data.credential_json.items():
|
||||
credential.credential_json[key] = value
|
||||
|
||||
credential.user_id = user.id if user is not None else None
|
||||
|
||||
db_session.commit()
|
||||
return credential
|
||||
|
||||
|
||||
def update_credential(
|
||||
credential_id: int,
|
||||
credential_data: CredentialBase,
|
||||
@@ -136,6 +230,7 @@ def delete_credential(
|
||||
credential_id: int,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
if credential is None:
|
||||
@@ -149,11 +244,38 @@ def delete_credential(
|
||||
.all()
|
||||
)
|
||||
|
||||
if associated_connectors:
|
||||
raise ValueError(
|
||||
f"Cannot delete credential {credential_id} as it is still associated with {len(associated_connectors)} connector(s). "
|
||||
"Please delete all associated connectors first."
|
||||
)
|
||||
associated_doc_cc_pairs = (
|
||||
db_session.query(DocumentByConnectorCredentialPair)
|
||||
.filter(DocumentByConnectorCredentialPair.credential_id == credential_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
if associated_connectors or associated_doc_cc_pairs:
|
||||
if force:
|
||||
logger.warning(
|
||||
f"Force deleting credential {credential_id} and its associated records"
|
||||
)
|
||||
|
||||
# Delete DocumentByConnectorCredentialPair records first
|
||||
for doc_cc_pair in associated_doc_cc_pairs:
|
||||
db_session.delete(doc_cc_pair)
|
||||
|
||||
# Then delete ConnectorCredentialPair records
|
||||
for connector in associated_connectors:
|
||||
db_session.delete(connector)
|
||||
|
||||
# Commit these deletions before deleting the credential
|
||||
db_session.flush()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot delete credential as it is still associated with "
|
||||
f"{len(associated_connectors)} connector(s) and {len(associated_doc_cc_pairs)} document(s). "
|
||||
)
|
||||
|
||||
if force:
|
||||
logger.info(f"Force deleting credential {credential_id}")
|
||||
else:
|
||||
logger.info(f"Deleting credential {credential_id}")
|
||||
|
||||
db_session.delete(credential)
|
||||
db_session.commit()
|
||||
|
||||
@@ -15,7 +15,7 @@ from danswer.db.models import CloudEmbeddingProvider
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.search.search_nlp_models import clean_model_name
|
||||
from danswer.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from danswer.server.manage.embedding.models import (
|
||||
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from danswer.configs.app_configs import POSTGRES_HOST
|
||||
from danswer.configs.app_configs import POSTGRES_PASSWORD
|
||||
from danswer.configs.app_configs import POSTGRES_PORT
|
||||
from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -25,12 +26,18 @@ logger = setup_logger()
|
||||
SYNC_DB_API = "psycopg2"
|
||||
ASYNC_DB_API = "asyncpg"
|
||||
|
||||
POSTGRES_APP_NAME = (
|
||||
POSTGRES_UNKNOWN_APP_NAME # helps to diagnose open connections in postgres
|
||||
)
|
||||
|
||||
# global so we don't create more than one engine per process
|
||||
# outside of being best practice, this is needed so we can properly pool
|
||||
# connections and not create a new pool on every request
|
||||
_SYNC_ENGINE: Engine | None = None
|
||||
_ASYNC_ENGINE: AsyncEngine | None = None
|
||||
|
||||
SessionFactory: sessionmaker[Session] | None = None
|
||||
|
||||
|
||||
def get_db_current_time(db_session: Session) -> datetime:
|
||||
"""Get the current time from Postgres representing the start of the transaction
|
||||
@@ -51,14 +58,25 @@ def build_connection_string(
|
||||
host: str = POSTGRES_HOST,
|
||||
port: str = POSTGRES_PORT,
|
||||
db: str = POSTGRES_DB,
|
||||
app_name: str | None = None,
|
||||
) -> str:
|
||||
if app_name:
|
||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
|
||||
|
||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
||||
|
||||
|
||||
def init_sqlalchemy_engine(app_name: str) -> None:
|
||||
global POSTGRES_APP_NAME
|
||||
POSTGRES_APP_NAME = app_name
|
||||
|
||||
|
||||
def get_sqlalchemy_engine() -> Engine:
|
||||
global _SYNC_ENGINE
|
||||
if _SYNC_ENGINE is None:
|
||||
connection_string = build_connection_string(db_api=SYNC_DB_API)
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync"
|
||||
)
|
||||
_SYNC_ENGINE = create_engine(connection_string, pool_size=40, max_overflow=10)
|
||||
return _SYNC_ENGINE
|
||||
|
||||
@@ -66,9 +84,16 @@ def get_sqlalchemy_engine() -> Engine:
|
||||
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
global _ASYNC_ENGINE
|
||||
if _ASYNC_ENGINE is None:
|
||||
# underlying asyncpg cannot accept application_name directly in the connection string
|
||||
# https://github.com/MagicStack/asyncpg/issues/798
|
||||
connection_string = build_connection_string()
|
||||
_ASYNC_ENGINE = create_async_engine(
|
||||
connection_string, pool_size=40, max_overflow=10
|
||||
connection_string,
|
||||
connect_args={
|
||||
"server_settings": {"application_name": POSTGRES_APP_NAME + "_async"}
|
||||
},
|
||||
pool_size=40,
|
||||
max_overflow=10,
|
||||
)
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
@@ -115,4 +140,8 @@ async def warm_up_connections(
|
||||
await async_conn.close()
|
||||
|
||||
|
||||
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
|
||||
def get_session_factory() -> sessionmaker[Session]:
|
||||
global SessionFactory
|
||||
if SessionFactory is None:
|
||||
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
|
||||
return SessionFactory
|
||||
|
||||
@@ -15,6 +15,7 @@ from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.server.documents.models import ConnectorCredentialPair
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
@@ -23,6 +24,22 @@ from danswer.utils.telemetry import RecordType
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_last_attempt_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
embedding_model_id: int,
|
||||
db_session: Session,
|
||||
) -> IndexAttempt | None:
|
||||
return (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair_id,
|
||||
IndexAttempt.embedding_model_id == embedding_model_id,
|
||||
)
|
||||
.order_by(IndexAttempt.time_updated.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
def get_index_attempt(
|
||||
db_session: Session, index_attempt_id: int
|
||||
) -> IndexAttempt | None:
|
||||
@@ -31,15 +48,13 @@ def get_index_attempt(
|
||||
|
||||
|
||||
def create_index_attempt(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
connector_credential_pair_id: int,
|
||||
embedding_model_id: int,
|
||||
db_session: Session,
|
||||
from_beginning: bool = False,
|
||||
) -> int:
|
||||
new_attempt = IndexAttempt(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
connector_credential_pair_id=connector_credential_pair_id,
|
||||
embedding_model_id=embedding_model_id,
|
||||
from_beginning=from_beginning,
|
||||
status=IndexingStatus.NOT_STARTED,
|
||||
@@ -56,7 +71,9 @@ def get_inprogress_index_attempts(
|
||||
) -> list[IndexAttempt]:
|
||||
stmt = select(IndexAttempt)
|
||||
if connector_id is not None:
|
||||
stmt = stmt.where(IndexAttempt.connector_id == connector_id)
|
||||
stmt = stmt.where(
|
||||
IndexAttempt.connector_credential_pair.has(connector_id=connector_id)
|
||||
)
|
||||
stmt = stmt.where(IndexAttempt.status == IndexingStatus.IN_PROGRESS)
|
||||
|
||||
incomplete_attempts = db_session.scalars(stmt)
|
||||
@@ -65,21 +82,31 @@ def get_inprogress_index_attempts(
|
||||
|
||||
def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]:
|
||||
"""This eagerly loads the connector and credential so that the db_session can be expired
|
||||
before running long-living indexing jobs, which causes increasing memory usage"""
|
||||
before running long-living indexing jobs, which causes increasing memory usage.
|
||||
|
||||
Results are ordered by time_created (oldest to newest)."""
|
||||
stmt = select(IndexAttempt)
|
||||
stmt = stmt.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
|
||||
stmt = stmt.order_by(IndexAttempt.time_created)
|
||||
stmt = stmt.options(
|
||||
joinedload(IndexAttempt.connector), joinedload(IndexAttempt.credential)
|
||||
joinedload(IndexAttempt.connector_credential_pair).joinedload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
joinedload(IndexAttempt.connector_credential_pair).joinedload(
|
||||
ConnectorCredentialPair.credential
|
||||
),
|
||||
)
|
||||
new_attempts = db_session.scalars(stmt)
|
||||
return list(new_attempts.all())
|
||||
|
||||
|
||||
def mark_attempt_in_progress__no_commit(
|
||||
def mark_attempt_in_progress(
|
||||
index_attempt: IndexAttempt,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
index_attempt.status = IndexingStatus.IN_PROGRESS
|
||||
index_attempt.time_started = index_attempt.time_started or func.now() # type: ignore
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_attempt_succeeded(
|
||||
@@ -103,7 +130,7 @@ def mark_attempt_failed(
|
||||
db_session.add(index_attempt)
|
||||
db_session.commit()
|
||||
|
||||
source = index_attempt.connector.source
|
||||
source = index_attempt.connector_credential_pair.connector.source
|
||||
optional_telemetry(record_type=RecordType.FAILURE, data={"connector": source})
|
||||
|
||||
|
||||
@@ -128,11 +155,16 @@ def get_last_attempt(
|
||||
embedding_model_id: int | None,
|
||||
db_session: Session,
|
||||
) -> IndexAttempt | None:
|
||||
stmt = select(IndexAttempt).where(
|
||||
IndexAttempt.connector_id == connector_id,
|
||||
IndexAttempt.credential_id == credential_id,
|
||||
IndexAttempt.embedding_model_id == embedding_model_id,
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.join(ConnectorCredentialPair)
|
||||
.where(
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
IndexAttempt.embedding_model_id == embedding_model_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Note, the below is using time_created instead of time_updated
|
||||
stmt = stmt.order_by(desc(IndexAttempt.time_created))
|
||||
|
||||
@@ -145,8 +177,7 @@ def get_latest_index_attempts(
|
||||
db_session: Session,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
ids_stmt = select(
|
||||
IndexAttempt.connector_id,
|
||||
IndexAttempt.credential_id,
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.max(IndexAttempt.time_created).label("max_time_created"),
|
||||
).join(EmbeddingModel, IndexAttempt.embedding_model_id == EmbeddingModel.id)
|
||||
|
||||
@@ -158,43 +189,95 @@ def get_latest_index_attempts(
|
||||
where_stmts: list[ColumnElement] = []
|
||||
for connector_credential_pair_identifier in connector_credential_pair_identifiers:
|
||||
where_stmts.append(
|
||||
and_(
|
||||
IndexAttempt.connector_id
|
||||
== connector_credential_pair_identifier.connector_id,
|
||||
IndexAttempt.credential_id
|
||||
== connector_credential_pair_identifier.credential_id,
|
||||
IndexAttempt.connector_credential_pair_id
|
||||
== (
|
||||
select(ConnectorCredentialPair.id)
|
||||
.where(
|
||||
ConnectorCredentialPair.connector_id
|
||||
== connector_credential_pair_identifier.connector_id,
|
||||
ConnectorCredentialPair.credential_id
|
||||
== connector_credential_pair_identifier.credential_id,
|
||||
)
|
||||
.scalar_subquery()
|
||||
)
|
||||
)
|
||||
if where_stmts:
|
||||
ids_stmt = ids_stmt.where(or_(*where_stmts))
|
||||
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_id, IndexAttempt.credential_id)
|
||||
ids_subqery = ids_stmt.subquery()
|
||||
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
ids_subquery = ids_stmt.subquery()
|
||||
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.join(
|
||||
ids_subqery,
|
||||
and_(
|
||||
ids_subqery.c.connector_id == IndexAttempt.connector_id,
|
||||
ids_subqery.c.credential_id == IndexAttempt.credential_id,
|
||||
),
|
||||
ids_subquery,
|
||||
IndexAttempt.connector_credential_pair_id
|
||||
== ids_subquery.c.connector_credential_pair_id,
|
||||
)
|
||||
.where(IndexAttempt.time_created == ids_subqery.c.max_time_created)
|
||||
.where(IndexAttempt.time_created == ids_subquery.c.max_time_created)
|
||||
)
|
||||
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def get_index_attempts_for_connector(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
only_current: bool = True,
|
||||
disinclude_finished: bool = False,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.join(ConnectorCredentialPair)
|
||||
.where(ConnectorCredentialPair.connector_id == connector_id)
|
||||
)
|
||||
if disinclude_finished:
|
||||
stmt = stmt.where(
|
||||
IndexAttempt.status.in_(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
)
|
||||
)
|
||||
if only_current:
|
||||
stmt = stmt.join(EmbeddingModel).where(
|
||||
EmbeddingModel.status == IndexModelStatus.PRESENT
|
||||
)
|
||||
|
||||
stmt = stmt.order_by(IndexAttempt.time_created.desc())
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def get_latest_finished_index_attempt_for_cc_pair(
|
||||
connector_credential_pair_id: int,
|
||||
db_session: Session,
|
||||
) -> IndexAttempt | None:
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.where(
|
||||
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
|
||||
IndexAttempt.status.not_in(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
),
|
||||
)
|
||||
.order_by(desc(IndexAttempt.time_created))
|
||||
.limit(1)
|
||||
)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def get_index_attempts_for_cc_pair(
|
||||
db_session: Session,
|
||||
cc_pair_identifier: ConnectorCredentialPairIdentifier,
|
||||
only_current: bool = True,
|
||||
disinclude_finished: bool = False,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
stmt = select(IndexAttempt).where(
|
||||
and_(
|
||||
IndexAttempt.connector_id == cc_pair_identifier.connector_id,
|
||||
IndexAttempt.credential_id == cc_pair_identifier.credential_id,
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.join(ConnectorCredentialPair)
|
||||
.where(
|
||||
and_(
|
||||
ConnectorCredentialPair.connector_id == cc_pair_identifier.connector_id,
|
||||
ConnectorCredentialPair.credential_id
|
||||
== cc_pair_identifier.credential_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
if disinclude_finished:
|
||||
@@ -218,9 +301,11 @@ def delete_index_attempts(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
stmt = delete(IndexAttempt).where(
|
||||
IndexAttempt.connector_id == connector_id,
|
||||
IndexAttempt.credential_id == credential_id,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
@@ -254,9 +339,11 @@ def cancel_indexing_attempts_for_connector(
|
||||
db_session: Session,
|
||||
include_secondary_index: bool = False,
|
||||
) -> None:
|
||||
stmt = delete(IndexAttempt).where(
|
||||
IndexAttempt.connector_id == connector_id,
|
||||
IndexAttempt.status == IndexingStatus.NOT_STARTED,
|
||||
stmt = (
|
||||
delete(IndexAttempt)
|
||||
.where(IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id)
|
||||
.where(ConnectorCredentialPair.connector_id == connector_id)
|
||||
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
|
||||
)
|
||||
|
||||
if not include_secondary_index:
|
||||
@@ -296,7 +383,8 @@ def count_unique_cc_pairs_with_successful_index_attempts(
|
||||
Then do distinct by connector_id and credential_id which is equivalent to the cc-pair. Finally,
|
||||
do a count to get the total number of unique cc-pairs with successful attempts"""
|
||||
unique_pairs_count = (
|
||||
db_session.query(IndexAttempt.connector_id, IndexAttempt.credential_id)
|
||||
db_session.query(IndexAttempt.connector_credential_pair_id)
|
||||
.join(ConnectorCredentialPair)
|
||||
.filter(
|
||||
IndexAttempt.embedding_model_id == embedding_model_id,
|
||||
IndexAttempt.status == IndexingStatus.SUCCESS,
|
||||
|
||||
202
backend/danswer/db/input_prompt.py
Normal file
202
backend/danswer/db/input_prompt.py
Normal file
@@ -0,0 +1,202 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import InputPrompt
|
||||
from danswer.db.models import User
|
||||
from danswer.server.features.input_prompt.models import InputPromptSnapshot
|
||||
from danswer.server.manage.models import UserInfo
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def insert_input_prompt_if_not_exists(
|
||||
user: User | None,
|
||||
input_prompt_id: int | None,
|
||||
prompt: str,
|
||||
content: str,
|
||||
active: bool,
|
||||
is_public: bool,
|
||||
db_session: Session,
|
||||
commit: bool = True,
|
||||
) -> InputPrompt:
|
||||
if input_prompt_id is not None:
|
||||
input_prompt = (
|
||||
db_session.query(InputPrompt).filter_by(id=input_prompt_id).first()
|
||||
)
|
||||
else:
|
||||
query = db_session.query(InputPrompt).filter(InputPrompt.prompt == prompt)
|
||||
if user:
|
||||
query = query.filter(InputPrompt.user_id == user.id)
|
||||
else:
|
||||
query = query.filter(InputPrompt.user_id.is_(None))
|
||||
input_prompt = query.first()
|
||||
|
||||
if input_prompt is None:
|
||||
input_prompt = InputPrompt(
|
||||
id=input_prompt_id,
|
||||
prompt=prompt,
|
||||
content=content,
|
||||
active=active,
|
||||
is_public=is_public or user is None,
|
||||
user_id=user.id if user else None,
|
||||
)
|
||||
db_session.add(input_prompt)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
|
||||
return input_prompt
|
||||
|
||||
|
||||
def insert_input_prompt(
|
||||
prompt: str,
|
||||
content: str,
|
||||
is_public: bool,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> InputPrompt:
|
||||
input_prompt = InputPrompt(
|
||||
prompt=prompt,
|
||||
content=content,
|
||||
active=True,
|
||||
is_public=is_public or user is None,
|
||||
user_id=user.id if user is not None else None,
|
||||
)
|
||||
db_session.add(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
return input_prompt
|
||||
|
||||
|
||||
def update_input_prompt(
|
||||
user: User | None,
|
||||
input_prompt_id: int,
|
||||
prompt: str,
|
||||
content: str,
|
||||
active: bool,
|
||||
db_session: Session,
|
||||
) -> InputPrompt:
|
||||
input_prompt = db_session.scalar(
|
||||
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
|
||||
)
|
||||
if input_prompt is None:
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if not validate_user_prompt_authorization(user, input_prompt):
|
||||
raise HTTPException(status_code=401, detail="You don't own this prompt")
|
||||
|
||||
input_prompt.prompt = prompt
|
||||
input_prompt.content = content
|
||||
input_prompt.active = active
|
||||
|
||||
db_session.commit()
|
||||
return input_prompt
|
||||
|
||||
|
||||
def validate_user_prompt_authorization(
|
||||
user: User | None, input_prompt: InputPrompt
|
||||
) -> bool:
|
||||
prompt = InputPromptSnapshot.from_model(input_prompt=input_prompt)
|
||||
|
||||
if prompt.user_id is not None:
|
||||
if user is None:
|
||||
return False
|
||||
|
||||
user_details = UserInfo.from_model(user)
|
||||
if str(user_details.id) != str(prompt.user_id):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def remove_public_input_prompt(input_prompt_id: int, db_session: Session) -> None:
|
||||
input_prompt = db_session.scalar(
|
||||
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
|
||||
)
|
||||
|
||||
if input_prompt is None:
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if not input_prompt.is_public:
|
||||
raise HTTPException(status_code=400, detail="This prompt is not public")
|
||||
|
||||
db_session.delete(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_input_prompt(
|
||||
user: User | None, input_prompt_id: int, db_session: Session
|
||||
) -> None:
|
||||
input_prompt = db_session.scalar(
|
||||
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
|
||||
)
|
||||
if input_prompt is None:
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if input_prompt.is_public:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Cannot delete public prompts with this method"
|
||||
)
|
||||
|
||||
if not validate_user_prompt_authorization(user, input_prompt):
|
||||
raise HTTPException(status_code=401, detail="You do not own this prompt")
|
||||
|
||||
db_session.delete(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def fetch_input_prompt_by_id(
|
||||
id: int, user_id: UUID | None, db_session: Session
|
||||
) -> InputPrompt:
|
||||
query = select(InputPrompt).where(InputPrompt.id == id)
|
||||
|
||||
if user_id:
|
||||
query = query.where(
|
||||
(InputPrompt.user_id == user_id) | (InputPrompt.user_id is None)
|
||||
)
|
||||
else:
|
||||
# If no user_id is provided, only fetch prompts without a user_id (aka public)
|
||||
query = query.where(InputPrompt.user_id == None) # noqa
|
||||
|
||||
result = db_session.scalar(query)
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(422, "No input prompt found")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def fetch_public_input_prompts(
|
||||
db_session: Session,
|
||||
) -> list[InputPrompt]:
|
||||
query = select(InputPrompt).where(InputPrompt.is_public)
|
||||
return list(db_session.scalars(query).all())
|
||||
|
||||
|
||||
def fetch_input_prompts_by_user(
|
||||
db_session: Session,
|
||||
user_id: UUID | None,
|
||||
active: bool | None = None,
|
||||
include_public: bool = False,
|
||||
) -> list[InputPrompt]:
|
||||
query = select(InputPrompt)
|
||||
|
||||
if user_id is not None:
|
||||
if include_public:
|
||||
query = query.where(
|
||||
(InputPrompt.user_id == user_id) | InputPrompt.is_public
|
||||
)
|
||||
else:
|
||||
query = query.where(InputPrompt.user_id == user_id)
|
||||
|
||||
elif include_public:
|
||||
query = query.where(InputPrompt.is_public)
|
||||
|
||||
if active is not None:
|
||||
query = query.where(InputPrompt.active == active)
|
||||
|
||||
return list(db_session.scalars(query).all())
|
||||
@@ -1,15 +1,41 @@
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
from danswer.db.models import LLMProvider__UserGroup
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.llm.models import FullLLMProvider
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
|
||||
|
||||
def update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id: int,
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
# Delete existing relationships
|
||||
db_session.query(LLMProvider__UserGroup).filter(
|
||||
LLMProvider__UserGroup.llm_provider_id == llm_provider_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
# Add new relationships from given group_ids
|
||||
if group_ids:
|
||||
new_relationships = [
|
||||
LLMProvider__UserGroup(
|
||||
llm_provider_id=llm_provider_id,
|
||||
user_group_id=group_id,
|
||||
)
|
||||
for group_id in group_ids
|
||||
]
|
||||
db_session.add_all(new_relationships)
|
||||
|
||||
|
||||
def upsert_cloud_embedding_provider(
|
||||
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
|
||||
) -> CloudEmbeddingProvider:
|
||||
@@ -36,36 +62,36 @@ def upsert_llm_provider(
|
||||
existing_llm_provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
||||
)
|
||||
if existing_llm_provider:
|
||||
existing_llm_provider.provider = llm_provider.provider
|
||||
existing_llm_provider.api_key = llm_provider.api_key
|
||||
existing_llm_provider.api_base = llm_provider.api_base
|
||||
existing_llm_provider.api_version = llm_provider.api_version
|
||||
existing_llm_provider.custom_config = llm_provider.custom_config
|
||||
existing_llm_provider.default_model_name = llm_provider.default_model_name
|
||||
existing_llm_provider.fast_default_model_name = (
|
||||
llm_provider.fast_default_model_name
|
||||
)
|
||||
existing_llm_provider.model_names = llm_provider.model_names
|
||||
db_session.commit()
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
# if it does not exist, create a new entry
|
||||
llm_provider_model = LLMProviderModel(
|
||||
name=llm_provider.name,
|
||||
provider=llm_provider.provider,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
default_model_name=llm_provider.default_model_name,
|
||||
fast_default_model_name=llm_provider.fast_default_model_name,
|
||||
model_names=llm_provider.model_names,
|
||||
is_default_provider=None,
|
||||
|
||||
if not existing_llm_provider:
|
||||
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
existing_llm_provider.provider = llm_provider.provider
|
||||
existing_llm_provider.api_key = llm_provider.api_key
|
||||
existing_llm_provider.api_base = llm_provider.api_base
|
||||
existing_llm_provider.api_version = llm_provider.api_version
|
||||
existing_llm_provider.custom_config = llm_provider.custom_config
|
||||
existing_llm_provider.default_model_name = llm_provider.default_model_name
|
||||
existing_llm_provider.fast_default_model_name = llm_provider.fast_default_model_name
|
||||
existing_llm_provider.model_names = llm_provider.model_names
|
||||
existing_llm_provider.is_public = llm_provider.is_public
|
||||
existing_llm_provider.display_model_names = llm_provider.display_model_names
|
||||
|
||||
if not existing_llm_provider.id:
|
||||
# If its not already in the db, we need to generate an ID by flushing
|
||||
db_session.flush()
|
||||
|
||||
# Make sure the relationship table stays up to date
|
||||
update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id=existing_llm_provider.id,
|
||||
group_ids=llm_provider.groups,
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.add(llm_provider_model)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return FullLLMProvider.from_model(llm_provider_model)
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
|
||||
|
||||
def fetch_existing_embedding_providers(
|
||||
@@ -74,8 +100,29 @@ def fetch_existing_embedding_providers(
|
||||
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
) -> list[LLMProviderModel]:
|
||||
if not user:
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
stmt = select(LLMProviderModel).distinct()
|
||||
user_groups_subquery = (
|
||||
select(User__UserGroup.user_group_id)
|
||||
.where(User__UserGroup.user_id == user.id)
|
||||
.subquery()
|
||||
)
|
||||
access_conditions = or_(
|
||||
LLMProviderModel.is_public,
|
||||
LLMProviderModel.id.in_( # User is part of a group that has access
|
||||
select(LLMProvider__UserGroup.llm_provider_id).where(
|
||||
LLMProvider__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore
|
||||
)
|
||||
),
|
||||
)
|
||||
stmt = stmt.where(access_conditions)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
@@ -119,6 +166,13 @@ def remove_embedding_provider(
|
||||
|
||||
|
||||
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||
# Remove LLMProvider's dependent relationships
|
||||
db_session.execute(
|
||||
delete(LLMProvider__UserGroup).where(
|
||||
LLMProvider__UserGroup.llm_provider_id == provider_id
|
||||
)
|
||||
)
|
||||
# Remove LLMProvider
|
||||
db_session.execute(
|
||||
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from uuid import UUID
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID
|
||||
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
|
||||
from fastapi_users_db_sqlalchemy.generics import TIMESTAMPAware
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import Enum
|
||||
@@ -120,6 +121,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
postgresql.ARRAY(Integer), nullable=True
|
||||
)
|
||||
|
||||
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
|
||||
TIMESTAMPAware(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
# relationships
|
||||
credentials: Mapped[list["Credential"]] = relationship(
|
||||
"Credential", back_populates="user", lazy="joined"
|
||||
@@ -132,12 +137,39 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
)
|
||||
|
||||
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
|
||||
input_prompts: Mapped[list["InputPrompt"]] = relationship(
|
||||
"InputPrompt", back_populates="user"
|
||||
)
|
||||
|
||||
# Personas owned by this user
|
||||
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
|
||||
# Custom tools created by this user
|
||||
custom_tools: Mapped[list["Tool"]] = relationship("Tool", back_populates="user")
|
||||
|
||||
|
||||
class InputPrompt(Base):
|
||||
__tablename__ = "inputprompt"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
prompt: Mapped[str] = mapped_column(String)
|
||||
content: Mapped[str] = mapped_column(String)
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
user: Mapped[User | None] = relationship("User", back_populates="input_prompts")
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
|
||||
|
||||
|
||||
class InputPrompt__User(Base):
|
||||
__tablename__ = "inputprompt__user"
|
||||
|
||||
input_prompt_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("inputprompt.id"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("inputprompt.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
|
||||
pass
|
||||
|
||||
@@ -337,6 +369,9 @@ class ConnectorCredentialPair(Base):
|
||||
back_populates="connector_credential_pairs",
|
||||
overlaps="document_set",
|
||||
)
|
||||
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
|
||||
"IndexAttempt", back_populates="connector_credential_pair"
|
||||
)
|
||||
|
||||
|
||||
class Document(Base):
|
||||
@@ -416,6 +451,9 @@ class Connector(Base):
|
||||
connector_specific_config: Mapped[dict[str, Any]] = mapped_column(
|
||||
postgresql.JSONB()
|
||||
)
|
||||
indexing_start: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime, nullable=True
|
||||
)
|
||||
refresh_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
prune_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -434,14 +472,17 @@ class Connector(Base):
|
||||
documents_by_connector: Mapped[
|
||||
list["DocumentByConnectorCredentialPair"]
|
||||
] = relationship("DocumentByConnectorCredentialPair", back_populates="connector")
|
||||
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
|
||||
"IndexAttempt", back_populates="connector"
|
||||
)
|
||||
|
||||
|
||||
class Credential(Base):
|
||||
__tablename__ = "credential"
|
||||
|
||||
name: Mapped[str] = mapped_column(String, nullable=True)
|
||||
|
||||
source: Mapped[DocumentSource] = mapped_column(
|
||||
Enum(DocumentSource, native_enum=False)
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
credential_json: Mapped[dict[str, Any]] = mapped_column(EncryptedJson())
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
@@ -462,9 +503,7 @@ class Credential(Base):
|
||||
documents_by_credential: Mapped[
|
||||
list["DocumentByConnectorCredentialPair"]
|
||||
] = relationship("DocumentByConnectorCredentialPair", back_populates="credential")
|
||||
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
|
||||
"IndexAttempt", back_populates="credential"
|
||||
)
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="credentials")
|
||||
|
||||
|
||||
@@ -516,12 +555,12 @@ class EmbeddingModel(Base):
|
||||
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
return self.cloud_provider.api_key if self.cloud_provider else None
|
||||
def provider_type(self) -> str | None:
|
||||
return self.cloud_provider.name if self.cloud_provider is not None else None
|
||||
|
||||
@property
|
||||
def provider_type(self) -> str | None:
|
||||
return self.cloud_provider.name if self.cloud_provider else None
|
||||
def api_key(self) -> str | None:
|
||||
return self.cloud_provider.api_key if self.cloud_provider is not None else None
|
||||
|
||||
|
||||
class IndexAttempt(Base):
|
||||
@@ -534,13 +573,10 @@ class IndexAttempt(Base):
|
||||
__tablename__ = "index_attempt"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
connector_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("connector.id"),
|
||||
nullable=True,
|
||||
)
|
||||
credential_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("credential.id"),
|
||||
nullable=True,
|
||||
|
||||
connector_credential_pair_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Some index attempts that run from beginning will still have this as False
|
||||
@@ -578,12 +614,10 @@ class IndexAttempt(Base):
|
||||
onupdate=func.now(),
|
||||
)
|
||||
|
||||
connector: Mapped[Connector] = relationship(
|
||||
"Connector", back_populates="index_attempts"
|
||||
)
|
||||
credential: Mapped[Credential] = relationship(
|
||||
"Credential", back_populates="index_attempts"
|
||||
connector_credential_pair: Mapped[ConnectorCredentialPair] = relationship(
|
||||
"ConnectorCredentialPair", back_populates="index_attempts"
|
||||
)
|
||||
|
||||
embedding_model: Mapped[EmbeddingModel] = relationship(
|
||||
"EmbeddingModel", back_populates="index_attempts"
|
||||
)
|
||||
@@ -591,8 +625,7 @@ class IndexAttempt(Base):
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_index_attempt_latest_for_connector_credential_pair",
|
||||
"connector_id",
|
||||
"credential_id",
|
||||
"connector_credential_pair_id",
|
||||
"time_created",
|
||||
),
|
||||
)
|
||||
@@ -600,7 +633,6 @@ class IndexAttempt(Base):
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<IndexAttempt(id={self.id!r}, "
|
||||
f"connector_id={self.connector_id!r}, "
|
||||
f"status={self.status!r}, "
|
||||
f"error_msg={self.error_msg!r})>"
|
||||
f"time_created={self.time_created!r}, "
|
||||
@@ -821,6 +853,8 @@ class ChatMessage(Base):
|
||||
secondary="chat_message__search_doc",
|
||||
back_populates="chat_messages",
|
||||
)
|
||||
# NOTE: Should always be attached to the `assistant` message.
|
||||
# represents the tool calls used to generate this message
|
||||
tool_calls: Mapped[list["ToolCall"]] = relationship(
|
||||
"ToolCall",
|
||||
back_populates="message",
|
||||
@@ -923,6 +957,11 @@ class LLMProvider(Base):
|
||||
default_model_name: Mapped[str] = mapped_column(String)
|
||||
fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# Models to actually disp;aly to users
|
||||
# If nulled out, we assume in the application logic we should present all
|
||||
display_model_names: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
# The LLMs that are available for this provider. Only required if not a default provider.
|
||||
# If a default provider, then the LLM options are pulled from the `options.py` file.
|
||||
# If needed, can be pulled out as a separate table in the future.
|
||||
@@ -932,6 +971,13 @@ class LLMProvider(Base):
|
||||
|
||||
# should only be set for a single provider
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
groups: Mapped[list["UserGroup"]] = relationship(
|
||||
"UserGroup",
|
||||
secondary="llm_provider__user_group",
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(Base):
|
||||
@@ -1109,7 +1155,10 @@ class Persona(Base):
|
||||
# where lower value IDs (e.g. created earlier) are displayed first
|
||||
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None)
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
uploaded_image_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
icon_color: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
icon_shape: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# These are only defaults, users can select from all if desired
|
||||
prompts: Mapped[list[Prompt]] = relationship(
|
||||
@@ -1137,6 +1186,7 @@ class Persona(Base):
|
||||
viewonly=True,
|
||||
)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
groups: Mapped[list["UserGroup"]] = relationship(
|
||||
"UserGroup",
|
||||
secondary="persona__user_group",
|
||||
@@ -1360,6 +1410,17 @@ class Persona__UserGroup(Base):
|
||||
)
|
||||
|
||||
|
||||
class LLMProvider__UserGroup(Base):
|
||||
__tablename__ = "llm_provider__user_group"
|
||||
|
||||
llm_provider_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("llm_provider.id"), primary_key=True
|
||||
)
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class DocumentSet__UserGroup(Base):
|
||||
__tablename__ = "document_set__user_group"
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import not_
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
@@ -24,6 +25,7 @@ from danswer.db.models import StarterMessage
|
||||
from danswer.db.models import Tool
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.server.features.persona.models import CreatePersonaRequest
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
@@ -80,6 +82,9 @@ def create_update_persona(
|
||||
starter_messages=create_persona_request.starter_messages,
|
||||
is_public=create_persona_request.is_public,
|
||||
db_session=db_session,
|
||||
icon_color=create_persona_request.icon_color,
|
||||
icon_shape=create_persona_request.icon_shape,
|
||||
uploaded_image_id=create_persona_request.uploaded_image_id,
|
||||
)
|
||||
|
||||
versioned_make_persona_private = fetch_versioned_implementation(
|
||||
@@ -328,6 +333,9 @@ def upsert_persona(
|
||||
persona_id: int | None = None,
|
||||
default_persona: bool = False,
|
||||
commit: bool = True,
|
||||
icon_color: str | None = None,
|
||||
icon_shape: int | None = None,
|
||||
uploaded_image_id: str | None = None,
|
||||
) -> Persona:
|
||||
if persona_id is not None:
|
||||
persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
||||
@@ -383,6 +391,9 @@ def upsert_persona(
|
||||
persona.starter_messages = starter_messages
|
||||
persona.deleted = False # Un-delete if previously deleted
|
||||
persona.is_public = is_public
|
||||
persona.icon_color = icon_color
|
||||
persona.icon_shape = icon_shape
|
||||
persona.uploaded_image_id = uploaded_image_id
|
||||
|
||||
# Do not delete any associations manually added unless
|
||||
# a new updated list is provided
|
||||
@@ -415,6 +426,9 @@ def upsert_persona(
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
starter_messages=starter_messages,
|
||||
tools=tools or [],
|
||||
icon_shape=icon_shape,
|
||||
icon_color=icon_color,
|
||||
uploaded_image_id=uploaded_image_id,
|
||||
)
|
||||
db_session.add(persona)
|
||||
|
||||
@@ -548,6 +562,8 @@ def get_default_prompt__read_only() -> Prompt:
|
||||
return _get_default_prompt(db_session)
|
||||
|
||||
|
||||
# TODO: since this gets called with every chat message, could it be more efficient to pregenerate
|
||||
# a direct mapping indicating whether a user has access to a specific persona?
|
||||
def get_persona_by_id(
|
||||
persona_id: int,
|
||||
# if user is `None` assume the user is an admin or auth is disabled
|
||||
@@ -556,16 +572,38 @@ def get_persona_by_id(
|
||||
include_deleted: bool = False,
|
||||
is_for_edit: bool = True, # NOTE: assume true for safety
|
||||
) -> Persona:
|
||||
stmt = select(Persona).where(Persona.id == persona_id)
|
||||
stmt = (
|
||||
select(Persona)
|
||||
.options(selectinload(Persona.users), selectinload(Persona.groups))
|
||||
.where(Persona.id == persona_id)
|
||||
)
|
||||
|
||||
or_conditions = []
|
||||
|
||||
# if user is an admin, they should have access to all Personas
|
||||
# and will skip the following clause
|
||||
if user is not None and user.role != UserRole.ADMIN:
|
||||
or_conditions.extend([Persona.user_id == user.id, Persona.user_id.is_(None)])
|
||||
# the user is not an admin
|
||||
isPersonaUnowned = Persona.user_id.is_(
|
||||
None
|
||||
) # allow access if persona user id is None
|
||||
isUserCreator = (
|
||||
Persona.user_id == user.id
|
||||
) # allow access if user created the persona
|
||||
or_conditions.extend([isPersonaUnowned, isUserCreator])
|
||||
|
||||
# if we aren't editing, also give access to all public personas
|
||||
# if we aren't editing, also give access if:
|
||||
# 1. the user is authorized for this persona
|
||||
# 2. the user is in an authorized group for this persona
|
||||
# 3. if the persona is public
|
||||
if not is_for_edit:
|
||||
isSharedWithUser = Persona.users.any(
|
||||
id=user.id
|
||||
) # allow access if user is in allowed users
|
||||
isSharedWithGroup = Persona.groups.any(
|
||||
UserGroup.users.any(id=user.id)
|
||||
) # allow access if user is in any allowed group
|
||||
or_conditions.extend([isSharedWithUser, isSharedWithGroup])
|
||||
or_conditions.append(Persona.is_public.is_(True))
|
||||
|
||||
if or_conditions:
|
||||
|
||||
@@ -7,6 +7,7 @@ from danswer.access.models import DocumentAccess
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import InferenceChunkUncleaned
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -257,7 +258,7 @@ class VectorCapable(abc.ABC):
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query: str, # Needed for matching purposes
|
||||
query_embedding: list[float],
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int,
|
||||
@@ -292,7 +293,7 @@ class HybridCapable(abc.ABC):
|
||||
def hybrid_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
query_embedding: list[float],
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int,
|
||||
|
||||
@@ -69,6 +69,7 @@ from danswer.search.retrieval.search_runner import query_processing
|
||||
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -329,11 +330,13 @@ def _index_vespa_chunk(
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
document = chunk.source_document
|
||||
|
||||
# No minichunk documents in vespa, minichunk vectors are stored in the chunk itself
|
||||
vespa_chunk_id = str(get_uuid_from_chunk(chunk))
|
||||
|
||||
embeddings = chunk.embeddings
|
||||
|
||||
embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding}
|
||||
|
||||
if embeddings.mini_chunk_embeddings:
|
||||
for ind, m_c_embed in enumerate(embeddings.mini_chunk_embeddings):
|
||||
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
|
||||
@@ -346,11 +349,15 @@ def _index_vespa_chunk(
|
||||
BLURB: remove_invalid_unicode_chars(chunk.blurb),
|
||||
TITLE: remove_invalid_unicode_chars(title) if title else None,
|
||||
SKIP_TITLE_EMBEDDING: not title,
|
||||
CONTENT: remove_invalid_unicode_chars(chunk.content),
|
||||
# For the BM25 index, the keyword suffix is used, the vector is already generated with the more
|
||||
# natural language representation of the metadata section
|
||||
CONTENT: remove_invalid_unicode_chars(
|
||||
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_keyword}"
|
||||
),
|
||||
# This duplication of `content` is needed for keyword highlighting
|
||||
# Note that it's not exactly the same as the actual content
|
||||
# which contains the title prefix and metadata suffix
|
||||
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content_summary),
|
||||
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content),
|
||||
SOURCE_TYPE: str(document.source.value),
|
||||
SOURCE_LINKS: json.dumps(chunk.source_links),
|
||||
SEMANTIC_IDENTIFIER: remove_invalid_unicode_chars(document.semantic_identifier),
|
||||
@@ -358,7 +365,7 @@ def _index_vespa_chunk(
|
||||
METADATA: json.dumps(document.metadata),
|
||||
# Save as a list for efficient extraction as an Attribute
|
||||
METADATA_LIST: chunk.source_document.get_metadata_str_attributes(),
|
||||
METADATA_SUFFIX: chunk.metadata_suffix,
|
||||
METADATA_SUFFIX: chunk.metadata_suffix_keyword,
|
||||
EMBEDDINGS: embeddings_name_vector_map,
|
||||
TITLE_EMBEDDING: chunk.title_embedding,
|
||||
BOOST: chunk.boost,
|
||||
@@ -1025,7 +1032,7 @@ class VespaIndex(DocumentIndex):
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
query_embedding: list[float],
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||
@@ -1067,7 +1074,7 @@ class VespaIndex(DocumentIndex):
|
||||
def hybrid_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
query_embedding: list[float],
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int,
|
||||
|
||||
@@ -103,6 +103,8 @@ def port_api_key_to_postgres() -> None:
|
||||
default_model_name=default_model_name,
|
||||
fast_default_model_name=default_fast_model_name,
|
||||
model_names=None,
|
||||
display_model_names=[],
|
||||
is_public=True,
|
||||
)
|
||||
llm_provider = upsert_llm_provider(db_session, llm_provider_upsert)
|
||||
update_default_provider(db_session, llm_provider.id)
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import cast
|
||||
from filelock import FileLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import SessionFactory
|
||||
from danswer.db.engine import get_session_factory
|
||||
from danswer.db.models import KVStore
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.dynamic_configs.interface import DynamicConfigStore
|
||||
@@ -56,7 +56,8 @@ class FileSystemBackedDynamicConfigStore(DynamicConfigStore):
|
||||
class PostgresBackedDynamicConfigStore(DynamicConfigStore):
|
||||
@contextmanager
|
||||
def get_session(self) -> Iterator[Session]:
|
||||
session: Session = SessionFactory()
|
||||
factory = get_session_factory()
|
||||
session: Session = factory()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import abc
|
||||
from collections.abc import Callable
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.configs.app_configs import BLURB_SIZE
|
||||
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
|
||||
from danswer.configs.app_configs import MINI_CHUNK_SIZE
|
||||
from danswer.configs.app_configs import SKIP_METADATA_IN_CHUNK
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MAX_CHUNK_TITLE_LEN
|
||||
from danswer.configs.constants import RETURN_SEPARATOR
|
||||
from danswer.configs.constants import SECTION_SEPARATOR
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
@@ -14,13 +15,14 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_metadata_keys_to_ignore,
|
||||
)
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.indexing.embedder import IndexingEmbedder
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.search.search_nlp_models import get_default_tokenizer
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import AutoTokenizer # type:ignore
|
||||
from llama_index.text_splitter import SentenceSplitter # type:ignore
|
||||
|
||||
|
||||
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
|
||||
@@ -28,6 +30,8 @@ if TYPE_CHECKING:
|
||||
CHUNK_OVERLAP = 0
|
||||
# Fairly arbitrary numbers but the general concept is we don't want the title/metadata to
|
||||
# overwhelm the actual contents of the chunk
|
||||
# For example in a rare case, this could be 128 tokens for the 512 chunk and title prefix
|
||||
# could be another 128 tokens leaving 256 for the actual contents
|
||||
MAX_METADATA_PERCENTAGE = 0.25
|
||||
CHUNK_MIN_CONTENT = 256
|
||||
|
||||
@@ -36,15 +40,11 @@ logger = setup_logger()
|
||||
ChunkFunc = Callable[[Document], list[DocAwareChunk]]
|
||||
|
||||
|
||||
def extract_blurb(text: str, blurb_size: int) -> str:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
token_count_func = get_default_tokenizer().tokenize
|
||||
blurb_splitter = SentenceSplitter(
|
||||
tokenizer=token_count_func, chunk_size=blurb_size, chunk_overlap=0
|
||||
)
|
||||
|
||||
return blurb_splitter.split_text(text)[0]
|
||||
def extract_blurb(text: str, blurb_splitter: "SentenceSplitter") -> str:
|
||||
texts = blurb_splitter.split_text(text)
|
||||
if not texts:
|
||||
return ""
|
||||
return texts[0]
|
||||
|
||||
|
||||
def chunk_large_section(
|
||||
@@ -52,76 +52,129 @@ def chunk_large_section(
|
||||
section_link_text: str,
|
||||
document: Document,
|
||||
start_chunk_id: int,
|
||||
tokenizer: "AutoTokenizer",
|
||||
chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
chunk_overlap: int = CHUNK_OVERLAP,
|
||||
blurb_size: int = BLURB_SIZE,
|
||||
title_prefix: str = "",
|
||||
metadata_suffix: str = "",
|
||||
blurb: str,
|
||||
chunk_splitter: "SentenceSplitter",
|
||||
mini_chunk_splitter: Optional["SentenceSplitter"],
|
||||
title_prefix: str,
|
||||
metadata_suffix_semantic: str,
|
||||
metadata_suffix_keyword: str,
|
||||
) -> list[DocAwareChunk]:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
blurb = extract_blurb(section_text, blurb_size)
|
||||
|
||||
sentence_aware_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
|
||||
split_texts = sentence_aware_splitter.split_text(section_text)
|
||||
split_texts = chunk_splitter.split_text(section_text)
|
||||
|
||||
chunks = [
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=start_chunk_id + chunk_ind,
|
||||
blurb=blurb,
|
||||
content=f"{title_prefix}{chunk_str}{metadata_suffix}",
|
||||
content_summary=chunk_str,
|
||||
content=chunk_text,
|
||||
source_links={0: section_link_text},
|
||||
section_continuation=(chunk_ind != 0),
|
||||
metadata_suffix=metadata_suffix,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
|
||||
if mini_chunk_splitter and chunk_text.strip()
|
||||
else None,
|
||||
)
|
||||
for chunk_ind, chunk_str in enumerate(split_texts)
|
||||
for chunk_ind, chunk_text in enumerate(split_texts)
|
||||
]
|
||||
return chunks
|
||||
|
||||
|
||||
def _get_metadata_suffix_for_document_index(
|
||||
metadata: dict[str, str | list[str]]
|
||||
) -> str:
|
||||
metadata: dict[str, str | list[str]], include_separator: bool = False
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Returns the metadata as a natural language string representation with all of the keys and values for the vector embedding
|
||||
and a string of all of the values for the keyword search
|
||||
|
||||
For example, if we have the following metadata:
|
||||
{
|
||||
"author": "John Doe",
|
||||
"space": "Engineering"
|
||||
}
|
||||
The vector embedding string should include the relation between the key and value wheres as for keyword we only want John Doe
|
||||
and Engineering. The keys are repeat and much more noisy.
|
||||
"""
|
||||
if not metadata:
|
||||
return ""
|
||||
return "", ""
|
||||
|
||||
metadata_str = "Metadata:\n"
|
||||
metadata_values = []
|
||||
for key, value in metadata.items():
|
||||
if key in get_metadata_keys_to_ignore():
|
||||
continue
|
||||
|
||||
value_str = ", ".join(value) if isinstance(value, list) else value
|
||||
|
||||
if isinstance(value, list):
|
||||
metadata_values.extend(value)
|
||||
else:
|
||||
metadata_values.append(value)
|
||||
|
||||
metadata_str += f"\t{key} - {value_str}\n"
|
||||
return metadata_str.strip()
|
||||
|
||||
metadata_semantic = metadata_str.strip()
|
||||
metadata_keyword = " ".join(metadata_values)
|
||||
|
||||
if include_separator:
|
||||
return RETURN_SEPARATOR + metadata_semantic, RETURN_SEPARATOR + metadata_keyword
|
||||
return metadata_semantic, metadata_keyword
|
||||
|
||||
|
||||
def chunk_document(
|
||||
document: Document,
|
||||
embedder: IndexingEmbedder,
|
||||
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
subsection_overlap: int = CHUNK_OVERLAP,
|
||||
blurb_size: int = BLURB_SIZE,
|
||||
blurb_size: int = BLURB_SIZE, # Used for both title and content
|
||||
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
|
||||
mini_chunk_size: int = MINI_CHUNK_SIZE,
|
||||
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
|
||||
) -> list[DocAwareChunk]:
|
||||
tokenizer = get_default_tokenizer()
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
title = document.get_title_for_document_index()
|
||||
title_prefix = f"{title[:MAX_CHUNK_TITLE_LEN]}{RETURN_SEPARATOR}" if title else ""
|
||||
tokenizer = get_tokenizer(
|
||||
model_name=embedder.model_name,
|
||||
provider_type=embedder.provider_type,
|
||||
)
|
||||
|
||||
blurb_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize, chunk_size=blurb_size, chunk_overlap=0
|
||||
)
|
||||
|
||||
chunk_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize,
|
||||
chunk_size=chunk_tok_size,
|
||||
chunk_overlap=subsection_overlap,
|
||||
)
|
||||
|
||||
mini_chunk_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize,
|
||||
chunk_size=mini_chunk_size,
|
||||
chunk_overlap=0,
|
||||
)
|
||||
|
||||
title = extract_blurb(document.get_title_for_document_index() or "", blurb_splitter)
|
||||
title_prefix = title + RETURN_SEPARATOR if title else ""
|
||||
title_tokens = len(tokenizer.tokenize(title_prefix))
|
||||
|
||||
metadata_suffix = ""
|
||||
metadata_suffix_semantic = ""
|
||||
metadata_suffix_keyword = ""
|
||||
metadata_tokens = 0
|
||||
if include_metadata:
|
||||
metadata = _get_metadata_suffix_for_document_index(document.metadata)
|
||||
metadata_suffix = RETURN_SEPARATOR + metadata if metadata else ""
|
||||
metadata_tokens = len(tokenizer.tokenize(metadata_suffix))
|
||||
(
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
) = _get_metadata_suffix_for_document_index(
|
||||
document.metadata, include_separator=True
|
||||
)
|
||||
metadata_tokens = len(tokenizer.tokenize(metadata_suffix_semantic))
|
||||
|
||||
if metadata_tokens >= chunk_tok_size * MAX_METADATA_PERCENTAGE:
|
||||
metadata_suffix = ""
|
||||
# Note: we can keep the keyword suffix even if the semantic suffix is too long to fit in the model
|
||||
# context, there is no limit for the keyword component
|
||||
metadata_suffix_semantic = ""
|
||||
metadata_tokens = 0
|
||||
|
||||
content_token_limit = chunk_tok_size - title_tokens - metadata_tokens
|
||||
@@ -130,7 +183,7 @@ def chunk_document(
|
||||
if content_token_limit <= CHUNK_MIN_CONTENT:
|
||||
content_token_limit = chunk_tok_size
|
||||
title_prefix = ""
|
||||
metadata_suffix = ""
|
||||
metadata_suffix_semantic = ""
|
||||
|
||||
chunks: list[DocAwareChunk] = []
|
||||
link_offsets: dict[int, str] = {}
|
||||
@@ -151,12 +204,16 @@ def chunk_document(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_size),
|
||||
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
|
||||
content_summary=chunk_text,
|
||||
blurb=extract_blurb(chunk_text, blurb_splitter),
|
||||
content=chunk_text,
|
||||
source_links=link_offsets,
|
||||
section_continuation=False,
|
||||
metadata_suffix=metadata_suffix,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
|
||||
if enable_mini_chunk and chunk_text.strip()
|
||||
else None,
|
||||
)
|
||||
)
|
||||
link_offsets = {}
|
||||
@@ -167,12 +224,14 @@ def chunk_document(
|
||||
section_link_text=section_link_text,
|
||||
document=document,
|
||||
start_chunk_id=len(chunks),
|
||||
tokenizer=tokenizer,
|
||||
chunk_size=content_token_limit,
|
||||
chunk_overlap=subsection_overlap,
|
||||
blurb_size=blurb_size,
|
||||
chunk_splitter=chunk_splitter,
|
||||
mini_chunk_splitter=mini_chunk_splitter
|
||||
if enable_mini_chunk and chunk_text.strip()
|
||||
else None,
|
||||
blurb=extract_blurb(section_text, blurb_splitter),
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix=metadata_suffix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
chunks.extend(large_section_chunks)
|
||||
continue
|
||||
@@ -193,60 +252,62 @@ def chunk_document(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_size),
|
||||
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
|
||||
content_summary=chunk_text,
|
||||
blurb=extract_blurb(chunk_text, blurb_splitter),
|
||||
content=chunk_text,
|
||||
source_links=link_offsets,
|
||||
section_continuation=False,
|
||||
metadata_suffix=metadata_suffix,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
|
||||
if enable_mini_chunk and chunk_text.strip()
|
||||
else None,
|
||||
)
|
||||
)
|
||||
link_offsets = {0: section_link_text}
|
||||
chunk_text = section_text
|
||||
|
||||
# Once we hit the end, if we're still in the process of building a chunk, add what we have
|
||||
# NOTE: if it's just whitespace, ignore it.
|
||||
if chunk_text.strip():
|
||||
# Once we hit the end, if we're still in the process of building a chunk, add what we have. If there is only whitespace left
|
||||
# then don't include it. If there are no chunks at all from the doc, we can just create a single chunk with the title.
|
||||
if chunk_text.strip() or not chunks:
|
||||
chunks.append(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_size),
|
||||
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
|
||||
content_summary=chunk_text,
|
||||
blurb=extract_blurb(chunk_text, blurb_splitter),
|
||||
content=chunk_text,
|
||||
source_links=link_offsets,
|
||||
section_continuation=False,
|
||||
metadata_suffix=metadata_suffix,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
|
||||
if enable_mini_chunk and chunk_text.strip()
|
||||
else None,
|
||||
)
|
||||
)
|
||||
|
||||
# If the chunk does not have any useable content, it will not be indexed
|
||||
return chunks
|
||||
|
||||
|
||||
def split_chunk_text_into_mini_chunks(
|
||||
chunk_text: str, mini_chunk_size: int = MINI_CHUNK_SIZE
|
||||
) -> list[str]:
|
||||
"""The minichunks won't all have the title prefix or metadata suffix
|
||||
It could be a significant percentage of every minichunk so better to not include it
|
||||
"""
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
token_count_func = get_default_tokenizer().tokenize
|
||||
sentence_aware_splitter = SentenceSplitter(
|
||||
tokenizer=token_count_func, chunk_size=mini_chunk_size, chunk_overlap=0
|
||||
)
|
||||
|
||||
return sentence_aware_splitter.split_text(chunk_text)
|
||||
|
||||
|
||||
class Chunker:
|
||||
@abc.abstractmethod
|
||||
def chunk(self, document: Document) -> list[DocAwareChunk]:
|
||||
def chunk(
|
||||
self,
|
||||
document: Document,
|
||||
embedder: IndexingEmbedder,
|
||||
) -> list[DocAwareChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DefaultChunker(Chunker):
|
||||
def chunk(self, document: Document) -> list[DocAwareChunk]:
|
||||
def chunk(
|
||||
self,
|
||||
document: Document,
|
||||
embedder: IndexingEmbedder,
|
||||
) -> list[DocAwareChunk]:
|
||||
# Specifically for reproducing an issue with gmail
|
||||
if document.source == DocumentSource.GMAIL:
|
||||
logger.debug(f"Chunking {document.semantic_identifier}")
|
||||
return chunk_document(document)
|
||||
return chunk_document(document, embedder=embedder)
|
||||
|
||||
@@ -3,23 +3,20 @@ from abc import abstractmethod
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
|
||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.models import EmbeddingModel as DbEmbeddingModel
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
|
||||
from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.utils.batching import batch_list
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -32,14 +29,21 @@ class IndexingEmbedder(ABC):
|
||||
normalize: bool,
|
||||
query_prefix: str | None,
|
||||
passage_prefix: str | None,
|
||||
provider_type: str | None,
|
||||
api_key: str | None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.normalize = normalize
|
||||
self.query_prefix = query_prefix
|
||||
self.passage_prefix = passage_prefix
|
||||
self.provider_type = provider_type
|
||||
self.api_key = api_key
|
||||
|
||||
@abstractmethod
|
||||
def embed_chunks(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]:
|
||||
def embed_chunks(
|
||||
self,
|
||||
chunks: list[DocAwareChunk],
|
||||
) -> list[IndexChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -50,10 +54,12 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
normalize: bool,
|
||||
query_prefix: str | None,
|
||||
passage_prefix: str | None,
|
||||
api_key: str | None = None,
|
||||
provider_type: str | None = None,
|
||||
api_key: str | None = None,
|
||||
):
|
||||
super().__init__(model_name, normalize, query_prefix, passage_prefix)
|
||||
super().__init__(
|
||||
model_name, normalize, query_prefix, passage_prefix, provider_type, api_key
|
||||
)
|
||||
self.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE # Currently not customizable
|
||||
|
||||
self.embedding_model = EmbeddingModel(
|
||||
@@ -66,72 +72,63 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
retrim_content=True,
|
||||
)
|
||||
|
||||
def embed_chunks(
|
||||
self,
|
||||
chunks: list[DocAwareChunk],
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
|
||||
) -> list[IndexChunk]:
|
||||
# Cache the Title embeddings to only have to do it once
|
||||
title_embed_dict: dict[str, list[float]] = {}
|
||||
embedded_chunks: list[IndexChunk] = []
|
||||
# All chunks at this point must have some non-empty content
|
||||
flat_chunk_texts: list[str] = []
|
||||
for chunk in chunks:
|
||||
chunk_text = (
|
||||
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_semantic}"
|
||||
) or chunk.source_document.get_title_for_document_index()
|
||||
|
||||
# Create Mini Chunks for more precise matching of details
|
||||
# Off by default with unedited settings
|
||||
chunk_texts = []
|
||||
chunk_mini_chunks_count = {}
|
||||
for chunk_ind, chunk in enumerate(chunks):
|
||||
chunk_texts.append(chunk.content)
|
||||
mini_chunk_texts = (
|
||||
split_chunk_text_into_mini_chunks(chunk.content_summary)
|
||||
if enable_mini_chunk
|
||||
else []
|
||||
)
|
||||
chunk_texts.extend(mini_chunk_texts)
|
||||
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
|
||||
if not chunk_text:
|
||||
# This should never happen, the document would have been dropped
|
||||
# before getting to this point
|
||||
raise ValueError(f"Chunk has no content: {chunk.to_short_descriptor()}")
|
||||
|
||||
# Batching for embedding
|
||||
text_batches = batch_list(chunk_texts, batch_size)
|
||||
flat_chunk_texts.append(chunk_text)
|
||||
|
||||
embeddings: list[list[float]] = []
|
||||
len_text_batches = len(text_batches)
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}")
|
||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
||||
# value for the set loss
|
||||
embeddings.extend(
|
||||
self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE)
|
||||
)
|
||||
if chunk.mini_chunk_texts:
|
||||
flat_chunk_texts.extend(chunk.mini_chunk_texts)
|
||||
|
||||
# Replace line above with the line below for easy debugging of indexing flow
|
||||
# skipping the actual model
|
||||
# embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))])
|
||||
embeddings = self.embedding_model.encode(
|
||||
flat_chunk_texts, text_type=EmbedTextType.PASSAGE
|
||||
)
|
||||
|
||||
chunk_titles = {
|
||||
chunk.source_document.get_title_for_document_index() for chunk in chunks
|
||||
}
|
||||
|
||||
# Drop any None or empty strings
|
||||
# If there is no title or the title is empty, the title embedding field will be null
|
||||
# which is ok, it just won't contribute at all to the scoring.
|
||||
chunk_titles_list = [title for title in chunk_titles if title]
|
||||
|
||||
# Embed Titles in batches
|
||||
title_batches = batch_list(chunk_titles_list, batch_size)
|
||||
len_title_batches = len(title_batches)
|
||||
for ind_batch, title_batch in enumerate(title_batches, start=1):
|
||||
logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}")
|
||||
# Cache the Title embeddings to only have to do it once
|
||||
title_embed_dict: dict[str, Embedding] = {}
|
||||
if chunk_titles_list:
|
||||
title_embeddings = self.embedding_model.encode(
|
||||
title_batch, text_type=EmbedTextType.PASSAGE
|
||||
chunk_titles_list, text_type=EmbedTextType.PASSAGE
|
||||
)
|
||||
title_embed_dict.update(
|
||||
{title: vector for title, vector in zip(title_batch, title_embeddings)}
|
||||
{
|
||||
title: vector
|
||||
for title, vector in zip(chunk_titles_list, title_embeddings)
|
||||
}
|
||||
)
|
||||
|
||||
# Mapping embeddings to chunks
|
||||
embedded_chunks: list[IndexChunk] = []
|
||||
embedding_ind_start = 0
|
||||
for chunk_ind, chunk in enumerate(chunks):
|
||||
num_embeddings = chunk_mini_chunks_count[chunk_ind]
|
||||
for chunk in chunks:
|
||||
num_embeddings = 1 + (
|
||||
len(chunk.mini_chunk_texts) if chunk.mini_chunk_texts else 0
|
||||
)
|
||||
chunk_embeddings = embeddings[
|
||||
embedding_ind_start : embedding_ind_start + num_embeddings
|
||||
]
|
||||
@@ -184,4 +181,6 @@ def get_embedding_model_from_db_embedding_model(
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
api_key=db_embedding_model.api_key,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import Protocol
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -34,7 +33,9 @@ logger = setup_logger()
|
||||
|
||||
class IndexingPipelineProtocol(Protocol):
|
||||
def __call__(
|
||||
self, documents: list[Document], index_attempt_metadata: IndexAttemptMetadata
|
||||
self,
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
) -> tuple[int, int]:
|
||||
...
|
||||
|
||||
@@ -116,7 +117,7 @@ def index_doc_batch(
|
||||
chunker: Chunker,
|
||||
embedder: IndexingEmbedder,
|
||||
document_index: DocumentIndex,
|
||||
documents: list[Document],
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
ignore_time_skip: bool = False,
|
||||
@@ -124,6 +125,32 @@ def index_doc_batch(
|
||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||
Note that the documents should already be batched at this point so that it does not inflate the
|
||||
memory requirements"""
|
||||
documents = []
|
||||
for document in document_batch:
|
||||
empty_contents = not any(section.text.strip() for section in document.sections)
|
||||
if (
|
||||
(not document.title or not document.title.strip())
|
||||
and not document.semantic_identifier.strip()
|
||||
and empty_contents
|
||||
):
|
||||
# Skip documents that have neither title nor content
|
||||
# If the document doesn't have either, then there is no useful information in it
|
||||
# This is again verified later in the pipeline after chunking but at that point there should
|
||||
# already be no documents that are empty.
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as it has neither title nor content."
|
||||
)
|
||||
elif (
|
||||
document.title is not None and not document.title.strip() and empty_contents
|
||||
):
|
||||
# The title is explicitly empty ("" and not None) and the document is empty
|
||||
# so when building the chunk text representation, it will be empty and unuseable
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as the chunks will be empty."
|
||||
)
|
||||
else:
|
||||
documents.append(document)
|
||||
|
||||
document_ids = [document.id for document in documents]
|
||||
db_docs = get_documents_by_ids(
|
||||
document_ids=document_ids,
|
||||
@@ -138,6 +165,11 @@ def index_doc_batch(
|
||||
if not ignore_time_skip
|
||||
else documents
|
||||
)
|
||||
|
||||
# No docs to update either because the batch is empty or every doc was already indexed
|
||||
if not updatable_docs:
|
||||
return 0, 0
|
||||
|
||||
updatable_ids = [doc.id for doc in updatable_docs]
|
||||
|
||||
# Create records in the source of truth about these documents,
|
||||
@@ -149,14 +181,21 @@ def index_doc_batch(
|
||||
)
|
||||
|
||||
logger.debug("Starting chunking")
|
||||
|
||||
# The first chunk additionally contains the Title of the Document
|
||||
chunks: list[DocAwareChunk] = list(
|
||||
chain(*[chunker.chunk(document=document) for document in updatable_docs])
|
||||
)
|
||||
# The embedder is needed here to get the correct tokenizer
|
||||
chunks: list[DocAwareChunk] = [
|
||||
chunk
|
||||
for document in updatable_docs
|
||||
for chunk in chunker.chunk(document=document, embedder=embedder)
|
||||
]
|
||||
|
||||
logger.debug("Starting embedding")
|
||||
chunks_with_embeddings = embedder.embed_chunks(chunks=chunks)
|
||||
chunks_with_embeddings = (
|
||||
embedder.embed_chunks(
|
||||
chunks=chunks,
|
||||
)
|
||||
if chunks
|
||||
else []
|
||||
)
|
||||
|
||||
# Acquires a lock on the documents so that no other process can modify them
|
||||
# NOTE: don't need to acquire till here, since this is when the actual race condition
|
||||
@@ -191,7 +230,7 @@ def index_doc_batch(
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in chunks]}"
|
||||
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in access_aware_chunks]}"
|
||||
)
|
||||
# A document will not be spread across different batches, so all the
|
||||
# documents with chunks in this set, are fully represented by the chunks
|
||||
@@ -215,7 +254,7 @@ def index_doc_batch(
|
||||
)
|
||||
|
||||
return len([r for r in insertion_records if r.already_existed is False]), len(
|
||||
chunks
|
||||
access_aware_chunks
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import EmbeddingModel
|
||||
@@ -13,9 +14,6 @@ if TYPE_CHECKING:
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
Embedding = list[float]
|
||||
|
||||
|
||||
class ChunkEmbedding(BaseModel):
|
||||
full_embedding: Embedding
|
||||
mini_chunk_embeddings: list[Embedding]
|
||||
@@ -36,15 +34,17 @@ class DocAwareChunk(BaseChunk):
|
||||
# During inference we only have access to the document id and do not reconstruct the Document
|
||||
source_document: Document
|
||||
|
||||
# The Vespa documents require a separate highlight field. Since it is stored as a duplicate anyway,
|
||||
# it's easier to just store a not prefixed/suffixed string for the highlighting
|
||||
# Also during the chunking, this non-prefixed/suffixed string is used for mini-chunks
|
||||
content_summary: str
|
||||
# This could be an empty string if the title is too long and taking up too much of the chunk
|
||||
# This does not mean necessarily that the document does not have a title
|
||||
title_prefix: str
|
||||
|
||||
# During indexing we also (optionally) build a metadata string from the metadata dict
|
||||
# This is also indexed so that we can strip it out after indexing, this way it supports
|
||||
# multiple iterations of metadata representation for backwards compatibility
|
||||
metadata_suffix: str
|
||||
metadata_suffix_semantic: str
|
||||
metadata_suffix_keyword: str
|
||||
|
||||
mini_chunk_texts: list[str] | None
|
||||
|
||||
def to_short_descriptor(self) -> str:
|
||||
"""Used when logging the identity of a chunk"""
|
||||
|
||||
@@ -34,8 +34,8 @@ from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import message_generator_to_string_generator
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.tools.custom.custom_tool_prompt_builder import (
|
||||
build_user_message_for_custom_tool_for_non_tool_calling_llm,
|
||||
)
|
||||
@@ -99,6 +99,7 @@ class Answer:
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
llm: LLM,
|
||||
prompt_config: PromptConfig,
|
||||
force_use_tool: ForceUseTool,
|
||||
# must be the same length as `docs`. If None, all docs are considered "relevant"
|
||||
message_history: list[PreviousMessage] | None = None,
|
||||
single_message_history: str | None = None,
|
||||
@@ -107,10 +108,8 @@ class Answer:
|
||||
latest_query_files: list[InMemoryChatFile] | None = None,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
tools: list[Tool] | None = None,
|
||||
# if specified, tells the LLM to always this tool
|
||||
# NOTE: for native tool-calling, this is only supported by OpenAI atm,
|
||||
# but we only support them anyways
|
||||
force_use_tool: ForceUseTool | None = None,
|
||||
# if set to True, then never use the LLMs provided tool-calling functonality
|
||||
skip_explicit_tool_calling: bool = False,
|
||||
# Returns the full document sections text from the search tool
|
||||
@@ -129,6 +128,7 @@ class Answer:
|
||||
|
||||
self.tools = tools or []
|
||||
self.force_use_tool = force_use_tool
|
||||
|
||||
self.skip_explicit_tool_calling = skip_explicit_tool_calling
|
||||
|
||||
self.message_history = message_history or []
|
||||
@@ -139,7 +139,10 @@ class Answer:
|
||||
self.prompt_config = prompt_config
|
||||
|
||||
self.llm = llm
|
||||
self.llm_tokenizer = get_default_llm_tokenizer()
|
||||
self.llm_tokenizer = get_tokenizer(
|
||||
provider_type=llm.config.model_provider,
|
||||
model_name=llm.config.model_name,
|
||||
)
|
||||
|
||||
self._final_prompt: list[BaseMessage] | None = None
|
||||
|
||||
@@ -187,7 +190,7 @@ class Answer:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
if self.force_use_tool and self.force_use_tool.args is not None:
|
||||
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
|
||||
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
|
||||
# / need to generate the args
|
||||
tool_call_chunk = AIMessageChunk(
|
||||
@@ -221,7 +224,7 @@ class Answer:
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=final_tool_definitions if final_tool_definitions else None,
|
||||
tool_choice="required" if self.force_use_tool else None,
|
||||
tool_choice="required" if self.force_use_tool.force_use else None,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
@@ -240,12 +243,26 @@ class Answer:
|
||||
# if we have a tool call, we need to call the tool
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
for tool_call_request in tool_call_requests:
|
||||
tool = [
|
||||
known_tools_by_name = [
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
][0]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
if self.tools:
|
||||
tool = self.tools[0]
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
tool = known_tools_by_name[0]
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool and self.force_use_tool.args
|
||||
if self.force_use_tool.tool_name == tool.name
|
||||
and self.force_use_tool.args
|
||||
else tool_call_request["args"]
|
||||
)
|
||||
|
||||
@@ -263,9 +280,13 @@ class Answer:
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
self._update_prompt_builder_for_search_tool(prompt_builder, [])
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = [
|
||||
img_generation_result["url"]
|
||||
for img_generation_result in tool_runner.tool_final_result().tool_result
|
||||
]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question,
|
||||
query=self.question, img_urls=img_urls
|
||||
)
|
||||
)
|
||||
yield tool_runner.tool_final_result()
|
||||
@@ -286,7 +307,7 @@ class Answer:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
chosen_tool_and_args: tuple[Tool, dict] | None = None
|
||||
|
||||
if self.force_use_tool:
|
||||
if self.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = next(
|
||||
iter(
|
||||
@@ -303,7 +324,7 @@ class Answer:
|
||||
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.args
|
||||
if self.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
|
||||
@@ -16,6 +16,7 @@ from danswer.configs.constants import MessageType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import ChatMessage
|
||||
@@ -32,6 +33,7 @@ class PreviousMessage(BaseModel):
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
@@ -49,6 +51,14 @@ class PreviousMessage(BaseModel):
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
|
||||
@@ -12,8 +12,8 @@ from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import check_message_tokens
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import drop_messages_history_overflow
|
||||
@@ -66,7 +66,10 @@ class AnswerPromptBuilder:
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||
self.user_message_and_token_cnt: tuple[HumanMessage, int] | None = None
|
||||
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
llm_tokenizer = get_tokenizer(
|
||||
provider_type=llm_config.model_provider,
|
||||
model_name=llm_config.model_name,
|
||||
)
|
||||
self.llm_tokenizer_encode_func = cast(
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
@@ -111,8 +114,24 @@ class AnswerPromptBuilder:
|
||||
final_messages_with_tokens.append(self.user_message_and_token_cnt)
|
||||
|
||||
if tool_call_summary:
|
||||
final_messages_with_tokens.append((tool_call_summary.tool_call_request, 0))
|
||||
final_messages_with_tokens.append((tool_call_summary.tool_call_result, 0))
|
||||
final_messages_with_tokens.append(
|
||||
(
|
||||
tool_call_summary.tool_call_request,
|
||||
check_message_tokens(
|
||||
tool_call_summary.tool_call_request,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
)
|
||||
final_messages_with_tokens.append(
|
||||
(
|
||||
tool_call_summary.tool_call_result,
|
||||
check_message_tokens(
|
||||
tool_call_summary.tool_call_result,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
|
||||
@@ -14,8 +14,8 @@ from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import tokenizer_trim_content
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
@@ -135,8 +135,12 @@ def _apply_pruning(
|
||||
is_manually_selected_docs: bool,
|
||||
use_sections: bool,
|
||||
using_tool_message: bool,
|
||||
llm_config: LLMConfig,
|
||||
) -> list[InferenceSection]:
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
llm_tokenizer = get_tokenizer(
|
||||
provider_type=llm_config.model_provider,
|
||||
model_name=llm_config.model_name,
|
||||
)
|
||||
sections = deepcopy(sections) # don't modify in place
|
||||
|
||||
# re-order docs with all the "relevant" docs at the front
|
||||
@@ -165,14 +169,15 @@ def _apply_pruning(
|
||||
)
|
||||
)
|
||||
|
||||
section_tokens = len(llm_tokenizer.encode(section_str))
|
||||
section_token_count = len(llm_tokenizer.encode(section_str))
|
||||
# if not using sections (specifically, using Sections where each section maps exactly to the one center chunk),
|
||||
# truncate chunks that are way too long. This can happen if the embedding model tokenizer is different
|
||||
# than the LLM tokenizer
|
||||
if (
|
||||
not is_manually_selected_docs
|
||||
and not use_sections
|
||||
and section_tokens > DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE
|
||||
and section_token_count
|
||||
> DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE
|
||||
):
|
||||
logger.warning(
|
||||
"Found more tokens in Section than expected, "
|
||||
@@ -183,9 +188,9 @@ def _apply_pruning(
|
||||
desired_length=DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
section_tokens = DOC_EMBEDDING_CONTEXT_SIZE
|
||||
section_token_count = DOC_EMBEDDING_CONTEXT_SIZE
|
||||
|
||||
total_tokens += section_tokens
|
||||
total_tokens += section_token_count
|
||||
if total_tokens > token_limit:
|
||||
final_section_ind = ind
|
||||
break
|
||||
@@ -273,6 +278,7 @@ def prune_sections(
|
||||
is_manually_selected_docs=document_pruning_config.is_manually_selected_docs,
|
||||
use_sections=document_pruning_config.use_sections, # Now default True
|
||||
using_tool_message=document_pruning_config.using_tool_message,
|
||||
llm_config=llm_config,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -266,8 +266,14 @@ class DefaultMultiLLM(LLM):
|
||||
stream=stream,
|
||||
# model params
|
||||
temperature=self._temperature,
|
||||
max_tokens=self._max_output_tokens,
|
||||
max_tokens=self._max_output_tokens
|
||||
if self._max_output_tokens > 0
|
||||
else None,
|
||||
timeout=self._timeout,
|
||||
# For now, we don't support parallel tool calls
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
# or else OpenAI throws an error
|
||||
**({"parallel_tool_calls": False} if tools else {}),
|
||||
**self._model_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -70,6 +70,8 @@ def load_llm_providers(db_session: Session) -> None:
|
||||
FAST_GEN_AI_MODEL_VERSION or well_known_provider.default_fast_model
|
||||
),
|
||||
model_names=model_names,
|
||||
is_public=True,
|
||||
display_model_names=[],
|
||||
)
|
||||
llm_provider = upsert_llm_provider(db_session, llm_provider_request)
|
||||
update_default_provider(db_session, llm_provider.id)
|
||||
|
||||
@@ -31,9 +31,7 @@ OPEN_AI_MODEL_NAMES = [
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-1106-preview",
|
||||
"gpt-4-vision-preview",
|
||||
# "gpt-4-32k", # not EOL but still doesnt work
|
||||
"gpt-4-0613",
|
||||
# "gpt-4-32k-0613", # not EOL but still doesnt work
|
||||
"gpt-4-0314",
|
||||
"gpt-4-32k-0314",
|
||||
"gpt-3.5-turbo",
|
||||
@@ -48,9 +46,11 @@ OPEN_AI_MODEL_NAMES = [
|
||||
BEDROCK_PROVIDER_NAME = "bedrock"
|
||||
# need to remove all the weird "bedrock/eu-central-1/anthropic.claude-v1" named
|
||||
# models
|
||||
BEDROCK_MODEL_NAMES = [model for model in litellm.bedrock_models if "/" not in model][
|
||||
::-1
|
||||
]
|
||||
BEDROCK_MODEL_NAMES = [
|
||||
model
|
||||
for model in litellm.bedrock_models
|
||||
if "/" not in model and "embed" not in model
|
||||
][::-1]
|
||||
|
||||
IGNORABLE_ANTHROPIC_MODELS = [
|
||||
"claude-2",
|
||||
@@ -84,7 +84,7 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
custom_config_keys=[],
|
||||
llm_names=fetch_models_for_provider(OPENAI_PROVIDER_NAME),
|
||||
default_model="gpt-4",
|
||||
default_fast_model="gpt-3.5-turbo",
|
||||
default_fast_model="gpt-4o-mini",
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=ANTHROPIC_PROVIDER_NAME,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from copy import copy
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -16,10 +16,8 @@ from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from tiktoken.core import Encoding
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
@@ -28,7 +26,6 @@ from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.prompts.constants import CODE_BLOCK_PAT
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
|
||||
@@ -37,60 +34,17 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_LLM_TOKENIZER: Any = None
|
||||
_LLM_TOKENIZER_ENCODE: Callable[[str], Any] | None = None
|
||||
|
||||
|
||||
def get_default_llm_tokenizer() -> Encoding:
|
||||
"""Currently only supports the OpenAI default tokenizer: tiktoken"""
|
||||
global _LLM_TOKENIZER
|
||||
if _LLM_TOKENIZER is None:
|
||||
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base")
|
||||
return _LLM_TOKENIZER
|
||||
|
||||
|
||||
def get_default_llm_token_encode() -> Callable[[str], Any]:
|
||||
global _LLM_TOKENIZER_ENCODE
|
||||
if _LLM_TOKENIZER_ENCODE is None:
|
||||
tokenizer = get_default_llm_tokenizer()
|
||||
if isinstance(tokenizer, Encoding):
|
||||
return tokenizer.encode # type: ignore
|
||||
|
||||
# Currently only supports OpenAI encoder
|
||||
raise ValueError("Invalid Encoder selected")
|
||||
|
||||
return _LLM_TOKENIZER_ENCODE
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
content: str, desired_length: int, tokenizer: Encoding
|
||||
) -> str:
|
||||
tokens = tokenizer.encode(content)
|
||||
if len(tokens) > desired_length:
|
||||
content = tokenizer.decode(tokens[:desired_length])
|
||||
return content
|
||||
|
||||
|
||||
def tokenizer_trim_chunks(
|
||||
chunks: list[InferenceChunk], max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE
|
||||
) -> list[InferenceChunk]:
|
||||
tokenizer = get_default_llm_tokenizer()
|
||||
new_chunks = copy(chunks)
|
||||
for ind, chunk in enumerate(new_chunks):
|
||||
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
|
||||
if len(new_content) != len(chunk.content):
|
||||
new_chunk = copy(chunk)
|
||||
new_chunk.content = new_content
|
||||
new_chunks[ind] = new_chunk
|
||||
return new_chunks
|
||||
|
||||
|
||||
def translate_danswer_msg_to_langchain(
|
||||
msg: Union[ChatMessage, "PreviousMessage"],
|
||||
) -> BaseMessage:
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
# If the message is a `ChatMessage`, it doesn't have the downloaded files
|
||||
# attached. Just ignore them for now
|
||||
files = [] if isinstance(msg, ChatMessage) else msg.files
|
||||
# attached. Just ignore them for now. Also, OpenAI doesn't allow files to
|
||||
# be attached to AI messages, so we must remove them
|
||||
if not isinstance(msg, ChatMessage) and msg.message_type != MessageType.ASSISTANT:
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(msg.message, files)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
@@ -271,6 +225,13 @@ def check_message_tokens(
|
||||
elif part["type"] == "image_url":
|
||||
total_tokens += _IMG_TOKENS
|
||||
|
||||
if isinstance(message, AIMessage) and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
total_tokens += check_number_of_tokens(
|
||||
json.dumps(tool_call["args"]), encode_fn
|
||||
)
|
||||
total_tokens += check_number_of_tokens(tool_call["name"], encode_fn)
|
||||
|
||||
return total_tokens
|
||||
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.configs.constants import POSTGRES_WEB_APP_NAME
|
||||
from danswer.db.connector import create_initial_default_connector
|
||||
from danswer.db.connector_credential_pair import associate_default_cc_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
@@ -42,6 +43,7 @@ from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import init_sqlalchemy_engine
|
||||
from danswer.db.engine import warm_up_connections
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import expire_index_attempts
|
||||
@@ -50,8 +52,8 @@ from danswer.db.standard_answer import create_initial_default_standard_answer_ca
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.llm_initialization import load_llm_providers
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_encoders
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.server.auth_check import check_router_auth
|
||||
from danswer.server.danswer_api.ingestion import router as danswer_api_router
|
||||
from danswer.server.documents.cc_pair import router as cc_pair_router
|
||||
@@ -60,6 +62,10 @@ from danswer.server.documents.credential import router as credential_router
|
||||
from danswer.server.documents.document import router as document_router
|
||||
from danswer.server.features.document_set.api import router as document_set_router
|
||||
from danswer.server.features.folder.api import router as folder_router
|
||||
from danswer.server.features.input_prompt.api import (
|
||||
admin_router as admin_input_prompt_router,
|
||||
)
|
||||
from danswer.server.features.input_prompt.api import basic_router as input_prompt_router
|
||||
from danswer.server.features.persona.api import admin_router as admin_persona_router
|
||||
from danswer.server.features.persona.api import basic_router as persona_router
|
||||
from danswer.server.features.prompt.api import basic_router as prompt_router
|
||||
@@ -154,6 +160,7 @@ def include_router_with_global_prefix_prepended(
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
init_sqlalchemy_engine(POSTGRES_WEB_APP_NAME)
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
verify_auth = fetch_versioned_implementation(
|
||||
@@ -248,14 +255,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
logger.info(f"Waiting on Vespa, retrying in {wait_time} seconds...")
|
||||
time.sleep(wait_time)
|
||||
|
||||
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
warm_up_encoders(
|
||||
embedding_model=db_embedding_model,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
||||
yield
|
||||
@@ -284,6 +290,8 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, standard_answer_router)
|
||||
include_router_with_global_prefix_prepended(application, persona_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_persona_router)
|
||||
include_router_with_global_prefix_prepended(application, input_prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, tool_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_tool_router)
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
import gc
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import requests
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
from httpx import HTTPError
|
||||
|
||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.db.models import EmbeddingModel as DBEmbeddingModel
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.utils.batching import batch_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.model_server_models import Embedding
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
@@ -20,50 +21,13 @@ from shared_configs.model_server_models import IntentResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
|
||||
transformer_logging.set_verbosity_error()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
|
||||
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
|
||||
|
||||
|
||||
def clean_model_name(model_str: str) -> str:
|
||||
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
|
||||
|
||||
|
||||
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
|
||||
# for cases where this is more important, be sure to refresh with the actual model name
|
||||
# One case where it is not particularly important is in the document chunking flow,
|
||||
# they're basically all using the sentencepiece tokenizer and whether it's cased or
|
||||
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
|
||||
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
global _TOKENIZER
|
||||
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
|
||||
if _TOKENIZER[0] is not None:
|
||||
del _TOKENIZER
|
||||
gc.collect()
|
||||
|
||||
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
|
||||
|
||||
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
return _TOKENIZER[0]
|
||||
|
||||
|
||||
def build_model_server_url(
|
||||
model_server_host: str,
|
||||
model_server_port: int,
|
||||
@@ -91,6 +55,7 @@ class EmbeddingModel:
|
||||
provider_type: str | None,
|
||||
# The following are globals are currently not configurable
|
||||
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
retrim_content: bool = False,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.provider_type = provider_type
|
||||
@@ -99,32 +64,95 @@ class EmbeddingModel:
|
||||
self.passage_prefix = passage_prefix
|
||||
self.normalize = normalize
|
||||
self.model_name = model_name
|
||||
self.retrim_content = retrim_content
|
||||
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
|
||||
def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]:
|
||||
if text_type == EmbedTextType.QUERY and self.query_prefix:
|
||||
prefixed_texts = [self.query_prefix + text for text in texts]
|
||||
elif text_type == EmbedTextType.PASSAGE and self.passage_prefix:
|
||||
prefixed_texts = [self.passage_prefix + text for text in texts]
|
||||
else:
|
||||
prefixed_texts = texts
|
||||
def encode(
|
||||
self,
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
) -> list[Embedding]:
|
||||
if not texts or not all(texts):
|
||||
raise ValueError(f"Empty or missing text for embedding: {texts}")
|
||||
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=prefixed_texts,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
if self.retrim_content:
|
||||
# This is applied during indexing as a catchall for overly long titles (or other uncapped fields)
|
||||
# Note that this uses just the default tokenizer which may also lead to very minor miscountings
|
||||
# However this slight miscounting is very unlikely to have any material impact.
|
||||
texts = [
|
||||
tokenizer_trim_content(
|
||||
content=text,
|
||||
desired_length=self.max_seq_length,
|
||||
tokenizer=get_tokenizer(
|
||||
model_name=self.model_name,
|
||||
provider_type=self.provider_type,
|
||||
),
|
||||
)
|
||||
for text in texts
|
||||
]
|
||||
|
||||
if self.provider_type:
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=texts,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
)
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.dict()
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
|
||||
# Batching for local embedding
|
||||
text_batches = batch_list(texts, batch_size)
|
||||
embeddings: list[Embedding] = []
|
||||
logger.debug(
|
||||
f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model"
|
||||
)
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=text_batch,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
api_key=self.api_key,
|
||||
provider_type=self.provider_type,
|
||||
text_type=text_type,
|
||||
manual_query_prefix=self.query_prefix,
|
||||
manual_passage_prefix=self.passage_prefix,
|
||||
)
|
||||
|
||||
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
|
||||
response.raise_for_status()
|
||||
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.dict()
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
||||
# value for the set loss
|
||||
embeddings.extend(EmbedResponse(**response.json()).embeddings)
|
||||
return embeddings
|
||||
|
||||
|
||||
class CrossEncoderEnsembleModel:
|
||||
@@ -136,7 +164,7 @@ class CrossEncoderEnsembleModel:
|
||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
|
||||
|
||||
def predict(self, query: str, passages: list[str]) -> list[list[float]]:
|
||||
def predict(self, query: str, passages: list[str]) -> list[list[float] | None]:
|
||||
rerank_request = RerankRequest(query=query, documents=passages)
|
||||
|
||||
response = requests.post(
|
||||
@@ -171,35 +199,40 @@ class IntentModel:
|
||||
|
||||
|
||||
def warm_up_encoders(
|
||||
model_name: str,
|
||||
normalize: bool,
|
||||
embedding_model: DBEmbeddingModel,
|
||||
model_server_host: str = MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
) -> None:
|
||||
model_name = embedding_model.model_name
|
||||
normalize = embedding_model.normalize
|
||||
provider_type = embedding_model.provider_type
|
||||
warm_up_str = (
|
||||
"Danswer is amazing! Check out our easy deployment guide at "
|
||||
"https://docs.danswer.dev/quickstart"
|
||||
)
|
||||
|
||||
# May not be the exact same tokenizer used for the indexing flow
|
||||
get_default_tokenizer(model_name=model_name)(warm_up_str)
|
||||
logger.info(f"Warming up encoder model: {model_name}")
|
||||
get_tokenizer(model_name=model_name, provider_type=provider_type).encode(
|
||||
warm_up_str
|
||||
)
|
||||
|
||||
embed_model = EmbeddingModel(
|
||||
model_name=model_name,
|
||||
normalize=normalize,
|
||||
provider_type=provider_type,
|
||||
# Not a big deal if prefix is incorrect
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
server_host=model_server_host,
|
||||
server_port=model_server_port,
|
||||
api_key=None,
|
||||
provider_type=None,
|
||||
)
|
||||
|
||||
# First time downloading the models it may take even longer, but just in case,
|
||||
# retry the whole server
|
||||
wait_time = 5
|
||||
for attempt in range(20):
|
||||
for _ in range(20):
|
||||
try:
|
||||
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
||||
return
|
||||
149
backend/danswer/natural_language_processing/utils.py
Normal file
149
backend/danswer/natural_language_processing/utils.py
Normal file
@@ -0,0 +1,149 @@
|
||||
import os
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from copy import copy
|
||||
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
transformer_logging.set_verbosity_error()
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
||||
|
||||
|
||||
class BaseTokenizer(ABC):
|
||||
@abstractmethod
|
||||
def encode(self, string: str) -> list[int]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def tokenize(self, string: str) -> list[str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class TiktokenTokenizer(BaseTokenizer):
|
||||
_instances: dict[str, "TiktokenTokenizer"] = {}
|
||||
|
||||
def __new__(cls, encoding_name: str = "cl100k_base") -> "TiktokenTokenizer":
|
||||
if encoding_name not in cls._instances:
|
||||
cls._instances[encoding_name] = super(TiktokenTokenizer, cls).__new__(cls)
|
||||
return cls._instances[encoding_name]
|
||||
|
||||
def __init__(self, encoding_name: str = "cl100k_base"):
|
||||
if not hasattr(self, "encoder"):
|
||||
import tiktoken
|
||||
|
||||
self.encoder = tiktoken.get_encoding(encoding_name)
|
||||
|
||||
def encode(self, string: str) -> list[int]:
|
||||
# this returns no special tokens
|
||||
return self.encoder.encode_ordinary(string)
|
||||
|
||||
def tokenize(self, string: str) -> list[str]:
|
||||
return [self.encoder.decode([token]) for token in self.encode(string)]
|
||||
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
return self.encoder.decode(tokens)
|
||||
|
||||
|
||||
class HuggingFaceTokenizer(BaseTokenizer):
|
||||
def __init__(self, model_name: str):
|
||||
from tokenizers import Tokenizer # type: ignore
|
||||
|
||||
self.encoder = Tokenizer.from_pretrained(model_name)
|
||||
|
||||
def encode(self, string: str) -> list[int]:
|
||||
# this returns no special tokens
|
||||
return self.encoder.encode(string, add_special_tokens=False).ids
|
||||
|
||||
def tokenize(self, string: str) -> list[str]:
|
||||
return self.encoder.encode(string, add_special_tokens=False).tokens
|
||||
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
return self.encoder.decode(tokens)
|
||||
|
||||
|
||||
_TOKENIZER_CACHE: dict[str, BaseTokenizer] = {}
|
||||
|
||||
|
||||
def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
|
||||
global _TOKENIZER_CACHE
|
||||
|
||||
if tokenizer_name not in _TOKENIZER_CACHE:
|
||||
if tokenizer_name == "openai":
|
||||
_TOKENIZER_CACHE[tokenizer_name] = TiktokenTokenizer("cl100k_base")
|
||||
return _TOKENIZER_CACHE[tokenizer_name]
|
||||
try:
|
||||
logger.debug(f"Initializing HuggingFaceTokenizer for: {tokenizer_name}")
|
||||
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(tokenizer_name)
|
||||
except Exception as primary_error:
|
||||
logger.error(
|
||||
f"Error initializing HuggingFaceTokenizer for {tokenizer_name}: {primary_error}"
|
||||
)
|
||||
logger.warning(
|
||||
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Cache this tokenizer name to the default so we don't have to try to load it again
|
||||
# and fail again
|
||||
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(
|
||||
DOCUMENT_ENCODER_MODEL
|
||||
)
|
||||
except Exception as fallback_error:
|
||||
logger.error(
|
||||
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to initialize tokenizer for {tokenizer_name} and fallback model"
|
||||
) from fallback_error
|
||||
|
||||
return _TOKENIZER_CACHE[tokenizer_name]
|
||||
|
||||
|
||||
def get_tokenizer(model_name: str | None, provider_type: str | None) -> BaseTokenizer:
|
||||
if provider_type:
|
||||
if provider_type.lower() == "openai":
|
||||
# Used across ada and text-embedding-3 models
|
||||
return _check_tokenizer_cache("openai")
|
||||
# If we are given a cloud provider_type that isn't OpenAI, we default to trying to use the model_name
|
||||
# this means we are approximating the token count which may leave some performance on the table
|
||||
|
||||
if not model_name:
|
||||
raise ValueError("Need to provide a model_name or provider_type")
|
||||
|
||||
return _check_tokenizer_cache(model_name)
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
content: str, desired_length: int, tokenizer: BaseTokenizer
|
||||
) -> str:
|
||||
tokens = tokenizer.encode(content)
|
||||
if len(tokens) > desired_length:
|
||||
content = tokenizer.decode(tokens[:desired_length])
|
||||
return content
|
||||
|
||||
|
||||
def tokenizer_trim_chunks(
|
||||
chunks: list[InferenceChunk],
|
||||
tokenizer: BaseTokenizer,
|
||||
max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
) -> list[InferenceChunk]:
|
||||
new_chunks = copy(chunks)
|
||||
for ind, chunk in enumerate(new_chunks):
|
||||
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
|
||||
if len(new_content) != len(chunk.content):
|
||||
new_chunk = copy(chunk)
|
||||
new_chunk.content = new_content
|
||||
new_chunks[ind] = new_chunk
|
||||
return new_chunks
|
||||
@@ -34,7 +34,7 @@ from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.models import QuotesConfig
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.one_shot_answer.models import QueryRephrase
|
||||
@@ -117,8 +117,12 @@ def stream_answer_objects(
|
||||
one_shot=True,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
)
|
||||
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona)
|
||||
|
||||
llm_tokenizer = get_default_llm_token_encode()
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
|
||||
# Create a chat session which will just store the root message, the query, and the AI response
|
||||
root_message = get_or_create_root_message(
|
||||
@@ -126,10 +130,12 @@ def stream_answer_objects(
|
||||
)
|
||||
|
||||
history_str = combine_message_thread(
|
||||
messages=history, max_tokens=max_history_tokens
|
||||
messages=history,
|
||||
max_tokens=max_history_tokens,
|
||||
llm_tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
rephrased_query = thread_based_query_rephrase(
|
||||
rephrased_query = query_req.query_override or thread_based_query_rephrase(
|
||||
user_query=query_msg.message,
|
||||
history_str=history_str,
|
||||
)
|
||||
@@ -158,13 +164,12 @@ def stream_answer_objects(
|
||||
parent_message=root_message,
|
||||
prompt_id=query_req.prompt_id,
|
||||
message=query_msg.message,
|
||||
token_count=len(llm_tokenizer(query_msg.message)),
|
||||
token_count=len(llm_tokenizer.encode(query_msg.message)),
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona)
|
||||
prompt_config = PromptConfig.from_model(prompt)
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
max_chunks=int(
|
||||
@@ -206,6 +211,7 @@ def stream_answer_objects(
|
||||
single_message_history=history_str,
|
||||
tools=[search_tool],
|
||||
force_use_tool=ForceUseTool(
|
||||
force_use=True,
|
||||
tool_name=search_tool.name,
|
||||
args={"query": rephrased_query},
|
||||
),
|
||||
@@ -256,6 +262,9 @@ def stream_answer_objects(
|
||||
)
|
||||
yield initial_response
|
||||
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
||||
yield packet.response
|
||||
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
chunk_indices = packet.response
|
||||
|
||||
@@ -268,9 +277,6 @@ def stream_answer_objects(
|
||||
|
||||
yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response)
|
||||
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
||||
yield packet.response
|
||||
|
||||
elif packet.id == SEARCH_EVALUATION_ID:
|
||||
evaluation_response = LLMRelevanceSummaryResponse(
|
||||
relevance_summaries=packet.response
|
||||
@@ -291,7 +297,7 @@ def stream_answer_objects(
|
||||
parent_message=new_user_message,
|
||||
prompt_id=query_req.prompt_id,
|
||||
message=answer.llm_answer,
|
||||
token_count=len(llm_tokenizer(answer.llm_answer)),
|
||||
token_count=len(llm_tokenizer.encode(answer.llm_answer)),
|
||||
message_type=MessageType.ASSISTANT,
|
||||
error=None,
|
||||
reference_docs=reference_db_search_docs,
|
||||
|
||||
@@ -34,10 +34,17 @@ class DirectQARequest(ChunkContext):
|
||||
skip_llm_chunk_filter: bool | None = None
|
||||
chain_of_thought: bool = False
|
||||
return_contexts: bool = False
|
||||
|
||||
# allows the caller to specify the exact search query they want to use
|
||||
# can be used if the message sent to the LLM / query should not be the same
|
||||
# will also disable Thread-based Rewording if specified
|
||||
query_override: str | None = None
|
||||
|
||||
# This is to toggle agentic evaluation:
|
||||
# 1. Evaluates whether each response is relevant or not
|
||||
# 2. Provides a summary of the document's relevance in the resulsts
|
||||
llm_doc_eval: bool = False
|
||||
|
||||
# If True, skips generative an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -18,7 +17,7 @@ def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
|
||||
def combine_message_thread(
|
||||
messages: list[ThreadMessage],
|
||||
max_tokens: int | None,
|
||||
llm_tokenizer: Callable | None = None,
|
||||
llm_tokenizer: BaseTokenizer,
|
||||
) -> str:
|
||||
"""Used to create a single combined message context from threads"""
|
||||
if not messages:
|
||||
@@ -26,8 +25,6 @@ def combine_message_thread(
|
||||
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
if llm_tokenizer is None:
|
||||
llm_tokenizer = get_default_llm_token_encode()
|
||||
|
||||
for message in reversed(messages):
|
||||
if message.role == MessageType.USER:
|
||||
@@ -42,7 +39,7 @@ def combine_message_thread(
|
||||
role_str = message.role.value.upper()
|
||||
|
||||
msg_str = f"{role_str}:\n{message.message}"
|
||||
message_token_count = len(llm_tokenizer(msg_str))
|
||||
message_token_count = len(llm_tokenizer.encode(msg_str))
|
||||
|
||||
if (
|
||||
max_tokens is not None
|
||||
|
||||
@@ -126,6 +126,7 @@ class InferenceChunk(BaseChunk):
|
||||
document_id: str
|
||||
source_type: DocumentSource
|
||||
semantic_identifier: str
|
||||
title: str | None # Separate from Semantic Identifier though often same
|
||||
boost: int
|
||||
recency_bias: float
|
||||
score: float | None
|
||||
@@ -193,16 +194,16 @@ class InferenceChunk(BaseChunk):
|
||||
|
||||
|
||||
class InferenceChunkUncleaned(InferenceChunk):
|
||||
title: str | None # Separate from Semantic Identifier though often same
|
||||
metadata_suffix: str | None
|
||||
|
||||
def to_inference_chunk(self) -> InferenceChunk:
|
||||
# Create a dict of all fields except 'title' and 'metadata_suffix'
|
||||
# Create a dict of all fields except 'metadata_suffix'
|
||||
# Assumes the cleaning has already been applied and just needs to translate to the right type
|
||||
inference_chunk_data = {
|
||||
k: v
|
||||
for k, v in self.dict().items()
|
||||
if k not in ["title", "metadata_suffix"]
|
||||
if k
|
||||
not in ["metadata_suffix"] # May be other fields to throw out in the future
|
||||
}
|
||||
return InferenceChunk(**inference_chunk_data)
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -154,6 +155,7 @@ class SearchPipeline:
|
||||
|
||||
return cast(list[InferenceChunk], self._retrieved_chunks)
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def _get_sections(self) -> list[InferenceSection]:
|
||||
"""Returns an expanded section from each of the chunks.
|
||||
If whole docs (instead of above/below context) is specified then it will give back all of the whole docs
|
||||
@@ -173,9 +175,11 @@ class SearchPipeline:
|
||||
expanded_inference_sections = []
|
||||
|
||||
# Full doc setting takes priority
|
||||
|
||||
if self.search_query.full_doc:
|
||||
seen_document_ids = set()
|
||||
unique_chunks = []
|
||||
|
||||
# This preserves the ordering since the chunks are retrieved in score order
|
||||
for chunk in retrieved_chunks:
|
||||
if chunk.document_id not in seen_document_ids:
|
||||
@@ -195,7 +199,6 @@ class SearchPipeline:
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
list_inference_chunks = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=False
|
||||
)
|
||||
@@ -240,32 +243,35 @@ class SearchPipeline:
|
||||
merged_ranges = [
|
||||
merge_chunk_intervals(ranges) for ranges in doc_chunk_ranges_map.values()
|
||||
]
|
||||
flat_ranges = [r for ranges in merged_ranges for r in ranges]
|
||||
|
||||
flat_ranges: list[ChunkRange] = [r for ranges in merged_ranges for r in ranges]
|
||||
flattened_inference_chunks: list[InferenceChunk] = []
|
||||
parallel_functions_with_args = []
|
||||
|
||||
for chunk_range in flat_ranges:
|
||||
functions_with_args.append(
|
||||
(
|
||||
# If Large Chunks are introduced, additional filters need to be added here
|
||||
self.document_index.id_based_retrieval,
|
||||
(
|
||||
# Only need the document_id here, just use any chunk in the range is fine
|
||||
chunk_range.chunks[0].document_id,
|
||||
chunk_range.start,
|
||||
chunk_range.end,
|
||||
# There is no chunk level permissioning, this expansion around chunks
|
||||
# can be assumed to be safe
|
||||
IndexFilters(access_control_list=None),
|
||||
),
|
||||
)
|
||||
)
|
||||
# Don't need to fetch chunks within range for merging if chunk_above / below are 0.
|
||||
if above == below == 0:
|
||||
flattened_inference_chunks.extend(chunk_range.chunks)
|
||||
|
||||
# list of list of inference chunks where the inner list needs to be combined for content
|
||||
list_inference_chunks = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=False
|
||||
)
|
||||
flattened_inference_chunks = [
|
||||
chunk for sublist in list_inference_chunks for chunk in sublist
|
||||
]
|
||||
else:
|
||||
parallel_functions_with_args.append(
|
||||
(
|
||||
self.document_index.id_based_retrieval,
|
||||
(
|
||||
chunk_range.chunks[0].document_id,
|
||||
chunk_range.start,
|
||||
chunk_range.end,
|
||||
IndexFilters(access_control_list=None),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
if parallel_functions_with_args:
|
||||
list_inference_chunks = run_functions_tuples_in_parallel(
|
||||
parallel_functions_with_args, allow_failures=False
|
||||
)
|
||||
for inference_chunks in list_inference_chunks:
|
||||
flattened_inference_chunks.extend(inference_chunks)
|
||||
|
||||
doc_chunk_ind_to_chunk = {
|
||||
(chunk.document_id, chunk.chunk_id): chunk
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import cast
|
||||
|
||||
import numpy
|
||||
|
||||
from danswer.configs.constants import MAX_CHUNK_TITLE_LEN
|
||||
from danswer.configs.app_configs import BLURB_SIZE
|
||||
from danswer.configs.constants import RETURN_SEPARATOR
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||
@@ -12,6 +12,9 @@ from danswer.document_index.document_index_utils import (
|
||||
translate_boost_count_to_multiplier,
|
||||
)
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.natural_language_processing.search_nlp_models import (
|
||||
CrossEncoderEnsembleModel,
|
||||
)
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceChunkUncleaned
|
||||
@@ -20,7 +23,6 @@ from danswer.search.models import MAX_METRICS_CONTENT
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
|
||||
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
@@ -58,8 +60,14 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
|
||||
if chunk.content.startswith(chunk.title):
|
||||
return chunk.content[len(chunk.title) :].lstrip()
|
||||
|
||||
if chunk.content.startswith(chunk.title[:MAX_CHUNK_TITLE_LEN]):
|
||||
return chunk.content[MAX_CHUNK_TITLE_LEN:].lstrip()
|
||||
# BLURB SIZE is by token instead of char but each token is at least 1 char
|
||||
# If this prefix matches the content, it's assumed the title was prepended
|
||||
if chunk.content.startswith(chunk.title[:BLURB_SIZE]):
|
||||
return (
|
||||
chunk.content.split(RETURN_SEPARATOR, 1)[-1]
|
||||
if RETURN_SEPARATOR in chunk.content
|
||||
else chunk.content
|
||||
)
|
||||
|
||||
return chunk.content
|
||||
|
||||
@@ -91,7 +99,11 @@ def semantic_reranking(
|
||||
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
|
||||
"""
|
||||
cross_encoders = CrossEncoderEnsembleModel()
|
||||
passages = [chunk.content for chunk in chunks]
|
||||
|
||||
passages = [
|
||||
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
|
||||
for chunk in chunks
|
||||
]
|
||||
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
|
||||
|
||||
sim_scores = [numpy.array(scores) for scores in sim_scores_floats]
|
||||
@@ -224,6 +236,12 @@ def search_postprocessing(
|
||||
) -> Iterator[list[InferenceSection] | list[int]]:
|
||||
post_processing_tasks: list[FunctionCall] = []
|
||||
|
||||
if not retrieved_sections:
|
||||
# Avoids trying to rerank an empty list which throws an error
|
||||
yield []
|
||||
yield []
|
||||
return
|
||||
|
||||
rerank_task_id = None
|
||||
sections_yielded = False
|
||||
if should_rerank(search_query):
|
||||
|
||||
@@ -1,26 +1,20 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.natural_language_processing.search_nlp_models import IntentModel
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
|
||||
from danswer.search.search_nlp_models import get_default_tokenizer
|
||||
from danswer.search.search_nlp_models import IntentModel
|
||||
from danswer.server.query_and_chat.models import HelperResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import AutoTokenizer # type:ignore
|
||||
|
||||
|
||||
def count_unk_tokens(text: str, tokenizer: "AutoTokenizer") -> int:
|
||||
def count_unk_tokens(text: str, tokenizer: BaseTokenizer) -> int:
|
||||
"""Unclear if the wordpiece/sentencepiece tokenizer used is actually tokenizing anything as the [UNK] token
|
||||
It splits up even foreign characters and unicode emojis without using UNK"""
|
||||
tokenized_text = tokenizer.tokenize(text)
|
||||
num_unk_tokens = len(
|
||||
[token for token in tokenized_text if token == tokenizer.unk_token]
|
||||
)
|
||||
num_unk_tokens = len([token for token in tokenized_text if token == "[UNK]"])
|
||||
logger.debug(f"Total of {num_unk_tokens} UNKNOWN tokens found")
|
||||
return num_unk_tokens
|
||||
|
||||
@@ -74,7 +68,12 @@ def recommend_search_flow(
|
||||
|
||||
# UNK tokens -> suggest Keyword (still may be valid QA)
|
||||
# TODO do a better job with the classifier model and retire the heuristics
|
||||
if count_unk_tokens(query, get_default_tokenizer(model_name=model_name)) > 0:
|
||||
if (
|
||||
count_unk_tokens(
|
||||
query, get_tokenizer(model_name=model_name, provider_type=None)
|
||||
)
|
||||
> 0
|
||||
):
|
||||
if not keyword:
|
||||
heuristic_search_type = SearchType.KEYWORD
|
||||
message = "Unknown tokens in query."
|
||||
|
||||
@@ -11,6 +11,7 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import InferenceChunk
|
||||
@@ -20,7 +21,6 @@ from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.postprocessing.postprocessing import cleanup_chunks
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.search.utils import inference_section_from_chunks
|
||||
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -94,7 +94,7 @@ def history_based_query_rephrase(
|
||||
llm: LLM,
|
||||
size_heuristic: int = 200,
|
||||
punctuation_heuristic: int = 10,
|
||||
skip_first_rephrase: bool = False,
|
||||
skip_first_rephrase: bool = True,
|
||||
prompt_template: str = HISTORY_QUERY_REPHRASE,
|
||||
) -> str:
|
||||
# Globally disabled, just use the exact user query
|
||||
|
||||
@@ -96,6 +96,8 @@ def upsert_ingestion_doc(
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
api_key=db_embedding_model.api_key,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
@@ -105,8 +107,8 @@ def upsert_ingestion_doc(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
new_doc, chunks = indexing_pipeline(
|
||||
documents=[document],
|
||||
new_doc, __chunk_count = indexing_pipeline(
|
||||
document_batch=[document],
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
@@ -132,6 +134,8 @@ def upsert_ingestion_doc(
|
||||
normalize=sec_db_embedding_model.normalize,
|
||||
query_prefix=sec_db_embedding_model.query_prefix,
|
||||
passage_prefix=sec_db_embedding_model.passage_prefix,
|
||||
api_key=sec_db_embedding_model.api_key,
|
||||
provider_type=sec_db_embedding_model.provider_type,
|
||||
)
|
||||
|
||||
sec_ind_pipeline = build_indexing_pipeline(
|
||||
@@ -142,7 +146,7 @@ def upsert_ingestion_doc(
|
||||
)
|
||||
|
||||
sec_ind_pipeline(
|
||||
documents=[document],
|
||||
document_batch=[document],
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
|
||||
@@ -12,7 +12,7 @@ from danswer.db.connector_credential_pair import get_connector_credential_pair_f
|
||||
from danswer.db.connector_credential_pair import remove_credential_from_connector
|
||||
from danswer.db.document import get_document_cnts_for_cc_pairs
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
|
||||
from danswer.db.index_attempt import get_index_attempts_for_connector
|
||||
from danswer.db.models import User
|
||||
from danswer.server.documents.models import CCPairFullInfo
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
@@ -43,9 +43,9 @@ def get_cc_pair_full_info(
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
index_attempts = get_index_attempts_for_cc_pair(
|
||||
db_session=db_session,
|
||||
cc_pair_identifier=cc_pair_identifier,
|
||||
index_attempts = get_index_attempts_for_connector(
|
||||
db_session,
|
||||
cc_pair.connector_id,
|
||||
)
|
||||
|
||||
document_count_info_list = list(
|
||||
|
||||
@@ -51,6 +51,8 @@ from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector import get_connector_credential_ids
|
||||
from danswer.db.connector import update_connector
|
||||
from danswer.db.connector_credential_pair import add_credential_to_connector
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.credentials import create_credential
|
||||
from danswer.db.credentials import delete_gmail_service_account_credentials
|
||||
@@ -64,6 +66,7 @@ from danswer.db.index_attempt import cancel_indexing_attempts_for_connector
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
|
||||
from danswer.db.index_attempt import get_latest_finished_index_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import get_latest_index_attempts
|
||||
from danswer.db.models import User
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
@@ -74,6 +77,7 @@ from danswer.server.documents.models import ConnectorBase
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorIndexingStatus
|
||||
from danswer.server.documents.models import ConnectorSnapshot
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import CredentialSnapshot
|
||||
from danswer.server.documents.models import FileUploadResponse
|
||||
from danswer.server.documents.models import GDriveCallback
|
||||
@@ -263,7 +267,8 @@ def upsert_service_account_credential(
|
||||
`Credential` table."""
|
||||
try:
|
||||
credential_base = build_service_account_creds(
|
||||
delegated_user_email=service_account_credential_request.google_drive_delegated_user
|
||||
DocumentSource.GOOGLE_DRIVE,
|
||||
delegated_user_email=service_account_credential_request.google_drive_delegated_user,
|
||||
)
|
||||
except ConfigNotFoundError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -288,7 +293,8 @@ def upsert_gmail_service_account_credential(
|
||||
`Credential` table."""
|
||||
try:
|
||||
credential_base = build_service_account_creds(
|
||||
delegated_user_email=service_account_credential_request.gmail_delegated_user
|
||||
DocumentSource.GMAIL,
|
||||
delegated_user_email=service_account_credential_request.gmail_delegated_user,
|
||||
)
|
||||
except ConfigNotFoundError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
@@ -386,8 +392,12 @@ def get_connector_indexing_status(
|
||||
secondary_index=secondary_index,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
cc_pair_to_latest_index_attempt = {
|
||||
(index_attempt.connector_id, index_attempt.credential_id): index_attempt
|
||||
(
|
||||
index_attempt.connector_credential_pair.connector_id,
|
||||
index_attempt.connector_credential_pair.credential_id,
|
||||
): index_attempt
|
||||
for index_attempt in latest_index_attempts
|
||||
}
|
||||
|
||||
@@ -410,6 +420,11 @@ def get_connector_indexing_status(
|
||||
latest_index_attempt = cc_pair_to_latest_index_attempt.get(
|
||||
(connector.id, credential.id)
|
||||
)
|
||||
|
||||
latest_finished_attempt = get_latest_finished_index_attempt_for_cc_pair(
|
||||
connector_credential_pair_id=cc_pair.id, db_session=db_session
|
||||
)
|
||||
|
||||
indexing_statuses.append(
|
||||
ConnectorIndexingStatus(
|
||||
cc_pair_id=cc_pair.id,
|
||||
@@ -418,6 +433,9 @@ def get_connector_indexing_status(
|
||||
credential=CredentialSnapshot.from_credential_db_model(credential),
|
||||
public_doc=cc_pair.is_public,
|
||||
owner=credential.user.email if credential.user else "",
|
||||
last_finished_status=latest_finished_attempt.status
|
||||
if latest_finished_attempt
|
||||
else None,
|
||||
last_status=latest_index_attempt.status
|
||||
if latest_index_attempt
|
||||
else None,
|
||||
@@ -480,6 +498,35 @@ def create_connector_from_model(
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/admin/connector-with-mock-credential")
|
||||
def create_connector_with_mock_credential(
|
||||
connector_data: ConnectorBase,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
try:
|
||||
_validate_connector_allowed(connector_data.source)
|
||||
connector_response = create_connector(connector_data, db_session)
|
||||
mock_credential = CredentialBase(
|
||||
credential_json={}, admin_public=True, source=connector_data.source
|
||||
)
|
||||
credential = create_credential(
|
||||
mock_credential, user=user, db_session=db_session
|
||||
)
|
||||
response = add_credential_to_connector(
|
||||
connector_id=cast(int, connector_response.id), # will aways be an int
|
||||
credential_id=credential.id,
|
||||
is_public=True,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
cc_pair_name=connector_data.name,
|
||||
)
|
||||
return response
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@router.patch("/admin/connector/{connector_id}")
|
||||
def update_connector_from_model(
|
||||
connector_id: int,
|
||||
@@ -515,6 +562,7 @@ def update_connector_from_model(
|
||||
credential_ids=[
|
||||
association.credential.id for association in updated_connector.credentials
|
||||
],
|
||||
indexing_start=updated_connector.indexing_start,
|
||||
time_created=updated_connector.time_created,
|
||||
time_updated=updated_connector.time_updated,
|
||||
disabled=updated_connector.disabled,
|
||||
@@ -542,6 +590,7 @@ def connector_run_once(
|
||||
) -> StatusResponse[list[int]]:
|
||||
connector_id = run_info.connector_id
|
||||
specified_credential_ids = run_info.credential_ids
|
||||
|
||||
try:
|
||||
possible_credential_ids = get_connector_credential_ids(
|
||||
run_info.connector_id, db_session
|
||||
@@ -585,16 +634,21 @@ def connector_run_once(
|
||||
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
connector_credential_pairs = [
|
||||
get_connector_credential_pair(run_info.connector_id, credential_id, db_session)
|
||||
for credential_id in credential_ids
|
||||
if credential_id not in skipped_credentials
|
||||
]
|
||||
|
||||
index_attempt_ids = [
|
||||
create_index_attempt(
|
||||
connector_id=run_info.connector_id,
|
||||
credential_id=credential_id,
|
||||
connector_credential_pair_id=connector_credential_pair.id,
|
||||
embedding_model_id=embedding_model.id,
|
||||
from_beginning=run_info.from_beginning,
|
||||
db_session=db_session,
|
||||
)
|
||||
for credential_id in credential_ids
|
||||
if credential_id not in skipped_credentials
|
||||
for connector_credential_pair in connector_credential_pairs
|
||||
if connector_credential_pair is not None
|
||||
]
|
||||
|
||||
if not index_attempt_ids:
|
||||
@@ -724,6 +778,7 @@ def get_connector_by_id(
|
||||
id=connector.id,
|
||||
name=connector.name,
|
||||
source=connector.source,
|
||||
indexing_start=connector.indexing_start,
|
||||
input_type=connector.input_type,
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
refresh_freq=connector.refresh_freq,
|
||||
|
||||
@@ -6,15 +6,21 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.credentials import alter_credential
|
||||
from danswer.db.credentials import create_credential
|
||||
from danswer.db.credentials import delete_credential
|
||||
from danswer.db.credentials import fetch_credential_by_id
|
||||
from danswer.db.credentials import fetch_credentials
|
||||
from danswer.db.credentials import fetch_credentials_by_source
|
||||
from danswer.db.credentials import swap_credentials_connector
|
||||
from danswer.db.credentials import update_credential
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import DocumentSource
|
||||
from danswer.db.models import User
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import CredentialDataUpdateRequest
|
||||
from danswer.server.documents.models import CredentialSnapshot
|
||||
from danswer.server.documents.models import CredentialSwapRequest
|
||||
from danswer.server.documents.models import ObjectCreationIdResponse
|
||||
from danswer.server.models import StatusResponse
|
||||
|
||||
@@ -38,6 +44,34 @@ def list_credentials_admin(
|
||||
]
|
||||
|
||||
|
||||
@router.get("/admin/similar-credentials/{source_type}")
|
||||
def get_cc_source_full_info(
|
||||
source_type: DocumentSource,
|
||||
user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[CredentialSnapshot]:
|
||||
credentials = fetch_credentials_by_source(
|
||||
db_session=db_session, user=user, document_source=source_type
|
||||
)
|
||||
|
||||
return [
|
||||
CredentialSnapshot.from_credential_db_model(credential)
|
||||
for credential in credentials
|
||||
]
|
||||
|
||||
|
||||
@router.get("/credentials/{id}")
|
||||
def list_credentials_by_id(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[CredentialSnapshot]:
|
||||
credentials = fetch_credentials(db_session=db_session, user=user)
|
||||
return [
|
||||
CredentialSnapshot.from_credential_db_model(credential)
|
||||
for credential in credentials
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/admin/credential/{credential_id}")
|
||||
def delete_credential_by_id_admin(
|
||||
credential_id: int,
|
||||
@@ -51,6 +85,26 @@ def delete_credential_by_id_admin(
|
||||
)
|
||||
|
||||
|
||||
@router.put("/admin/credentials/swap")
|
||||
def swap_credentials_for_connector(
|
||||
credentail_swap_req: CredentialSwapRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
connector_credential_pair = swap_credentials_connector(
|
||||
new_credential_id=credentail_swap_req.new_credential_id,
|
||||
connector_id=credentail_swap_req.connector_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Credential swapped successfully",
|
||||
data=connector_credential_pair.id,
|
||||
)
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
|
||||
|
||||
@@ -79,7 +133,11 @@ def create_credential_from_model(
|
||||
)
|
||||
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
return ObjectCreationIdResponse(id=credential.id)
|
||||
|
||||
return ObjectCreationIdResponse(
|
||||
id=credential.id,
|
||||
credential=CredentialSnapshot.from_credential_db_model(credential),
|
||||
)
|
||||
|
||||
|
||||
@router.get("/credential/{credential_id}")
|
||||
@@ -98,6 +156,24 @@ def get_credential_by_id(
|
||||
return CredentialSnapshot.from_credential_db_model(credential)
|
||||
|
||||
|
||||
@router.put("/admin/credentials/{credential_id}")
|
||||
def update_credential_data(
|
||||
credential_id: int,
|
||||
credential_update: CredentialDataUpdateRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CredentialBase:
|
||||
credential = alter_credential(credential_id, credential_update, user, db_session)
|
||||
|
||||
if credential is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Credential {credential_id} does not exist or does not belong to user",
|
||||
)
|
||||
|
||||
return CredentialSnapshot.from_credential_db_model(credential)
|
||||
|
||||
|
||||
@router.patch("/credential/{credential_id}")
|
||||
def update_credential_from_model(
|
||||
credential_id: int,
|
||||
@@ -115,6 +191,7 @@ def update_credential_from_model(
|
||||
)
|
||||
|
||||
return CredentialSnapshot(
|
||||
source=updated_credential.source,
|
||||
id=updated_credential.id,
|
||||
credential_json=updated_credential.credential_json,
|
||||
user_id=updated_credential.user_id,
|
||||
@@ -130,7 +207,25 @@ def delete_credential_by_id(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
delete_credential(credential_id, user, db_session)
|
||||
delete_credential(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
)
|
||||
|
||||
return StatusResponse(
|
||||
success=True, message="Credential deleted successfully", data=credential_id
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/credential/force/{credential_id}")
|
||||
def force_delete_credential_by_id(
|
||||
credential_id: int,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
delete_credential(credential_id, user, db_session, True)
|
||||
|
||||
return StatusResponse(
|
||||
success=True, message="Credential deleted successfully", data=credential_id
|
||||
)
|
||||
|
||||
@@ -9,7 +9,7 @@ from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
|
||||
from danswer.server.documents.models import ChunkInfo
|
||||
@@ -50,7 +50,10 @@ def get_document_info(
|
||||
|
||||
# get actual document context used for LLM
|
||||
first_chunk = inference_chunks[0]
|
||||
tokenizer_encode = get_default_llm_token_encode()
|
||||
tokenizer_encode = get_tokenizer(
|
||||
provider_type=embedding_model.provider_type,
|
||||
model_name=embedding_model.model_name,
|
||||
).encode
|
||||
full_context_str = build_doc_context_str(
|
||||
semantic_identifier=first_chunk.semantic_identifier,
|
||||
source_type=first_chunk.source_type,
|
||||
@@ -92,7 +95,10 @@ def get_chunk_info(
|
||||
|
||||
chunk_content = inference_chunks[0].content
|
||||
|
||||
tokenizer_encode = get_default_llm_token_encode()
|
||||
tokenizer_encode = get_tokenizer(
|
||||
provider_type=embedding_model.provider_type,
|
||||
model_name=embedding_model.model_name,
|
||||
).encode
|
||||
|
||||
return ChunkInfo(
|
||||
content=chunk_content, num_tokens=len(tokenizer_encode(chunk_content))
|
||||
|
||||
@@ -26,6 +26,92 @@ class ChunkInfo(BaseModel):
|
||||
num_tokens: int
|
||||
|
||||
|
||||
class DeletionAttemptSnapshot(BaseModel):
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
status: TaskStatus
|
||||
|
||||
|
||||
class ConnectorBase(BaseModel):
|
||||
name: str
|
||||
source: DocumentSource
|
||||
input_type: InputType
|
||||
connector_specific_config: dict[str, Any]
|
||||
refresh_freq: int | None # In seconds, None for one time index with no refresh
|
||||
prune_freq: int | None
|
||||
disabled: bool
|
||||
indexing_start: datetime | None
|
||||
|
||||
|
||||
class ConnectorSnapshot(ConnectorBase):
|
||||
id: int
|
||||
credential_ids: list[int]
|
||||
time_created: datetime
|
||||
time_updated: datetime
|
||||
source: DocumentSource
|
||||
|
||||
@classmethod
|
||||
def from_connector_db_model(cls, connector: Connector) -> "ConnectorSnapshot":
|
||||
return ConnectorSnapshot(
|
||||
id=connector.id,
|
||||
name=connector.name,
|
||||
source=connector.source,
|
||||
input_type=connector.input_type,
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
refresh_freq=connector.refresh_freq,
|
||||
prune_freq=connector.prune_freq,
|
||||
credential_ids=[
|
||||
association.credential.id for association in connector.credentials
|
||||
],
|
||||
indexing_start=connector.indexing_start,
|
||||
time_created=connector.time_created,
|
||||
time_updated=connector.time_updated,
|
||||
disabled=connector.disabled,
|
||||
)
|
||||
|
||||
|
||||
class CredentialSwapRequest(BaseModel):
|
||||
new_credential_id: int
|
||||
connector_id: int
|
||||
|
||||
|
||||
class CredentialDataUpdateRequest(BaseModel):
|
||||
name: str
|
||||
credential_json: dict[str, Any]
|
||||
|
||||
|
||||
class CredentialBase(BaseModel):
|
||||
credential_json: dict[str, Any]
|
||||
# if `true`, then all Admins will have access to the credential
|
||||
admin_public: bool
|
||||
source: DocumentSource
|
||||
name: str | None = None
|
||||
|
||||
|
||||
class CredentialSnapshot(CredentialBase):
|
||||
id: int
|
||||
user_id: UUID | None
|
||||
time_created: datetime
|
||||
time_updated: datetime
|
||||
|
||||
@classmethod
|
||||
def from_credential_db_model(cls, credential: Credential) -> "CredentialSnapshot":
|
||||
return CredentialSnapshot(
|
||||
id=credential.id,
|
||||
credential_json=(
|
||||
mask_credential_dict(credential.credential_json)
|
||||
if MASK_CREDENTIAL_PREFIX
|
||||
else credential.credential_json
|
||||
),
|
||||
user_id=credential.user_id,
|
||||
admin_public=credential.admin_public,
|
||||
time_created=credential.time_created,
|
||||
time_updated=credential.time_updated,
|
||||
source=credential.source or DocumentSource.NOT_APPLICABLE,
|
||||
name=credential.name,
|
||||
)
|
||||
|
||||
|
||||
class IndexAttemptSnapshot(BaseModel):
|
||||
id: int
|
||||
status: IndexingStatus | None
|
||||
@@ -49,80 +135,15 @@ class IndexAttemptSnapshot(BaseModel):
|
||||
docs_removed_from_index=index_attempt.docs_removed_from_index or 0,
|
||||
error_msg=index_attempt.error_msg,
|
||||
full_exception_trace=index_attempt.full_exception_trace,
|
||||
time_started=index_attempt.time_started.isoformat()
|
||||
if index_attempt.time_started
|
||||
else None,
|
||||
time_started=(
|
||||
index_attempt.time_started.isoformat()
|
||||
if index_attempt.time_started
|
||||
else None
|
||||
),
|
||||
time_updated=index_attempt.time_updated.isoformat(),
|
||||
)
|
||||
|
||||
|
||||
class DeletionAttemptSnapshot(BaseModel):
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
status: TaskStatus
|
||||
|
||||
|
||||
class ConnectorBase(BaseModel):
|
||||
name: str
|
||||
source: DocumentSource
|
||||
input_type: InputType
|
||||
connector_specific_config: dict[str, Any]
|
||||
refresh_freq: int | None # In seconds, None for one time index with no refresh
|
||||
prune_freq: int | None
|
||||
disabled: bool
|
||||
|
||||
|
||||
class ConnectorSnapshot(ConnectorBase):
|
||||
id: int
|
||||
credential_ids: list[int]
|
||||
time_created: datetime
|
||||
time_updated: datetime
|
||||
|
||||
@classmethod
|
||||
def from_connector_db_model(cls, connector: Connector) -> "ConnectorSnapshot":
|
||||
return ConnectorSnapshot(
|
||||
id=connector.id,
|
||||
name=connector.name,
|
||||
source=connector.source,
|
||||
input_type=connector.input_type,
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
refresh_freq=connector.refresh_freq,
|
||||
prune_freq=connector.prune_freq,
|
||||
credential_ids=[
|
||||
association.credential.id for association in connector.credentials
|
||||
],
|
||||
time_created=connector.time_created,
|
||||
time_updated=connector.time_updated,
|
||||
disabled=connector.disabled,
|
||||
)
|
||||
|
||||
|
||||
class CredentialBase(BaseModel):
|
||||
credential_json: dict[str, Any]
|
||||
# if `true`, then all Admins will have access to the credential
|
||||
admin_public: bool
|
||||
|
||||
|
||||
class CredentialSnapshot(CredentialBase):
|
||||
id: int
|
||||
user_id: UUID | None
|
||||
time_created: datetime
|
||||
time_updated: datetime
|
||||
|
||||
@classmethod
|
||||
def from_credential_db_model(cls, credential: Credential) -> "CredentialSnapshot":
|
||||
return CredentialSnapshot(
|
||||
id=credential.id,
|
||||
credential_json=mask_credential_dict(credential.credential_json)
|
||||
if MASK_CREDENTIAL_PREFIX
|
||||
else credential.credential_json,
|
||||
user_id=credential.user_id,
|
||||
admin_public=credential.admin_public,
|
||||
time_created=credential.time_created,
|
||||
time_updated=credential.time_updated,
|
||||
)
|
||||
|
||||
|
||||
class CCPairFullInfo(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
@@ -167,6 +188,7 @@ class ConnectorIndexingStatus(BaseModel):
|
||||
credential: CredentialSnapshot
|
||||
owner: str
|
||||
public_doc: bool
|
||||
last_finished_status: IndexingStatus | None
|
||||
last_status: IndexingStatus | None
|
||||
last_success: datetime | None
|
||||
docs_indexed: int
|
||||
@@ -242,6 +264,7 @@ class FileUploadResponse(BaseModel):
|
||||
|
||||
class ObjectCreationIdResponse(BaseModel):
|
||||
id: int | str
|
||||
credential: CredentialSnapshot | None = None
|
||||
|
||||
|
||||
class AuthStatus(BaseModel):
|
||||
|
||||
134
backend/danswer/server/features/input_prompt/api.py
Normal file
134
backend/danswer/server/features/input_prompt/api.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.input_prompt import fetch_input_prompt_by_id
|
||||
from danswer.db.input_prompt import fetch_input_prompts_by_user
|
||||
from danswer.db.input_prompt import fetch_public_input_prompts
|
||||
from danswer.db.input_prompt import insert_input_prompt
|
||||
from danswer.db.input_prompt import remove_input_prompt
|
||||
from danswer.db.input_prompt import remove_public_input_prompt
|
||||
from danswer.db.input_prompt import update_input_prompt
|
||||
from danswer.db.models import User
|
||||
from danswer.server.features.input_prompt.models import CreateInputPromptRequest
|
||||
from danswer.server.features.input_prompt.models import InputPromptSnapshot
|
||||
from danswer.server.features.input_prompt.models import UpdateInputPromptRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
basic_router = APIRouter(prefix="/input_prompt")
|
||||
admin_router = APIRouter(prefix="/admin/input_prompt")
|
||||
|
||||
|
||||
@basic_router.get("")
|
||||
def list_input_prompts(
|
||||
user: User | None = Depends(current_user),
|
||||
include_public: bool = False,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[InputPromptSnapshot]:
|
||||
user_prompts = fetch_input_prompts_by_user(
|
||||
user_id=user.id if user is not None else None,
|
||||
db_session=db_session,
|
||||
include_public=include_public,
|
||||
)
|
||||
return [InputPromptSnapshot.from_model(prompt) for prompt in user_prompts]
|
||||
|
||||
|
||||
@basic_router.get("/{input_prompt_id}")
|
||||
def get_input_prompt(
|
||||
input_prompt_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> InputPromptSnapshot:
|
||||
input_prompt = fetch_input_prompt_by_id(
|
||||
id=input_prompt_id,
|
||||
user_id=user.id if user is not None else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
return InputPromptSnapshot.from_model(input_prompt=input_prompt)
|
||||
|
||||
|
||||
@basic_router.post("")
|
||||
def create_input_prompt(
|
||||
create_input_prompt_request: CreateInputPromptRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> InputPromptSnapshot:
|
||||
input_prompt = insert_input_prompt(
|
||||
prompt=create_input_prompt_request.prompt,
|
||||
content=create_input_prompt_request.content,
|
||||
is_public=create_input_prompt_request.is_public,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
return InputPromptSnapshot.from_model(input_prompt)
|
||||
|
||||
|
||||
@basic_router.patch("/{input_prompt_id}")
|
||||
def patch_input_prompt(
|
||||
input_prompt_id: int,
|
||||
update_input_prompt_request: UpdateInputPromptRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> InputPromptSnapshot:
|
||||
try:
|
||||
updated_input_prompt = update_input_prompt(
|
||||
user=user,
|
||||
input_prompt_id=input_prompt_id,
|
||||
prompt=update_input_prompt_request.prompt,
|
||||
content=update_input_prompt_request.content,
|
||||
active=update_input_prompt_request.active,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError as e:
|
||||
error_msg = "Error occurred while updated input prompt"
|
||||
logger.warn(f"{error_msg}. Stack trace: {e}")
|
||||
raise HTTPException(status_code=404, detail=error_msg)
|
||||
|
||||
return InputPromptSnapshot.from_model(updated_input_prompt)
|
||||
|
||||
|
||||
@basic_router.delete("/{input_prompt_id}")
|
||||
def delete_input_prompt(
|
||||
input_prompt_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
remove_input_prompt(user, input_prompt_id, db_session)
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = "Error occurred while deleting input prompt"
|
||||
logger.warn(f"{error_msg}. Stack trace: {e}")
|
||||
raise HTTPException(status_code=404, detail=error_msg)
|
||||
|
||||
|
||||
@admin_router.delete("/{input_prompt_id}")
|
||||
def delete_public_input_prompt(
|
||||
input_prompt_id: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
remove_public_input_prompt(input_prompt_id, db_session)
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = "Error occurred while deleting input prompt"
|
||||
logger.warn(f"{error_msg}. Stack trace: {e}")
|
||||
raise HTTPException(status_code=404, detail=error_msg)
|
||||
|
||||
|
||||
@admin_router.get("")
|
||||
def list_public_input_prompts(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[InputPromptSnapshot]:
|
||||
user_prompts = fetch_public_input_prompts(
|
||||
db_session=db_session,
|
||||
)
|
||||
return [InputPromptSnapshot.from_model(prompt) for prompt in user_prompts]
|
||||
47
backend/danswer/server/features/input_prompt/models.py
Normal file
47
backend/danswer/server/features/input_prompt/models.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.db.models import InputPrompt
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class CreateInputPromptRequest(BaseModel):
|
||||
prompt: str
|
||||
content: str
|
||||
is_public: bool
|
||||
|
||||
|
||||
class UpdateInputPromptRequest(BaseModel):
|
||||
prompt: str
|
||||
content: str
|
||||
active: bool
|
||||
|
||||
|
||||
class InputPromptResponse(BaseModel):
|
||||
id: int
|
||||
prompt: str
|
||||
content: str
|
||||
active: bool
|
||||
|
||||
|
||||
class InputPromptSnapshot(BaseModel):
|
||||
id: int
|
||||
prompt: str
|
||||
content: str
|
||||
active: bool
|
||||
user_id: UUID | None
|
||||
is_public: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, input_prompt: InputPrompt) -> "InputPromptSnapshot":
|
||||
return InputPromptSnapshot(
|
||||
id=input_prompt.id,
|
||||
prompt=input_prompt.prompt,
|
||||
content=input_prompt.content,
|
||||
active=input_prompt.active,
|
||||
user_id=input_prompt.user_id,
|
||||
is_public=input_prompt.is_public,
|
||||
)
|
||||
@@ -1,12 +1,15 @@
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import create_update_persona
|
||||
@@ -17,6 +20,8 @@ from danswer.db.persona import mark_persona_as_not_deleted
|
||||
from danswer.db.persona import update_all_personas_display_priority
|
||||
from danswer.db.persona import update_persona_shared_users
|
||||
from danswer.db.persona import update_persona_visibility
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.llm.answering.prompts.utils import build_dummy_prompt
|
||||
from danswer.server.features.persona.models import CreatePersonaRequest
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
@@ -24,6 +29,7 @@ from danswer.server.features.persona.models import PromptTemplateResponse
|
||||
from danswer.server.models import DisplayPriorityRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -90,6 +96,26 @@ def undelete_persona(
|
||||
)
|
||||
|
||||
|
||||
# used for assistat profile pictures
|
||||
@admin_router.post("/upload-image")
|
||||
def upload_file(
|
||||
file: UploadFile,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_user),
|
||||
) -> dict[str, str]:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_type = ChatFileType.IMAGE
|
||||
file_id = str(uuid.uuid4())
|
||||
file_store.save_file(
|
||||
file_name=file_id,
|
||||
content=file.file,
|
||||
display_name=file.filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=file.content_type or file_type.value,
|
||||
)
|
||||
return {"file_id": file_id}
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
|
||||
|
||||
|
||||
@@ -33,6 +33,9 @@ class CreatePersonaRequest(BaseModel):
|
||||
# For Private Personas, who should be able to access these
|
||||
users: list[UUID] | None = None
|
||||
groups: list[int] | None = None
|
||||
icon_color: str | None = None
|
||||
icon_shape: int | None = None
|
||||
uploaded_image_id: str | None = None # New field for uploaded image
|
||||
|
||||
|
||||
class PersonaSnapshot(BaseModel):
|
||||
@@ -55,6 +58,9 @@ class PersonaSnapshot(BaseModel):
|
||||
document_sets: list[DocumentSet]
|
||||
users: list[MinimalUserSnapshot]
|
||||
groups: list[int]
|
||||
icon_color: str | None
|
||||
icon_shape: int | None
|
||||
uploaded_image_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -97,6 +103,9 @@ class PersonaSnapshot(BaseModel):
|
||||
for user in persona.users
|
||||
],
|
||||
groups=[user_group.id for user_group in persona.groups],
|
||||
icon_color=persona.icon_color,
|
||||
icon_shape=persona.icon_shape,
|
||||
uploaded_image_id=persona.uploaded_image_id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from danswer.db.llm import fetch_existing_embedding_providers
|
||||
from danswer.db.llm import remove_embedding_provider
|
||||
from danswer.db.llm import upsert_cloud_embedding_provider
|
||||
from danswer.db.models import User
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.embedding.models import TestEmbeddingRequest
|
||||
@@ -42,7 +42,7 @@ def test_embedding_configuration(
|
||||
passage_prefix=None,
|
||||
model_name=None,
|
||||
)
|
||||
test_model.encode(["Test String"], text_type=EmbedTextType.QUERY)
|
||||
test_model.encode(["Testing Embedding"], text_type=EmbedTextType.QUERY)
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = f"Not a valid embedding model. Exception thrown: {e}"
|
||||
|
||||
@@ -147,10 +147,10 @@ def set_provider_as_default(
|
||||
|
||||
@basic_router.get("/provider")
|
||||
def list_llm_provider_basics(
|
||||
_: User | None = Depends(current_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
return [
|
||||
LLMProviderDescriptor.from_model(llm_provider_model)
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session)
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session, user)
|
||||
]
|
||||
|
||||
@@ -32,6 +32,7 @@ class LLMProviderDescriptor(BaseModel):
|
||||
default_model_name: str
|
||||
fast_default_model_name: str | None
|
||||
is_default_provider: bool | None
|
||||
display_model_names: list[str] | None
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -48,6 +49,7 @@ class LLMProviderDescriptor(BaseModel):
|
||||
or fetch_models_for_provider(llm_provider_model.provider)
|
||||
or [llm_provider_model.default_model_name]
|
||||
),
|
||||
display_model_names=llm_provider_model.display_model_names,
|
||||
)
|
||||
|
||||
|
||||
@@ -60,6 +62,9 @@ class LLMProvider(BaseModel):
|
||||
custom_config: dict[str, str] | None
|
||||
default_model_name: str
|
||||
fast_default_model_name: str | None
|
||||
is_public: bool = True
|
||||
groups: list[int] | None = None
|
||||
display_model_names: list[str] | None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
@@ -86,9 +91,12 @@ class FullLLMProvider(LLMProvider):
|
||||
default_model_name=llm_provider_model.default_model_name,
|
||||
fast_default_model_name=llm_provider_model.fast_default_model_name,
|
||||
is_default_provider=llm_provider_model.is_default_provider,
|
||||
display_model_names=llm_provider_model.display_model_names,
|
||||
model_names=(
|
||||
llm_provider_model.model_names
|
||||
or fetch_models_for_provider(llm_provider_model.provider)
|
||||
or [llm_provider_model.default_model_name]
|
||||
),
|
||||
is_public=llm_provider_model.is_public,
|
||||
groups=[group.id for group in llm_provider_model.groups],
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -14,13 +15,15 @@ from danswer.db.models import SlackBotConfig as SlackBotConfigModel
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import StandardAnswer as StandardAnswerModel
|
||||
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
|
||||
from danswer.db.models import User
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
from danswer.server.models import FullUserSnapshot
|
||||
from danswer.server.models import InvitedUserSnapshot
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import User as UserModel
|
||||
pass
|
||||
|
||||
|
||||
class VersionResponse(BaseModel):
|
||||
@@ -46,9 +49,17 @@ class UserInfo(BaseModel):
|
||||
is_verified: bool
|
||||
role: UserRole
|
||||
preferences: UserPreferences
|
||||
oidc_expiry: datetime | None = None
|
||||
current_token_created_at: datetime | None = None
|
||||
current_token_expiry_length: int | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user: "UserModel") -> "UserInfo":
|
||||
def from_model(
|
||||
cls,
|
||||
user: User,
|
||||
current_token_created_at: datetime | None = None,
|
||||
expiry_length: int | None = None,
|
||||
) -> "UserInfo":
|
||||
return cls(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
@@ -57,6 +68,9 @@ class UserInfo(BaseModel):
|
||||
is_verified=user.is_verified,
|
||||
role=user.role,
|
||||
preferences=(UserPreferences(chosen_assistants=user.chosen_assistants)),
|
||||
oidc_expiry=user.oidc_expiry,
|
||||
current_token_created_at=current_token_created_at,
|
||||
current_token_expiry_length=expiry_length,
|
||||
)
|
||||
|
||||
|
||||
@@ -151,7 +165,9 @@ class SlackBotConfigCreationRequest(BaseModel):
|
||||
# by an optional `PersonaSnapshot` object. Keeping it like this
|
||||
# for now for simplicity / speed of development
|
||||
document_sets: list[int] | None
|
||||
persona_id: int | None # NOTE: only one of `document_sets` / `persona_id` should be set
|
||||
persona_id: (
|
||||
int | None
|
||||
) # NOTE: only one of `document_sets` / `persona_id` should be set
|
||||
channel_names: list[str]
|
||||
respond_tag_only: bool = False
|
||||
respond_to_bots: bool = False
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Body
|
||||
@@ -6,6 +8,9 @@ from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -19,9 +24,11 @@ from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.auth.users import optional_user
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import User
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.db.users import list_users
|
||||
@@ -117,9 +124,9 @@ def list_all_users(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
status=UserStatus.LIVE
|
||||
if user.is_active
|
||||
else UserStatus.DEACTIVATED,
|
||||
status=(
|
||||
UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED
|
||||
),
|
||||
)
|
||||
for user in users
|
||||
],
|
||||
@@ -246,9 +253,35 @@ async def get_user_role(user: User = Depends(current_user)) -> UserRoleResponse:
|
||||
return UserRoleResponse(role=user.role)
|
||||
|
||||
|
||||
def get_current_token_creation(
|
||||
user: User | None, db_session: Session
|
||||
) -> datetime | None:
|
||||
if user is None:
|
||||
return None
|
||||
try:
|
||||
result = db_session.execute(
|
||||
select(AccessToken)
|
||||
.where(AccessToken.user_id == user.id) # type: ignore
|
||||
.order_by(desc(Column("created_at")))
|
||||
.limit(1)
|
||||
)
|
||||
access_token = result.scalar_one_or_none()
|
||||
|
||||
if access_token:
|
||||
return access_token.created_at
|
||||
else:
|
||||
logger.error("No AccessToken found for user")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching AccessToken: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
def verify_user_logged_in(
|
||||
user: User | None = Depends(optional_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserInfo:
|
||||
# NOTE: this does not use `current_user` / `current_admin_user` because we don't want
|
||||
# to enforce user verification here - the frontend always wants to get the info about
|
||||
@@ -264,7 +297,20 @@ def verify_user_logged_in(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail="User Not Authenticated"
|
||||
)
|
||||
|
||||
return UserInfo.from_model(user)
|
||||
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
token_created_at = get_current_token_creation(user, db_session)
|
||||
user_info = UserInfo.from_model(
|
||||
user,
|
||||
current_token_created_at=token_created_at,
|
||||
expiry_length=SESSION_EXPIRE_TIME_SECONDS,
|
||||
)
|
||||
|
||||
return user_info
|
||||
|
||||
|
||||
"""APIs to adjust user preferences"""
|
||||
|
||||
@@ -44,8 +44,9 @@ from danswer.llm.answering.prompts.citations_prompt import (
|
||||
)
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.headers import get_litellm_additional_request_headers
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.secondary_llm_flows.chat_session_naming import (
|
||||
get_renamed_conversation_name,
|
||||
)
|
||||
@@ -442,6 +443,14 @@ def seed_chat(
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=new_chat_session.id, db_session=db_session
|
||||
)
|
||||
llm, fast_llm = get_llms_for_persona(persona=new_chat_session.persona)
|
||||
|
||||
tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
token_count = len(tokenizer.encode(chat_seed_request.message))
|
||||
|
||||
create_new_chat_message(
|
||||
chat_session_id=new_chat_session.id,
|
||||
parent_message=root_message,
|
||||
@@ -452,9 +461,7 @@ def seed_chat(
|
||||
else None
|
||||
),
|
||||
message=chat_seed_request.message,
|
||||
token_count=len(
|
||||
get_default_llm_tokenizer().encode(chat_seed_request.message)
|
||||
),
|
||||
token_count=token_count,
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@@ -90,7 +90,7 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
parent_message_id: int | None
|
||||
# New message contents
|
||||
message: str
|
||||
# file's that we should attach to this message
|
||||
# Files that we should attach to this message
|
||||
file_descriptors: list[FileDescriptor]
|
||||
# If no prompt provided, uses the largest prompt of the chat session
|
||||
# but really this should be explicitly specified, only in the simplified APIs is this inferred
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from openai import BaseModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
REQUEST_BODY = "requestBody"
|
||||
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
class ForceUseTool(BaseModel):
|
||||
# Could be not a forced usage of the tool but still have args, in which case
|
||||
# if the tool is called, then those args are applied instead of what the LLM
|
||||
# wanted to call it with
|
||||
force_use: bool
|
||||
tool_name: str
|
||||
args: dict[str, Any] | None = None
|
||||
|
||||
@@ -16,25 +18,10 @@ class ForceUseTool(BaseModel):
|
||||
return {"type": "function", "function": {"name": self.tool_name}}
|
||||
|
||||
|
||||
def modify_message_chain_for_force_use_tool(
|
||||
messages: list[BaseMessage], force_use_tool: ForceUseTool | None = None
|
||||
) -> list[BaseMessage]:
|
||||
"""NOTE: modifies `messages` in place."""
|
||||
if not force_use_tool:
|
||||
return messages
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message, AIMessage) and message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
tool_call["args"] = force_use_tool.args or {}
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def filter_tools_for_force_tool_use(
|
||||
tools: list[Tool], force_use_tool: ForceUseTool | None = None
|
||||
tools: list[Tool], force_use_tool: ForceUseTool
|
||||
) -> list[Tool]:
|
||||
if not force_use_tool:
|
||||
if not force_use_tool.force_use:
|
||||
return tools
|
||||
|
||||
return [tool for tool in tools if tool.name == force_use_tool.tool_name]
|
||||
|
||||
@@ -156,7 +156,9 @@ class ImageGenerationTool(Tool):
|
||||
for image_generation in image_generations
|
||||
]
|
||||
),
|
||||
img_urls=[image_generation.url for image_generation in image_generations],
|
||||
# NOTE: we can't pass in the image URLs here, since OpenAI doesn't allow
|
||||
# Tool messages to contain images
|
||||
# img_urls=[image_generation.url for image_generation in image_generations],
|
||||
)
|
||||
|
||||
def _generate_image(self, prompt: str) -> ImageGenerationResponse:
|
||||
|
||||
@@ -3,31 +3,19 @@ from langchain_core.messages import HumanMessage
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
|
||||
|
||||
NON_TOOL_CALLING_PROMPT = """
|
||||
You have just created the attached images in response to the following query: "{{query}}".
|
||||
IMG_GENERATION_SUMMARY_PROMPT = """
|
||||
You have just created the attached images in response to the following query: "{query}".
|
||||
|
||||
Can you please summarize them in a sentence or two?
|
||||
"""
|
||||
|
||||
TOOL_CALLING_PROMPT = """
|
||||
Can you please summarize the two images you generate in a sentence or two?
|
||||
Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists.
|
||||
"""
|
||||
|
||||
|
||||
def build_image_generation_user_prompt(
|
||||
query: str, img_urls: list[str] | None = None
|
||||
) -> HumanMessage:
|
||||
if img_urls:
|
||||
return HumanMessage(
|
||||
content=build_content_with_imgs(
|
||||
message=NON_TOOL_CALLING_PROMPT.format(query=query).strip(),
|
||||
img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
|
||||
return HumanMessage(
|
||||
content=build_content_with_imgs(
|
||||
message=TOOL_CALLING_PROMPT.strip(),
|
||||
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
|
||||
img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -6,7 +6,7 @@ from langchain_core.messages.tool import ToolCall
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
|
||||
|
||||
def build_tool_message(
|
||||
@@ -24,9 +24,9 @@ class ToolCallSummary(BaseModel):
|
||||
tool_call_result: ToolMessage
|
||||
|
||||
|
||||
def tool_call_tokens(tool_call_summary: ToolCallSummary) -> int:
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
|
||||
def tool_call_tokens(
|
||||
tool_call_summary: ToolCallSummary, llm_tokenizer: BaseTokenizer
|
||||
) -> int:
|
||||
request_tokens = len(
|
||||
llm_tokenizer.encode(
|
||||
json.dumps(tool_call_summary.tool_call_request.tool_calls[0]["args"])
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import json
|
||||
|
||||
from tiktoken import Encoding
|
||||
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
@@ -22,13 +20,9 @@ def explicit_tool_calling_supported(model_provider: str, model_name: str) -> boo
|
||||
return False
|
||||
|
||||
|
||||
def compute_tool_tokens(tool: Tool, llm_tokenizer: Encoding | None = None) -> int:
|
||||
if not llm_tokenizer:
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
def compute_tool_tokens(tool: Tool, llm_tokenizer: BaseTokenizer) -> int:
|
||||
return len(llm_tokenizer.encode(json.dumps(tool.tool_definition())))
|
||||
|
||||
|
||||
def compute_all_tool_tokens(
|
||||
tools: list[Tool], llm_tokenizer: Encoding | None = None
|
||||
) -> int:
|
||||
def compute_all_tool_tokens(tools: list[Tool], llm_tokenizer: BaseTokenizer) -> int:
|
||||
return sum(compute_tool_tokens(tool, llm_tokenizer) for tool in tools)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user