mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
180 Commits
eval/split
...
v0.4.27
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
680388537b | ||
|
|
d9bcacfae7 | ||
|
|
2ab192933b | ||
|
|
1c10f54294 | ||
|
|
0530f4283e | ||
|
|
3540aa579b | ||
|
|
54732a83c9 | ||
|
|
5e6365c449 | ||
|
|
20369fc451 | ||
|
|
f15d6d2b59 | ||
|
|
5dda047999 | ||
|
|
ffd9b0180b | ||
|
|
5ad54fec87 | ||
|
|
d636181aa5 | ||
|
|
e12ed7750a | ||
|
|
bbb8c5ff0b | ||
|
|
83e945ba57 | ||
|
|
26df869b91 | ||
|
|
1a4df1d65e | ||
|
|
0a165aae0b | ||
|
|
e517f47a89 | ||
|
|
c7e5b11c63 | ||
|
|
79523f2e0a | ||
|
|
7fae66b766 | ||
|
|
386b229ed3 | ||
|
|
ce666f3320 | ||
|
|
d60fb15ad3 | ||
|
|
7358ece008 | ||
|
|
9c5d33e198 | ||
|
|
7d5cfd2fa3 | ||
|
|
a4caf66a35 | ||
|
|
0a8d44b44c | ||
|
|
cc8a6da8e3 | ||
|
|
54d4526b73 | ||
|
|
c8ead6a0dc | ||
|
|
7bfa99766d | ||
|
|
b230082891 | ||
|
|
8cd1eda8b1 | ||
|
|
7dcc42aa95 | ||
|
|
e59d1a0294 | ||
|
|
384e61f4b0 | ||
|
|
f28b930475 | ||
|
|
1d989f5343 | ||
|
|
c1e3a1b3e7 | ||
|
|
be9ed319d5 | ||
|
|
c630fcffee | ||
|
|
f411b9cb55 | ||
|
|
bdaaebe955 | ||
|
|
9eb48ca2c3 | ||
|
|
509fa3a994 | ||
|
|
5097c7f284 | ||
|
|
c4e1c62c00 | ||
|
|
eab82782ca | ||
|
|
53d976234a | ||
|
|
44d8e34b5a | ||
|
|
d2e16a599d | ||
|
|
291e6c4198 | ||
|
|
bb7e1d6e55 | ||
|
|
fcc4c30ead | ||
|
|
f20984ea1d | ||
|
|
e0f0cfd92e | ||
|
|
57aec7d02a | ||
|
|
6350219143 | ||
|
|
3bc2cf9946 | ||
|
|
7f7452dc98 | ||
|
|
dc2a50034d | ||
|
|
ab564a9ec8 | ||
|
|
cc3856ef6d | ||
|
|
a8a4ad9546 | ||
|
|
5bfdecacad | ||
|
|
0bde66a888 | ||
|
|
5825d01d53 | ||
|
|
cd22cca4e8 | ||
|
|
a3ea217f40 | ||
|
|
66e4dded91 | ||
|
|
6d67d472cd | ||
|
|
76b7792e69 | ||
|
|
9d7100a287 | ||
|
|
876feecd6f | ||
|
|
0261d689dc | ||
|
|
aa4a00cbc2 | ||
|
|
52c505c210 | ||
|
|
ed455394fc | ||
|
|
57cc53ab94 | ||
|
|
6a61331cba | ||
|
|
51731ad0dd | ||
|
|
f280586e68 | ||
|
|
e31d6be4ce | ||
|
|
e6a92aa936 | ||
|
|
a54ea9f9fa | ||
|
|
73a92c046d | ||
|
|
459bd46846 | ||
|
|
445f7e70ba | ||
|
|
ca893f9918 | ||
|
|
1be1959d80 | ||
|
|
1654378850 | ||
|
|
d6d391d244 | ||
|
|
7c283b090d | ||
|
|
40226678af | ||
|
|
288e6fa606 | ||
|
|
5307d38472 | ||
|
|
d619602a6f | ||
|
|
348a2176f0 | ||
|
|
89b6da36a6 | ||
|
|
036d5c737e | ||
|
|
60a87d9472 | ||
|
|
eb9bb56829 | ||
|
|
d151082871 | ||
|
|
e4b1f5b963 | ||
|
|
3938a053aa | ||
|
|
7932e764d6 | ||
|
|
fb6695a983 | ||
|
|
015f415b71 | ||
|
|
96b582070b | ||
|
|
4a0a927a64 | ||
|
|
ea9a9cb553 | ||
|
|
38af12ab97 | ||
|
|
1b3154188d | ||
|
|
1f321826ad | ||
|
|
cbfbe4e5d8 | ||
|
|
3aa0e0124b | ||
|
|
f2f60c9cc0 | ||
|
|
6c32821ad4 | ||
|
|
d839595330 | ||
|
|
e422f96dff | ||
|
|
d28f460330 | ||
|
|
8e441d975d | ||
|
|
5c78af1f07 | ||
|
|
e325e063ed | ||
|
|
c81b45300b | ||
|
|
26a1e963d1 | ||
|
|
2a983263c7 | ||
|
|
2a37c95a5e | ||
|
|
c277a74f82 | ||
|
|
e4b31cd0d9 | ||
|
|
a40d2a1e2e | ||
|
|
c9fb99d719 | ||
|
|
a4d71e08aa | ||
|
|
546bfbd24b | ||
|
|
27824d6cc6 | ||
|
|
9d5c4ad634 | ||
|
|
9b32003816 | ||
|
|
8bc4123ed7 | ||
|
|
d58aaf7a59 | ||
|
|
a0056a1b3c | ||
|
|
d2584c773a | ||
|
|
807bef8ada | ||
|
|
5afddacbb2 | ||
|
|
4fb6a88f1e | ||
|
|
7057be6a88 | ||
|
|
91be8e7bfb | ||
|
|
9651ea828b | ||
|
|
6ee74bd0d1 | ||
|
|
48a0d29a5c | ||
|
|
6ff8e6c0ea | ||
|
|
2470c68506 | ||
|
|
866bc803b1 | ||
|
|
9c6084bd0d | ||
|
|
a0b46c60c6 | ||
|
|
4029233df0 | ||
|
|
6c88c0156c | ||
|
|
33332d08f2 | ||
|
|
17005fb705 | ||
|
|
48a7fe80b1 | ||
|
|
1276732409 | ||
|
|
f91b92a898 | ||
|
|
6222f533be | ||
|
|
1b49d17239 | ||
|
|
2f5f19642e | ||
|
|
6db4634871 | ||
|
|
5cfed45cef | ||
|
|
581ffde35a | ||
|
|
6313e6d91d | ||
|
|
c09c94bf32 | ||
|
|
0e8ba111c8 | ||
|
|
2ba24b1734 | ||
|
|
44820b4909 | ||
|
|
eb3e7610fc | ||
|
|
7fbbb174bb | ||
|
|
3854ca11af |
25
.github/pull_request_template.md
vendored
Normal file
25
.github/pull_request_template.md
vendored
Normal file
@@ -0,0 +1,25 @@
|
||||
## Description
|
||||
[Provide a brief description of the changes in this PR]
|
||||
|
||||
|
||||
## How Has This Been Tested?
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Accepted Risk
|
||||
[Any know risks or failure modes to point out to reviewers]
|
||||
|
||||
|
||||
## Related Issue(s)
|
||||
[If applicable, link to the issue(s) this PR addresses]
|
||||
|
||||
|
||||
## Checklist:
|
||||
- [ ] All of the automated tests pass
|
||||
- [ ] All PR comments are addressed and marked resolved
|
||||
- [ ] If there are migrations, they have been rebased to latest main
|
||||
- [ ] If there are new dependencies, they are added to the requirements
|
||||
- [ ] If there are new environment variables, they are added to all of the deployment methods
|
||||
- [ ] If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- [ ] Docker images build and basic functionalities work
|
||||
- [ ] Author has done a final read through of the PR right before merge
|
||||
@@ -7,7 +7,8 @@ on:
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
runs-on:
|
||||
group: amd64-image-builders
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
|
||||
2
.vscode/env_template.txt
vendored
2
.vscode/env_template.txt
vendored
@@ -15,7 +15,7 @@ LOG_LEVEL=debug
|
||||
|
||||
# This passes top N results to LLM an additional time for reranking prior to answer generation
|
||||
# This step is quite heavy on token usage so we disable it for dev generally
|
||||
DISABLE_LLM_CHUNK_FILTER=True
|
||||
DISABLE_LLM_DOC_RELEVANCE=True
|
||||
|
||||
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
|
||||
|
||||
26
.vscode/launch.template.jsonc
vendored
26
.vscode/launch.template.jsonc
vendored
@@ -39,7 +39,8 @@
|
||||
"--reload",
|
||||
"--port",
|
||||
"9000"
|
||||
]
|
||||
],
|
||||
"consoleTitle": "Model Server"
|
||||
},
|
||||
{
|
||||
"name": "API Server",
|
||||
@@ -58,7 +59,8 @@
|
||||
"--reload",
|
||||
"--port",
|
||||
"8080"
|
||||
]
|
||||
],
|
||||
"consoleTitle": "API Server"
|
||||
},
|
||||
{
|
||||
"name": "Indexing",
|
||||
@@ -68,11 +70,12 @@
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"ENABLE_MINI_CHUNK": "false",
|
||||
"ENABLE_MULTIPASS_INDEXING": "false",
|
||||
"LOG_LEVEL": "DEBUG",
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
}
|
||||
},
|
||||
"consoleTitle": "Indexing"
|
||||
},
|
||||
// Celery and all async jobs, usually would include indexing as well but this is handled separately above for dev
|
||||
{
|
||||
@@ -90,7 +93,8 @@
|
||||
},
|
||||
"args": [
|
||||
"--no-indexing"
|
||||
]
|
||||
],
|
||||
"consoleTitle": "Background Jobs"
|
||||
},
|
||||
// For the listner to access the Slack API,
|
||||
// DANSWER_BOT_SLACK_APP_TOKEN & DANSWER_BOT_SLACK_BOT_TOKEN need to be set in .env file located in the root of the project
|
||||
@@ -125,5 +129,17 @@
|
||||
//"tests/unit/danswer/llm/answering/test_prune_and_merge.py"
|
||||
]
|
||||
}
|
||||
],
|
||||
"compounds": [
|
||||
{
|
||||
"name": "Run Danswer",
|
||||
"configurations": [
|
||||
"Web Server",
|
||||
"Model Server",
|
||||
"API Server",
|
||||
"Indexing",
|
||||
"Background Jobs",
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -68,7 +68,9 @@ RUN apt-get update && \
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from transformers import AutoTokenizer; AutoTokenizer.from_pretrained('intfloat/e5-base-v2')"
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
|
||||
# Pre-downloading NLTK for setups with limited egress
|
||||
RUN python -c "import nltk; \
|
||||
|
||||
@@ -18,14 +18,18 @@ RUN apt-get remove -y --allow-remove-essential perl-base && \
|
||||
apt-get autoremove -y
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from transformers import AutoModel, AutoTokenizer, TFDistilBertForSequenceClassification; \
|
||||
# Download tokenizers, distilbert for the Danswer model
|
||||
# Download model weights
|
||||
# Run Nomic to pull in the custom architecture and have it cached locally
|
||||
RUN python -c "from transformers import AutoTokenizer; \
|
||||
AutoTokenizer.from_pretrained('distilbert-base-uncased', cache_folder='/root/.cache/temp_huggingface/hub/'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1', cache_folder='/root/.cache/temp_huggingface/hub/'); \
|
||||
from huggingface_hub import snapshot_download; \
|
||||
AutoTokenizer.from_pretrained('danswer/intent-model'); \
|
||||
AutoTokenizer.from_pretrained('intfloat/e5-base-v2'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
snapshot_download('danswer/intent-model'); \
|
||||
snapshot_download('intfloat/e5-base-v2'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1')"
|
||||
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3', cache_dir='/root/.cache/temp_huggingface/hub/'); \
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1', cache_dir='/root/.cache/temp_huggingface/hub/'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1', cache_dir='/root/.cache/temp_huggingface/hub/'); \
|
||||
from sentence_transformers import SentenceTransformer; \
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True, cache_folder='/root/.cache/temp_huggingface/hub/');"
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -17,15 +17,11 @@ depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("current_alternate_model", sa.String(), nullable=True),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("chat_session", "current_alternate_model")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""add_indexing_start_to_connector
|
||||
|
||||
Revision ID: 08a1eda20fe1
|
||||
Revises: 8a87bd6ec550
|
||||
Create Date: 2024-07-23 11:12:39.462397
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "08a1eda20fe1"
|
||||
down_revision = "8a87bd6ec550"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector", sa.Column("indexing_start", sa.DateTime(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("connector", "indexing_start")
|
||||
44
backend/alembic/versions/213fd978c6d8_notifications.py
Normal file
44
backend/alembic/versions/213fd978c6d8_notifications.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""notifications
|
||||
|
||||
Revision ID: 213fd978c6d8
|
||||
Revises: 5fc1f54cc252
|
||||
Create Date: 2024-08-10 11:13:36.070790
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "213fd978c6d8"
|
||||
down_revision = "5fc1f54cc252"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"notification",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"notif_type",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
sa.UUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("dismissed", sa.Boolean(), nullable=False),
|
||||
sa.Column("last_shown", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("first_shown", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("notification")
|
||||
@@ -79,7 +79,7 @@ def downgrade() -> None:
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_retrieval_feedback__chat_message_fk",
|
||||
"document_retrieval",
|
||||
"document_retrieval_feedback",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
|
||||
@@ -160,12 +160,28 @@ def downgrade() -> None:
|
||||
nullable=False,
|
||||
),
|
||||
)
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_credential_id", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Check if the constraint exists before dropping
|
||||
conn = op.get_bind()
|
||||
inspector = sa.inspect(conn)
|
||||
constraints = inspector.get_foreign_keys("index_attempt")
|
||||
|
||||
if any(
|
||||
constraint["name"] == "fk_index_attempt_credential_id"
|
||||
for constraint in constraints
|
||||
):
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_credential_id", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
if any(
|
||||
constraint["name"] == "fk_index_attempt_connector_id"
|
||||
for constraint in constraints
|
||||
):
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_connector_id", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
op.drop_column("index_attempt", "credential_id")
|
||||
op.drop_column("index_attempt", "connector_id")
|
||||
op.drop_table("connector_credential_pair")
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
"""Add icon_color and icon_shape to Persona
|
||||
|
||||
Revision ID: 325975216eb3
|
||||
Revises: 91ffac7e65b3
|
||||
Create Date: 2024-07-24 21:29:31.784562
|
||||
|
||||
"""
|
||||
import random
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.sql import table, column, select
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "325975216eb3"
|
||||
down_revision = "91ffac7e65b3"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
colorOptions = [
|
||||
"#FF6FBF",
|
||||
"#6FB1FF",
|
||||
"#B76FFF",
|
||||
"#FFB56F",
|
||||
"#6FFF8D",
|
||||
"#FF6F6F",
|
||||
"#6FFFFF",
|
||||
]
|
||||
|
||||
|
||||
# Function to generate a random shape ensuring at least 3 of the middle 4 squares are filled
|
||||
def generate_random_shape() -> int:
|
||||
center_squares = [12, 10, 6, 14, 13, 11, 7, 15]
|
||||
center_fill = random.choice(center_squares)
|
||||
remaining_squares = [i for i in range(16) if not (center_fill & (1 << i))]
|
||||
random.shuffle(remaining_squares)
|
||||
for i in range(10 - bin(center_fill).count("1")):
|
||||
center_fill |= 1 << remaining_squares[i]
|
||||
return center_fill
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("persona", sa.Column("icon_color", sa.String(), nullable=True))
|
||||
op.add_column("persona", sa.Column("icon_shape", sa.Integer(), nullable=True))
|
||||
op.add_column("persona", sa.Column("uploaded_image_id", sa.String(), nullable=True))
|
||||
|
||||
persona = table(
|
||||
"persona",
|
||||
column("id", sa.Integer),
|
||||
column("icon_color", sa.String),
|
||||
column("icon_shape", sa.Integer),
|
||||
)
|
||||
|
||||
conn = op.get_bind()
|
||||
personas = conn.execute(select(persona.c.id))
|
||||
|
||||
for persona_id in personas:
|
||||
random_color = random.choice(colorOptions)
|
||||
random_shape = generate_random_shape()
|
||||
conn.execute(
|
||||
persona.update()
|
||||
.where(persona.c.id == persona_id[0])
|
||||
.values(icon_color=random_color, icon_shape=random_shape)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "icon_shape")
|
||||
op.drop_column("persona", "uploaded_image_id")
|
||||
op.drop_column("persona", "icon_color")
|
||||
@@ -18,7 +18,6 @@ depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
@@ -29,10 +28,8 @@ def upgrade() -> None:
|
||||
["alternate_assistant_id"],
|
||||
["id"],
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint("fk_chat_message_persona", "chat_message", type_="foreignkey")
|
||||
op.drop_column("chat_message", "alternate_assistant_id")
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Rename index_origin to index_recursively
|
||||
|
||||
Revision ID: 1d6ad76d1f37
|
||||
Revises: e1392f05e840
|
||||
Create Date: 2024-08-01 12:38:54.466081
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1d6ad76d1f37"
|
||||
down_revision = "e1392f05e840"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = jsonb_set(
|
||||
connector_specific_config,
|
||||
'{index_recursively}',
|
||||
'true'::jsonb
|
||||
) - 'index_origin'
|
||||
WHERE connector_specific_config ? 'index_origin'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = jsonb_set(
|
||||
connector_specific_config,
|
||||
'{index_origin}',
|
||||
connector_specific_config->'index_recursively'
|
||||
) - 'index_recursively'
|
||||
WHERE connector_specific_config ? 'index_recursively'
|
||||
"""
|
||||
)
|
||||
@@ -0,0 +1,49 @@
|
||||
"""Add display_model_names to llm_provider
|
||||
|
||||
Revision ID: 473a1a7ca408
|
||||
Revises: 325975216eb3
|
||||
Create Date: 2024-07-25 14:31:02.002917
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "473a1a7ca408"
|
||||
down_revision = "325975216eb3"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
default_models_by_provider = {
|
||||
"openai": ["gpt-4", "gpt-4o", "gpt-4o-mini"],
|
||||
"bedrock": [
|
||||
"meta.llama3-1-70b-instruct-v1:0",
|
||||
"meta.llama3-1-8b-instruct-v1:0",
|
||||
"anthropic.claude-3-opus-20240229-v1:0",
|
||||
"mistral.mistral-large-2402-v1:0",
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
],
|
||||
"anthropic": ["claude-3-opus-20240229", "claude-3-5-sonnet-20240620"],
|
||||
}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("display_model_names", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
)
|
||||
|
||||
connection = op.get_bind()
|
||||
for provider, models in default_models_by_provider.items():
|
||||
connection.execute(
|
||||
sa.text(
|
||||
"UPDATE llm_provider SET display_model_names = :models WHERE provider = :provider"
|
||||
),
|
||||
{"models": models, "provider": provider},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("llm_provider", "display_model_names")
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Moved status to connector credential pair
|
||||
|
||||
Revision ID: 4a951134c801
|
||||
Revises: 7477a5f5d728
|
||||
Create Date: 2024-08-10 19:20:34.527559
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4a951134c801"
|
||||
down_revision = "7477a5f5d728"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.Enum(
|
||||
"ACTIVE",
|
||||
"PAUSED",
|
||||
"DELETING",
|
||||
name="connectorcredentialpairstatus",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Update status of connector_credential_pair based on connector's disabled status
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector_credential_pair
|
||||
SET status = CASE
|
||||
WHEN (
|
||||
SELECT disabled
|
||||
FROM connector
|
||||
WHERE connector.id = connector_credential_pair.connector_id
|
||||
) = FALSE THEN 'ACTIVE'
|
||||
ELSE 'PAUSED'
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the status column not nullable after setting values
|
||||
op.alter_column("connector_credential_pair", "status", nullable=False)
|
||||
|
||||
op.drop_column("connector", "disabled")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"connector",
|
||||
sa.Column("disabled", sa.BOOLEAN(), autoincrement=False, nullable=True),
|
||||
)
|
||||
|
||||
# Update disabled status of connector based on connector_credential_pair's status
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET disabled = CASE
|
||||
WHEN EXISTS (
|
||||
SELECT 1
|
||||
FROM connector_credential_pair
|
||||
WHERE connector_credential_pair.connector_id = connector.id
|
||||
AND connector_credential_pair.status = 'ACTIVE'
|
||||
) THEN FALSE
|
||||
ELSE TRUE
|
||||
END
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the disabled column not nullable after setting values
|
||||
op.alter_column("connector", "disabled", nullable=False)
|
||||
|
||||
op.drop_column("connector_credential_pair", "status")
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Add type to credentials
|
||||
|
||||
Revision ID: 4ea2c93919c1
|
||||
Revises: 473a1a7ca408
|
||||
Create Date: 2024-07-18 13:07:13.655895
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4ea2c93919c1"
|
||||
down_revision = "473a1a7ca408"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new 'source' column to the 'credential' table
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column(
|
||||
"source",
|
||||
sa.String(length=100), # Use String instead of Enum
|
||||
nullable=True, # Initially allow NULL values
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"credential",
|
||||
sa.Column(
|
||||
"name",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Create a temporary table that maps each credential to a single connector source.
|
||||
# This is needed because a credential can be associated with multiple connectors,
|
||||
# but we want to assign a single source to each credential.
|
||||
# We use DISTINCT ON to ensure we only get one row per credential_id.
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TEMPORARY TABLE temp_connector_credential AS
|
||||
SELECT DISTINCT ON (cc.credential_id)
|
||||
cc.credential_id,
|
||||
c.source AS connector_source
|
||||
FROM connector_credential_pair cc
|
||||
JOIN connector c ON cc.connector_id = c.id
|
||||
"""
|
||||
)
|
||||
|
||||
# Update the 'source' column in the 'credential' table
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE credential cred
|
||||
SET source = COALESCE(
|
||||
(SELECT connector_source
|
||||
FROM temp_connector_credential temp
|
||||
WHERE cred.id = temp.credential_id),
|
||||
'NOT_APPLICABLE'
|
||||
)
|
||||
"""
|
||||
)
|
||||
# If no exception was raised, alter the column
|
||||
op.alter_column("credential", "source", nullable=True) # TODO modify
|
||||
# # ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("credential", "source")
|
||||
op.drop_column("credential", "name")
|
||||
25
backend/alembic/versions/5fc1f54cc252_hybrid_enum.py
Normal file
25
backend/alembic/versions/5fc1f54cc252_hybrid_enum.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""hybrid-enum
|
||||
|
||||
Revision ID: 5fc1f54cc252
|
||||
Revises: 1d6ad76d1f37
|
||||
Create Date: 2024-08-06 15:35:40.278485
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5fc1f54cc252"
|
||||
down_revision = "1d6ad76d1f37"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("persona", "search_type")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column("persona", sa.Column("search_type", sa.String(), nullable=True))
|
||||
op.execute("UPDATE persona SET search_type = 'SEMANTIC'")
|
||||
op.alter_column("persona", "search_type", nullable=False)
|
||||
@@ -0,0 +1,24 @@
|
||||
"""Added model defaults for users
|
||||
|
||||
Revision ID: 7477a5f5d728
|
||||
Revises: 213fd978c6d8
|
||||
Create Date: 2024-08-04 19:00:04.512634
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7477a5f5d728"
|
||||
down_revision = "213fd978c6d8"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("user", sa.Column("default_model", sa.Text(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "default_model")
|
||||
@@ -28,5 +28,9 @@ def upgrade() -> None:
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# This wasn't really required by the code either, no good reason to make it unique again
|
||||
pass
|
||||
op.create_unique_constraint(
|
||||
"connector_credential_pair__name__key", "connector_credential_pair", ["name"]
|
||||
)
|
||||
op.alter_column(
|
||||
"connector_credential_pair", "name", existing_type=sa.String(), nullable=True
|
||||
)
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
"""add_llm_group_permissions_control
|
||||
|
||||
Revision ID: 795b20b85b4b
|
||||
Revises: 05c07bf07c00
|
||||
Create Date: 2024-07-19 11:54:35.701558
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
revision = "795b20b85b4b"
|
||||
down_revision = "05c07bf07c00"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"llm_provider__user_group",
|
||||
sa.Column("llm_provider_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_group_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["llm_provider_id"],
|
||||
["llm_provider.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_group_id"],
|
||||
["user_group.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"),
|
||||
)
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("llm_provider__user_group")
|
||||
op.drop_column("llm_provider", "is_public")
|
||||
@@ -0,0 +1,103 @@
|
||||
"""associate index attempts with ccpair
|
||||
|
||||
Revision ID: 8a87bd6ec550
|
||||
Revises: 4ea2c93919c1
|
||||
Create Date: 2024-07-22 15:15:52.558451
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8a87bd6ec550"
|
||||
down_revision = "4ea2c93919c1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new connector_credential_pair_id column
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
|
||||
# Create a foreign key constraint to the connector_credential_pair table
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_connector_credential_pair_id",
|
||||
"index_attempt",
|
||||
"connector_credential_pair",
|
||||
["connector_credential_pair_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Populate the new connector_credential_pair_id column using existing connector_id and credential_id
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE index_attempt ia
|
||||
SET connector_credential_pair_id =
|
||||
CASE
|
||||
WHEN ia.credential_id IS NULL THEN
|
||||
(SELECT id FROM connector_credential_pair
|
||||
WHERE connector_id = ia.connector_id
|
||||
LIMIT 1)
|
||||
ELSE
|
||||
(SELECT id FROM connector_credential_pair
|
||||
WHERE connector_id = ia.connector_id
|
||||
AND credential_id = ia.credential_id)
|
||||
END
|
||||
WHERE ia.connector_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the new connector_credential_pair_id column non-nullable
|
||||
op.alter_column("index_attempt", "connector_credential_pair_id", nullable=False)
|
||||
|
||||
# Drop the old connector_id and credential_id columns
|
||||
op.drop_column("index_attempt", "connector_id")
|
||||
op.drop_column("index_attempt", "credential_id")
|
||||
|
||||
# Update the index to use connector_credential_pair_id
|
||||
op.create_index(
|
||||
"ix_index_attempt_latest_for_connector_credential_pair",
|
||||
"index_attempt",
|
||||
["connector_credential_pair_id", "time_created"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the old connector_id and credential_id columns
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("connector_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("credential_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Populate the old connector_id and credential_id columns using the connector_credential_pair_id
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE index_attempt ia
|
||||
SET connector_id = ccp.connector_id, credential_id = ccp.credential_id
|
||||
FROM connector_credential_pair ccp
|
||||
WHERE ia.connector_credential_pair_id = ccp.id
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the old connector_id and credential_id columns non-nullable
|
||||
op.alter_column("index_attempt", "connector_id", nullable=False)
|
||||
op.alter_column("index_attempt", "credential_id", nullable=False)
|
||||
|
||||
# Drop the new connector_credential_pair_id column
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_connector_credential_pair_id",
|
||||
"index_attempt",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_column("index_attempt", "connector_credential_pair_id")
|
||||
|
||||
op.create_index(
|
||||
"ix_index_attempt_latest_for_connector_credential_pair",
|
||||
"index_attempt",
|
||||
["connector_id", "credential_id", "time_created"],
|
||||
)
|
||||
26
backend/alembic/versions/91ffac7e65b3_add_expiry_time.py
Normal file
26
backend/alembic/versions/91ffac7e65b3_add_expiry_time.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""add expiry time
|
||||
|
||||
Revision ID: 91ffac7e65b3
|
||||
Revises: bc9771dccadf
|
||||
Create Date: 2024-06-24 09:39:56.462242
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "91ffac7e65b3"
|
||||
down_revision = "795b20b85b4b"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user", sa.Column("oidc_expiry", sa.DateTime(timezone=True), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "oidc_expiry")
|
||||
@@ -16,7 +16,6 @@ depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"last_attempt_status",
|
||||
@@ -29,11 +28,9 @@ def upgrade() -> None:
|
||||
),
|
||||
nullable=True,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.alter_column(
|
||||
"connector_credential_pair",
|
||||
"last_attempt_status",
|
||||
@@ -46,4 +43,3 @@ def downgrade() -> None:
|
||||
),
|
||||
nullable=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
@@ -19,6 +19,9 @@ depends_on: None = None
|
||||
def upgrade() -> None:
|
||||
op.drop_table("deletion_attempt")
|
||||
|
||||
# Remove the DeletionStatus enum
|
||||
op.execute("DROP TYPE IF EXISTS deletionstatus;")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
|
||||
@@ -136,4 +136,4 @@ def downgrade() -> None:
|
||||
)
|
||||
op.drop_column("index_attempt", "embedding_model_id")
|
||||
op.drop_table("embedding_model")
|
||||
op.execute("DROP TYPE indexmodelstatus;")
|
||||
op.execute("DROP TYPE IF EXISTS indexmodelstatus;")
|
||||
|
||||
58
backend/alembic/versions/e1392f05e840_added_input_prompts.py
Normal file
58
backend/alembic/versions/e1392f05e840_added_input_prompts.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""Added input prompts
|
||||
|
||||
Revision ID: e1392f05e840
|
||||
Revises: 08a1eda20fe1
|
||||
Create Date: 2024-07-13 19:09:22.556224
|
||||
|
||||
"""
|
||||
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e1392f05e840"
|
||||
down_revision = "08a1eda20fe1"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"inputprompt",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("prompt", sa.String(), nullable=False),
|
||||
sa.Column("content", sa.String(), nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"inputprompt__user",
|
||||
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["input_prompt_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("inputprompt__user")
|
||||
op.drop_table("inputprompt")
|
||||
@@ -5,19 +5,16 @@ from danswer.access.utils import prefix_user
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.db.document import get_acccess_info_for_documents
|
||||
from danswer.db.models import User
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
|
||||
def _get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
document_access_info = get_acccess_info_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
cc_pair_to_delete=cc_pair_to_delete,
|
||||
)
|
||||
return {
|
||||
document_id: DocumentAccess.build(user_ids, [], is_public)
|
||||
@@ -28,14 +25,13 @@ def _get_access_for_documents(
|
||||
def get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
"""Fetches all access information for the given documents."""
|
||||
versioned_get_access_for_documents_fn = fetch_versioned_implementation(
|
||||
"danswer.access.access", "_get_access_for_documents"
|
||||
)
|
||||
return versioned_get_access_for_documents_fn(
|
||||
document_ids, db_session, cc_pair_to_delete
|
||||
document_ids, db_session
|
||||
) # type: ignore
|
||||
|
||||
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
from typing import cast
|
||||
|
||||
from danswer.configs.constants import KV_USER_STORE_KEY
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
|
||||
USER_STORE_KEY = "INVITED_USERS"
|
||||
|
||||
|
||||
def get_invited_users() -> list[str]:
|
||||
try:
|
||||
store = get_dynamic_config_store()
|
||||
return cast(list, store.load(USER_STORE_KEY))
|
||||
return cast(list, store.load(KV_USER_STORE_KEY))
|
||||
except ConfigNotFoundError:
|
||||
return list()
|
||||
|
||||
|
||||
def write_invited_users(emails: list[str]) -> int:
|
||||
store = get_dynamic_config_store()
|
||||
store.store(USER_STORE_KEY, cast(JSON_ro, emails))
|
||||
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
|
||||
return len(emails)
|
||||
|
||||
@@ -3,29 +3,27 @@ from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.constants import KV_NO_AUTH_USER_PREFERENCES_KEY
|
||||
from danswer.dynamic_configs.store import ConfigNotFoundError
|
||||
from danswer.dynamic_configs.store import DynamicConfigStore
|
||||
from danswer.server.manage.models import UserInfo
|
||||
from danswer.server.manage.models import UserPreferences
|
||||
|
||||
|
||||
NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
|
||||
|
||||
|
||||
def set_no_auth_user_preferences(
|
||||
store: DynamicConfigStore, preferences: UserPreferences
|
||||
) -> None:
|
||||
store.store(NO_AUTH_USER_PREFERENCES_KEY, preferences.dict())
|
||||
store.store(KV_NO_AUTH_USER_PREFERENCES_KEY, preferences.dict())
|
||||
|
||||
|
||||
def load_no_auth_user_preferences(store: DynamicConfigStore) -> UserPreferences:
|
||||
try:
|
||||
preferences_data = cast(
|
||||
Mapping[str, Any], store.load(NO_AUTH_USER_PREFERENCES_KEY)
|
||||
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
|
||||
)
|
||||
return UserPreferences(**preferences_data)
|
||||
except ConfigNotFoundError:
|
||||
return UserPreferences(chosen_assistants=None)
|
||||
return UserPreferences(chosen_assistants=None, default_model=None)
|
||||
|
||||
|
||||
def fetch_no_auth_user(store: DynamicConfigStore) -> UserInfo:
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import smtplib
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import Optional
|
||||
@@ -50,8 +52,10 @@ from danswer.db.auth import get_default_admin_user_emails
|
||||
from danswer.db.auth import get_user_count
|
||||
from danswer.db.auth import get_user_db
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import User
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
@@ -63,6 +67,14 @@ from danswer.utils.variable_functionality import (
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
if user and user.role == UserRole.ADMIN:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def verify_auth_setting() -> None:
|
||||
if AUTH_TYPE not in [AuthType.DISABLED, AuthType.BASIC, AuthType.GOOGLE_OAUTH]:
|
||||
raise ValueError(
|
||||
@@ -92,12 +104,18 @@ def user_needs_to_be_verified() -> bool:
|
||||
return AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str) -> None:
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
whitelist = get_invited_users()
|
||||
if (whitelist and email not in whitelist) or not email:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
|
||||
|
||||
def verify_email_domain(email: str) -> None:
|
||||
if VALID_EMAIL_DOMAINS:
|
||||
if email.count("@") != 1:
|
||||
@@ -147,7 +165,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
safe: bool = False,
|
||||
request: Optional[Request] = None,
|
||||
) -> models.UP:
|
||||
verify_email_in_whitelist(user_create.email)
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if hasattr(user_create, "role"):
|
||||
user_count = await get_user_count()
|
||||
@@ -173,7 +191,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
verify_email_in_whitelist(account_email)
|
||||
verify_email_domain(account_email)
|
||||
|
||||
return await super().oauth_callback( # type: ignore
|
||||
user = await super().oauth_callback( # type: ignore
|
||||
oauth_name=oauth_name,
|
||||
access_token=access_token,
|
||||
account_id=account_id,
|
||||
@@ -185,6 +203,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
is_verified_by_default=is_verified_by_default,
|
||||
)
|
||||
|
||||
# NOTE: google oauth expires after 1hr. We don't want to force the user to
|
||||
# re-authenticate that frequently, so for now we'll just ignore this for
|
||||
# google oauth users
|
||||
if expires_at and AUTH_TYPE != AuthType.GOOGLE_OAUTH:
|
||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||
await self.user_db.update(user, update_dict={"oidc_expiry": oidc_expiry})
|
||||
return user
|
||||
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
@@ -227,10 +253,12 @@ cookie_transport = CookieTransport(
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
strategy = DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
||||
)
|
||||
|
||||
return strategy
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="database",
|
||||
@@ -327,6 +355,12 @@ async def double_check_user(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@@ -345,4 +379,5 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Access denied. User is not an admin.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@@ -5,6 +5,7 @@ from celery import Celery # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.background.celery.celery_utils import should_kick_off_deletion_of_cc_pair
|
||||
from danswer.background.celery.celery_utils import should_prune_cc_pair
|
||||
from danswer.background.celery.celery_utils import should_sync_doc_set
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair
|
||||
@@ -14,6 +15,7 @@ from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import POSTGRES_CELERY_APP_NAME
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
@@ -38,7 +40,9 @@ from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
connection_string = build_connection_string(db_api=SYNC_DB_API)
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API, app_name=POSTGRES_CELERY_APP_NAME
|
||||
)
|
||||
celery_broker_url = f"sqla+{connection_string}"
|
||||
celery_backend_url = f"db+{connection_string}"
|
||||
celery_app = Celery(__name__, broker=celery_broker_url, backend=celery_backend_url)
|
||||
@@ -100,7 +104,7 @@ def cleanup_connector_credential_pair_task(
|
||||
@build_celery_task_wrapper(name_cc_prune_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all docuement IDs from the source
|
||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
@@ -267,6 +271,26 @@ def check_for_document_sets_sync_task() -> None:
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_for_cc_pair_deletion_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_cc_pair_deletion_task() -> None:
|
||||
"""Runs periodically to check if any deletion tasks should be run"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# check if any document sets are not synced
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_kick_off_deletion_of_cc_pair(cc_pair, db_session):
|
||||
logger.info(f"Deleting the {cc_pair.name} connector credential pair")
|
||||
cleanup_connector_credential_pair_task.apply_async(
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_for_prune_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
@@ -302,6 +326,12 @@ celery_app.conf.beat_schedule = {
|
||||
"task": "check_for_document_sets_sync_task",
|
||||
"schedule": timedelta(seconds=5),
|
||||
},
|
||||
"check-for-cc-pair-deletion": {
|
||||
"task": "check_for_cc_pair_deletion_task",
|
||||
# don't need to check too often, since we kick off a deletion initially
|
||||
# during the API call that actually marks the CC pair for deletion
|
||||
"schedule": timedelta(minutes=1),
|
||||
},
|
||||
}
|
||||
celery_app.conf.beat_schedule.update(
|
||||
{
|
||||
|
||||
@@ -6,8 +6,8 @@ from sqlalchemy.orm import Session
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.background.task_utils import name_document_set_sync_task
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from danswer.configs.app_configs import PREVENT_SIMULTANEOUS_PRUNING
|
||||
from danswer.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
@@ -16,10 +16,14 @@ from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import TaskQueueState
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.db.tasks import get_latest_task_by_type
|
||||
@@ -31,22 +35,52 @@ logger = setup_logger()
|
||||
|
||||
def get_deletion_status(
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
) -> TaskQueueState | None:
|
||||
cleanup_task_name = name_cc_cleanup_task(
|
||||
connector_id=connector_id, credential_id=credential_id
|
||||
)
|
||||
task_state = get_latest_task(task_name=cleanup_task_name, db_session=db_session)
|
||||
return get_latest_task(task_name=cleanup_task_name, db_session=db_session)
|
||||
|
||||
if not task_state:
|
||||
|
||||
def get_deletion_attempt_snapshot(
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
deletion_task = get_deletion_status(connector_id, credential_id, db_session)
|
||||
if not deletion_task:
|
||||
return None
|
||||
|
||||
return DeletionAttemptSnapshot(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
status=task_state.status,
|
||||
status=deletion_task.status,
|
||||
)
|
||||
|
||||
|
||||
def should_kick_off_deletion_of_cc_pair(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> bool:
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
return False
|
||||
|
||||
if check_deletion_attempt_is_allowed(cc_pair, db_session):
|
||||
return False
|
||||
|
||||
deletion_task = get_deletion_status(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if deletion_task and check_task_is_live_and_not_timed_out(
|
||||
deletion_task,
|
||||
db_session,
|
||||
# 1 hour timeout
|
||||
timeout=60 * 60,
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def should_sync_doc_set(document_set: DocumentSet, db_session: Session) -> bool:
|
||||
if document_set.is_up_to_date:
|
||||
return False
|
||||
@@ -80,7 +114,7 @@ def should_prune_cc_pair(
|
||||
return True
|
||||
return False
|
||||
|
||||
if PREVENT_SIMULTANEOUS_PRUNING:
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
pruning_type_task_name = name_cc_prune_task()
|
||||
last_pruning_type_task = get_latest_task_by_type(
|
||||
pruning_type_task_name, db_session
|
||||
@@ -89,11 +123,9 @@ def should_prune_cc_pair(
|
||||
if last_pruning_type_task and check_task_is_live_and_not_timed_out(
|
||||
last_pruning_type_task, db_session
|
||||
):
|
||||
logger.info("Another Connector is already pruning. Skipping.")
|
||||
return False
|
||||
|
||||
if check_task_is_live_and_not_timed_out(last_pruning_task, db_session):
|
||||
logger.info(f"Connector '{connector.name}' is already pruning. Skipping.")
|
||||
return False
|
||||
|
||||
if not last_pruning_task.start_time:
|
||||
|
||||
@@ -10,8 +10,6 @@ are multiple connector / credential pairs that have indexed it
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
import time
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
@@ -24,10 +22,8 @@ from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document_connector_cnts
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document_set import get_document_sets_by_ids
|
||||
from danswer.db.document_set import (
|
||||
mark_cc_pair__document_set_relationships_to_be_deleted__no_commit,
|
||||
)
|
||||
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
@@ -35,6 +31,10 @@ from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from danswer.utils.variable_functionality import noop_fallback
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -78,25 +78,37 @@ def delete_connector_credential_pair_batch(
|
||||
document_ids_to_update = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt > 1
|
||||
]
|
||||
|
||||
# maps document id to list of document set names
|
||||
new_doc_sets_for_documents: dict[str, set[str]] = {
|
||||
document_id_and_document_set_names_tuple[0]: set(
|
||||
document_id_and_document_set_names_tuple[1]
|
||||
)
|
||||
for document_id_and_document_set_names_tuple in fetch_document_sets_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_update,
|
||||
)
|
||||
}
|
||||
|
||||
# determine future ACLs for documents in batch
|
||||
access_for_documents = get_access_for_documents(
|
||||
document_ids=document_ids_to_update,
|
||||
db_session=db_session,
|
||||
cc_pair_to_delete=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
|
||||
# update Vespa
|
||||
logger.debug(f"Updating documents: {document_ids_to_update}")
|
||||
update_requests = [
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
access=access,
|
||||
document_sets=new_doc_sets_for_documents[document_id],
|
||||
)
|
||||
for document_id, access in access_for_documents.items()
|
||||
]
|
||||
logger.debug(f"Updating documents: {document_ids_to_update}")
|
||||
|
||||
document_index.update(update_requests=update_requests)
|
||||
|
||||
# clean up Postgres
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_update,
|
||||
@@ -108,48 +120,6 @@ def delete_connector_credential_pair_batch(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def cleanup_synced_entities(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> None:
|
||||
"""Updates the document sets associated with the connector / credential pair,
|
||||
then relies on the document set sync script to kick off Celery jobs which will
|
||||
sync these updates to Vespa.
|
||||
|
||||
Waits until the document sets are synced before returning."""
|
||||
logger.info(f"Cleaning up Document Sets for CC Pair with ID: '{cc_pair.id}'")
|
||||
document_sets_ids_to_sync = list(
|
||||
mark_cc_pair__document_set_relationships_to_be_deleted__no_commit(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# wait till all document sets are synced before continuing
|
||||
while True:
|
||||
all_synced = True
|
||||
document_sets = get_document_sets_by_ids(
|
||||
db_session=db_session, document_set_ids=document_sets_ids_to_sync
|
||||
)
|
||||
for document_set in document_sets:
|
||||
if not document_set.is_up_to_date:
|
||||
all_synced = False
|
||||
|
||||
if all_synced:
|
||||
break
|
||||
|
||||
# wait for 30 seconds before checking again
|
||||
db_session.commit() # end transaction
|
||||
logger.info(
|
||||
f"Document sets '{document_sets_ids_to_sync}' not synced yet, waiting 30s"
|
||||
)
|
||||
time.sleep(30)
|
||||
|
||||
logger.info(
|
||||
f"Finished cleaning up Document Sets for CC Pair with ID: '{cc_pair.id}'"
|
||||
)
|
||||
|
||||
|
||||
def delete_connector_credential_pair(
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
@@ -177,17 +147,33 @@ def delete_connector_credential_pair(
|
||||
)
|
||||
num_docs_deleted += len(documents)
|
||||
|
||||
# Clean up document sets / access information from Postgres
|
||||
# and sync these updates to Vespa
|
||||
# TODO: add user group cleanup with `fetch_versioned_implementation`
|
||||
cleanup_synced_entities(cc_pair, db_session)
|
||||
|
||||
# clean up the rest of the related Postgres entities
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
# user groups
|
||||
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.user_group",
|
||||
"delete_user_group_cc_pair_relationship__no_commit",
|
||||
noop_fallback,
|
||||
)
|
||||
cleanup_user_groups(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
|
||||
@@ -41,6 +41,12 @@ def _initializer(
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def _run_in_process(
|
||||
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
|
||||
) -> None:
|
||||
_initializer(func, args, kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleJob:
|
||||
"""Drop in replacement for `dask.distributed.Future`"""
|
||||
@@ -113,7 +119,7 @@ class SimpleJobClient:
|
||||
job_id = self.job_id_counter
|
||||
self.job_id_counter += 1
|
||||
|
||||
process = Process(target=_initializer(func=func, args=args), daemon=True)
|
||||
process = Process(target=_run_in_process, args=(func, args), daemon=True)
|
||||
job = SimpleJob(id=job_id, process=process)
|
||||
process.start()
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from datetime import timezone
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
|
||||
from danswer.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
|
||||
from danswer.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
@@ -14,13 +15,13 @@ from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector import disable_connector
|
||||
from danswer.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress__no_commit
|
||||
from danswer.db.index_attempt import mark_attempt_in_progress
|
||||
from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
from danswer.db.index_attempt import update_docs_indexed
|
||||
from danswer.db.models import IndexAttempt
|
||||
@@ -49,19 +50,26 @@ def _get_document_generator(
|
||||
are the complete list of existing documents of the connector. If the task
|
||||
of type LOAD_STATE, the list will be considered complete and otherwise incomplete.
|
||||
"""
|
||||
task = attempt.connector.input_type
|
||||
task = attempt.connector_credential_pair.connector.input_type
|
||||
|
||||
try:
|
||||
runnable_connector = instantiate_connector(
|
||||
attempt.connector.source,
|
||||
attempt.connector_credential_pair.connector.source,
|
||||
task,
|
||||
attempt.connector.connector_specific_config,
|
||||
attempt.credential,
|
||||
attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
attempt.connector_credential_pair.credential,
|
||||
db_session,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
disable_connector(attempt.connector.id, db_session)
|
||||
# since we failed to even instantiate the connector, we pause the CCPair since
|
||||
# it will never succeed
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=attempt.connector_credential_pair.connector.id,
|
||||
credential_id=attempt.connector_credential_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.PAUSED,
|
||||
)
|
||||
raise e
|
||||
|
||||
if task == InputType.LOAD_STATE:
|
||||
@@ -70,7 +78,10 @@ def _get_document_generator(
|
||||
|
||||
elif task == InputType.POLL:
|
||||
assert isinstance(runnable_connector, PollConnector)
|
||||
if attempt.connector_id is None or attempt.credential_id is None:
|
||||
if (
|
||||
attempt.connector_credential_pair.connector_id is None
|
||||
or attempt.connector_credential_pair.connector_id is None
|
||||
):
|
||||
raise ValueError(
|
||||
f"Polling attempt {attempt.id} is missing connector_id or credential_id, "
|
||||
f"can't fetch time range."
|
||||
@@ -110,13 +121,8 @@ def _run_indexing(
|
||||
primary_index_name=index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
api_key=db_embedding_model.api_key,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_embedding_model(
|
||||
db_embedding_model
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
@@ -127,16 +133,22 @@ def _run_indexing(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_connector = index_attempt.connector
|
||||
db_credential = index_attempt.credential
|
||||
db_cc_pair = index_attempt.connector_credential_pair
|
||||
db_connector = index_attempt.connector_credential_pair.connector
|
||||
db_credential = index_attempt.connector_credential_pair.credential
|
||||
|
||||
last_successful_index_time = (
|
||||
0.0
|
||||
if index_attempt.from_beginning
|
||||
else get_last_successful_attempt_time(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
embedding_model=index_attempt.embedding_model,
|
||||
db_session=db_session,
|
||||
db_connector.indexing_start.timestamp()
|
||||
if index_attempt.from_beginning and db_connector.indexing_start is not None
|
||||
else (
|
||||
0.0
|
||||
if index_attempt.from_beginning
|
||||
else get_last_successful_attempt_time(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
embedding_model=index_attempt.embedding_model,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
@@ -173,7 +185,7 @@ def _run_indexing(
|
||||
# contents still need to be initially pulled.
|
||||
db_session.refresh(db_connector)
|
||||
if (
|
||||
db_connector.disabled
|
||||
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
|
||||
and db_embedding_model.status != IndexModelStatus.FUTURE
|
||||
):
|
||||
# let the `except` block handle this
|
||||
@@ -184,12 +196,25 @@ def _run_indexing(
|
||||
# Likely due to user manually disabling it or model swap
|
||||
raise RuntimeError("Index Attempt was canceled")
|
||||
|
||||
logger.debug(
|
||||
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
|
||||
)
|
||||
batch_description = []
|
||||
for doc in doc_batch:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
|
||||
doc_size = 0
|
||||
for section in doc.sections:
|
||||
doc_size += len(section.text)
|
||||
|
||||
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Document size: doc='{doc.to_short_descriptor()}' "
|
||||
f"size={doc_size} "
|
||||
f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}"
|
||||
)
|
||||
|
||||
logger.debug(f"Indexing batch of documents: {batch_description}")
|
||||
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
documents=doc_batch,
|
||||
document_batch=doc_batch,
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
@@ -238,7 +263,7 @@ def _run_indexing(
|
||||
# to give better clarity in the UI, as the next run will never happen.
|
||||
if (
|
||||
ind == 0
|
||||
or db_connector.disabled
|
||||
or db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
|
||||
or index_attempt.status != IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
mark_attempt_failed(
|
||||
@@ -250,8 +275,8 @@ def _run_indexing(
|
||||
if is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=index_attempt.connector.id,
|
||||
credential_id=index_attempt.credential.id,
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
net_docs=net_doc_change,
|
||||
)
|
||||
raise e
|
||||
@@ -269,11 +294,9 @@ def _run_indexing(
|
||||
run_dt=run_end_dt,
|
||||
)
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.info(
|
||||
f"Indexed or refreshed {document_count} total documents for a total of {chunk_count} indexed chunks"
|
||||
)
|
||||
logger.info(
|
||||
f"Connector successfully finished, elapsed time: {time.time() - start_time} seconds"
|
||||
f"Connector succeeded: docs={document_count} chunks={chunk_count} elapsed={elapsed_time:.2f}s"
|
||||
)
|
||||
|
||||
|
||||
@@ -299,9 +322,7 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
|
||||
)
|
||||
|
||||
# only commit once, to make sure this all happens in a single transaction
|
||||
mark_attempt_in_progress__no_commit(attempt)
|
||||
if attempt.embedding_model.status != IndexModelStatus.PRESENT:
|
||||
db_session.commit()
|
||||
mark_attempt_in_progress(attempt, db_session)
|
||||
|
||||
return attempt
|
||||
|
||||
@@ -324,17 +345,19 @@ def run_indexing_entrypoint(index_attempt_id: int, is_ee: bool = False) -> None:
|
||||
attempt = _prepare_index_attempt(db_session, index_attempt_id)
|
||||
|
||||
logger.info(
|
||||
f"Running indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
f"Indexing starting: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
|
||||
_run_indexing(db_session, attempt)
|
||||
|
||||
logger.info(
|
||||
f"Completed indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
f"Indexing finished: "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.connector_id}'"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Indexing job with ID '{index_attempt_id}' failed due to {e}")
|
||||
|
||||
@@ -16,24 +16,29 @@ from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
|
||||
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import init_sqlalchemy_engine
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
from danswer.db.index_attempt import get_last_attempt
|
||||
from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
@@ -53,12 +58,14 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
|
||||
|
||||
|
||||
def _should_create_new_indexing(
|
||||
connector: Connector,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
model: EmbeddingModel,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
connector = cc_pair.connector
|
||||
|
||||
# User can still manually create single indexing attempts via the UI for the
|
||||
# currently in use index
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
@@ -66,28 +73,46 @@ def _should_create_new_indexing(
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
if model.status == IndexModelStatus.FUTURE and not last_index:
|
||||
if connector.id == 0: # Ingestion API
|
||||
return False
|
||||
if model.status == IndexModelStatus.FUTURE:
|
||||
if last_index:
|
||||
# No new index if the last index attempt succeeded
|
||||
# Once is enough. The model will never be able to swap otherwise.
|
||||
if last_index.status == IndexingStatus.SUCCESS:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is waiting to start
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
return False
|
||||
|
||||
# No new index if the last index attempt is running
|
||||
if last_index.status == IndexingStatus.IN_PROGRESS:
|
||||
return False
|
||||
else:
|
||||
if connector.id == 0: # Ingestion API
|
||||
return False
|
||||
return True
|
||||
|
||||
# If the connector is disabled, don't index
|
||||
# NOTE: during an embedding model switch over, we ignore this
|
||||
# and index the disabled connectors as well (which is why this if
|
||||
# statement is below the first condition above)
|
||||
if connector.disabled:
|
||||
# If the connector is paused or is the ingestion API, don't index
|
||||
# NOTE: during an embedding model switch over, the following logic
|
||||
# is bypassed by the above check for a future model
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.PAUSED or connector.id == 0:
|
||||
return False
|
||||
|
||||
if connector.refresh_freq is None:
|
||||
return False
|
||||
if not last_index:
|
||||
return True
|
||||
|
||||
# Only one scheduled job per connector at a time
|
||||
# Can schedule another one if the current one is already running however
|
||||
# Because the currently running one will not be until the latest time
|
||||
# Note, this last index is for the given embedding model
|
||||
if last_index.status == IndexingStatus.NOT_STARTED:
|
||||
if connector.refresh_freq is None:
|
||||
return False
|
||||
|
||||
# Only one scheduled/ongoing job per connector at a time
|
||||
# this prevents cases where
|
||||
# (1) the "latest" index_attempt is scheduled so we show
|
||||
# that in the UI despite another index_attempt being in-progress
|
||||
# (2) multiple scheduled index_attempts at a time
|
||||
if (
|
||||
last_index.status == IndexingStatus.NOT_STARTED
|
||||
or last_index.status == IndexingStatus.IN_PROGRESS
|
||||
):
|
||||
return False
|
||||
|
||||
current_db_time = get_db_current_time(db_session)
|
||||
@@ -111,8 +136,8 @@ def _mark_run_failed(
|
||||
"""Marks the `index_attempt` row as failed + updates the `
|
||||
connector_credential_pair` to reflect that the run failed"""
|
||||
logger.warning(
|
||||
f"Marking in-progress attempt 'connector: {index_attempt.connector_id}, "
|
||||
f"credential: {index_attempt.credential_id}' as failed due to {failure_reason}"
|
||||
f"Marking in-progress attempt 'connector: {index_attempt.connector_credential_pair.connector_id}, "
|
||||
f"credential: {index_attempt.connector_credential_pair.credential_id}' as failed due to {failure_reason}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
index_attempt=index_attempt,
|
||||
@@ -131,7 +156,7 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
3. There is not already an ongoing indexing attempt for this pair
|
||||
"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
ongoing: set[tuple[int | None, int | None, int]] = set()
|
||||
ongoing: set[tuple[int | None, int]] = set()
|
||||
for attempt_id in existing_jobs:
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=attempt_id
|
||||
@@ -144,8 +169,7 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
continue
|
||||
ongoing.add(
|
||||
(
|
||||
attempt.connector_id,
|
||||
attempt.credential_id,
|
||||
attempt.connector_credential_pair_id,
|
||||
attempt.embedding_model_id,
|
||||
)
|
||||
)
|
||||
@@ -155,31 +179,26 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
if secondary_embedding_model is not None:
|
||||
embedding_models.append(secondary_embedding_model)
|
||||
|
||||
all_connectors = fetch_connectors(db_session)
|
||||
for connector in all_connectors:
|
||||
for association in connector.credentials:
|
||||
for model in embedding_models:
|
||||
credential = association.credential
|
||||
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair in all_connector_credential_pairs:
|
||||
for model in embedding_models:
|
||||
# Check if there is an ongoing indexing attempt for this connector credential pair
|
||||
if (cc_pair.id, model.id) in ongoing:
|
||||
continue
|
||||
|
||||
# Check if there is an ongoing indexing attempt for this connector + credential pair
|
||||
if (connector.id, credential.id, model.id) in ongoing:
|
||||
continue
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, model.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
cc_pair=cc_pair,
|
||||
last_index=last_attempt,
|
||||
model=model,
|
||||
secondary_index_building=len(embedding_models) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt(
|
||||
connector.id, credential.id, model.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
connector=connector,
|
||||
last_index=last_attempt,
|
||||
model=model,
|
||||
secondary_index_building=len(embedding_models) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
create_index_attempt(
|
||||
connector.id, credential.id, model.id, db_session
|
||||
)
|
||||
create_index_attempt(cc_pair.id, model.id, db_session)
|
||||
|
||||
|
||||
def cleanup_indexing_jobs(
|
||||
@@ -271,24 +290,28 @@ def kickoff_indexing_jobs(
|
||||
# Don't include jobs waiting in the Dask queue that just haven't started running
|
||||
# Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet
|
||||
with Session(engine) as db_session:
|
||||
# get_not_started_index_attempts orders its returned results from oldest to newest
|
||||
# we must process attempts in a FIFO manner to prevent connector starvation
|
||||
new_indexing_attempts = [
|
||||
(attempt, attempt.embedding_model)
|
||||
for attempt in get_not_started_index_attempts(db_session)
|
||||
if attempt.id not in existing_jobs
|
||||
]
|
||||
|
||||
logger.info(f"Found {len(new_indexing_attempts)} new indexing tasks.")
|
||||
logger.debug(f"Found {len(new_indexing_attempts)} new indexing task(s).")
|
||||
|
||||
if not new_indexing_attempts:
|
||||
return existing_jobs
|
||||
|
||||
indexing_attempt_count = 0
|
||||
|
||||
for attempt, embedding_model in new_indexing_attempts:
|
||||
use_secondary_index = (
|
||||
embedding_model.status == IndexModelStatus.FUTURE
|
||||
if embedding_model is not None
|
||||
else False
|
||||
)
|
||||
if attempt.connector is None:
|
||||
if attempt.connector_credential_pair.connector is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Connector has been deleted: {attempt}"
|
||||
)
|
||||
@@ -297,7 +320,7 @@ def kickoff_indexing_jobs(
|
||||
attempt, db_session, failure_reason="Connector is null"
|
||||
)
|
||||
continue
|
||||
if attempt.credential is None:
|
||||
if attempt.connector_credential_pair.credential is None:
|
||||
logger.warning(
|
||||
f"Skipping index attempt as Credential has been deleted: {attempt}"
|
||||
)
|
||||
@@ -323,35 +346,53 @@ def kickoff_indexing_jobs(
|
||||
)
|
||||
|
||||
if run:
|
||||
secondary_str = "(secondary index) " if use_secondary_index else ""
|
||||
if indexing_attempt_count == 0:
|
||||
logger.info(
|
||||
f"Indexing dispatch starts: pending={len(new_indexing_attempts)}"
|
||||
)
|
||||
|
||||
indexing_attempt_count += 1
|
||||
secondary_str = " (secondary index)" if use_secondary_index else ""
|
||||
logger.info(
|
||||
f"Kicked off {secondary_str}"
|
||||
f"indexing attempt for connector: '{attempt.connector.name}', "
|
||||
f"with config: '{attempt.connector.connector_specific_config}', and "
|
||||
f"with credentials: '{attempt.credential_id}'"
|
||||
f"Indexing dispatched{secondary_str}: "
|
||||
f"attempt_id={attempt.id} "
|
||||
f"connector='{attempt.connector_credential_pair.connector.name}' "
|
||||
f"config='{attempt.connector_credential_pair.connector.connector_specific_config}' "
|
||||
f"credentials='{attempt.connector_credential_pair.credential_id}'"
|
||||
)
|
||||
existing_jobs_copy[attempt.id] = run
|
||||
|
||||
if indexing_attempt_count > 0:
|
||||
logger.info(
|
||||
f"Indexing dispatch results: "
|
||||
f"initial_pending={len(new_indexing_attempts)} "
|
||||
f"started={indexing_attempt_count} "
|
||||
f"remaining={len(new_indexing_attempts) - indexing_attempt_count}"
|
||||
)
|
||||
|
||||
return existing_jobs_copy
|
||||
|
||||
|
||||
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
||||
def update_loop(
|
||||
delay: int = 10,
|
||||
num_workers: int = NUM_INDEXING_WORKERS,
|
||||
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
|
||||
) -> None:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
logger.info("Running a first inference to warm up embedding model")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
if db_embedding_model.cloud_provider_id is None:
|
||||
logger.debug("Running a first inference to warm up embedding model")
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=db_embedding_model,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
client_primary: Client | SimpleJobClient
|
||||
client_secondary: Client | SimpleJobClient
|
||||
@@ -366,7 +407,7 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
cluster_secondary = LocalCluster(
|
||||
n_workers=num_workers,
|
||||
n_workers=num_secondary_workers,
|
||||
threads_per_worker=1,
|
||||
silence_logs=logging.ERROR,
|
||||
)
|
||||
@@ -376,18 +417,18 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
client_primary.register_worker_plugin(ResourceLogger())
|
||||
else:
|
||||
client_primary = SimpleJobClient(n_workers=num_workers)
|
||||
client_secondary = SimpleJobClient(n_workers=num_workers)
|
||||
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
|
||||
|
||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||
|
||||
while True:
|
||||
start = time.time()
|
||||
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
|
||||
logger.info(f"Running update, current UTC time: {start_time_utc}")
|
||||
logger.debug(f"Running update, current UTC time: {start_time_utc}")
|
||||
|
||||
if existing_jobs:
|
||||
# TODO: make this debug level once the "no jobs are being scheduled" issue is resolved
|
||||
logger.info(
|
||||
logger.debug(
|
||||
"Found existing indexing jobs: "
|
||||
f"{[(attempt_id, job.status) for attempt_id, job in existing_jobs.items()]}"
|
||||
)
|
||||
@@ -411,8 +452,9 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
|
||||
def update__main() -> None:
|
||||
set_is_ee_based_on_env_variable()
|
||||
init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME)
|
||||
|
||||
logger.info("Starting Indexing Loop")
|
||||
logger.info("Starting indexing service")
|
||||
update_loop()
|
||||
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
|
||||
def create_chat_chain(
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
prefetch_tool_calls: bool = True,
|
||||
) -> tuple[ChatMessage, list[ChatMessage]]:
|
||||
"""Build the linear chain of messages without including the root message"""
|
||||
mainline_messages: list[ChatMessage] = []
|
||||
@@ -43,6 +44,7 @@ def create_chat_chain(
|
||||
user_id=None,
|
||||
db_session=db_session,
|
||||
skip_permission_check=True,
|
||||
prefetch_tool_calls=prefetch_tool_calls,
|
||||
)
|
||||
id_to_msg = {msg.id: msg for msg in all_chat_messages}
|
||||
|
||||
|
||||
24
backend/danswer/chat/input_prompts.yaml
Normal file
24
backend/danswer/chat/input_prompts.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
input_prompts:
|
||||
- id: -5
|
||||
prompt: "Elaborate"
|
||||
content: "Elaborate on the above, give me a more in depth explanation."
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -4
|
||||
prompt: "Reword"
|
||||
content: "Help me rewrite the following politely and concisely for professional communication:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -3
|
||||
prompt: "Email"
|
||||
content: "Write a professional email for me including a subject line, signature, etc. Template the parts that need editing with [ ]. The email should cover the following points:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -2
|
||||
prompt: "Debug"
|
||||
content: "Provide step-by-step troubleshooting instructions for the following issue:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
@@ -1,13 +1,17 @@
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.chat_configs import INPUT_PROMPT_YAML
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import PERSONAS_YAML
|
||||
from danswer.configs.chat_configs import PROMPTS_YAML
|
||||
from danswer.db.document_set import get_or_create_document_set_by_name
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt as PromptDBModel
|
||||
from danswer.db.models import Tool as ToolDBModel
|
||||
from danswer.db.persona import get_prompt_by_name
|
||||
from danswer.db.persona import upsert_persona
|
||||
from danswer.db.persona import upsert_prompt
|
||||
@@ -76,9 +80,31 @@ def load_personas_from_yaml(
|
||||
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
|
||||
|
||||
p_id = persona.get("id")
|
||||
tool_ids = []
|
||||
if persona.get("image_generation"):
|
||||
image_gen_tool = (
|
||||
db_session.query(ToolDBModel)
|
||||
.filter(ToolDBModel.name == "ImageGenerationTool")
|
||||
.first()
|
||||
)
|
||||
if image_gen_tool:
|
||||
tool_ids.append(image_gen_tool.id)
|
||||
|
||||
llm_model_provider_override = persona.get("llm_model_provider_override")
|
||||
llm_model_version_override = persona.get("llm_model_version_override")
|
||||
|
||||
# Set specific overrides for image generation persona
|
||||
if persona.get("image_generation"):
|
||||
llm_model_version_override = "gpt-4o"
|
||||
|
||||
existing_persona = (
|
||||
db_session.query(Persona)
|
||||
.filter(Persona.name == persona["name"])
|
||||
.first()
|
||||
)
|
||||
|
||||
upsert_persona(
|
||||
user=None,
|
||||
# Negative to not conflict with existing personas
|
||||
persona_id=(-1 * p_id) if p_id is not None else None,
|
||||
name=persona["name"],
|
||||
description=persona["description"],
|
||||
@@ -88,20 +114,52 @@ def load_personas_from_yaml(
|
||||
llm_relevance_filter=persona.get("llm_relevance_filter"),
|
||||
starter_messages=persona.get("starter_messages"),
|
||||
llm_filter_extraction=persona.get("llm_filter_extraction"),
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
icon_shape=persona.get("icon_shape"),
|
||||
icon_color=persona.get("icon_color"),
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
recency_bias=RecencyBiasSetting(persona["recency_bias"]),
|
||||
prompt_ids=prompt_ids,
|
||||
document_set_ids=doc_set_ids,
|
||||
tool_ids=tool_ids,
|
||||
default_persona=True,
|
||||
is_public=True,
|
||||
display_priority=existing_persona.display_priority
|
||||
if existing_persona is not None
|
||||
else persona.get("display_priority"),
|
||||
is_visible=existing_persona.is_visible
|
||||
if existing_persona is not None
|
||||
else persona.get("is_visible"),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None:
|
||||
with open(input_prompts_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_input_prompts = data.get("input_prompts", [])
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for input_prompt in all_input_prompts:
|
||||
# If these prompts are deleted (which is a hard delete in the DB), on server startup
|
||||
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
|
||||
insert_input_prompt_if_not_exists(
|
||||
user=None,
|
||||
input_prompt_id=input_prompt.get("id"),
|
||||
prompt=input_prompt["prompt"],
|
||||
content=input_prompt["content"],
|
||||
is_public=input_prompt["is_public"],
|
||||
active=input_prompt.get("active", True),
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
|
||||
def load_chat_yamls(
|
||||
prompt_yaml: str = PROMPTS_YAML,
|
||||
personas_yaml: str = PERSONAS_YAML,
|
||||
input_prompts_yaml: str = INPUT_PROMPT_YAML,
|
||||
) -> None:
|
||||
load_prompts_from_yaml(prompt_yaml)
|
||||
load_personas_from_yaml(personas_yaml)
|
||||
load_input_prompts_from_yaml(input_prompts_yaml)
|
||||
|
||||
@@ -46,15 +46,22 @@ class LLMRelevanceFilterResponse(BaseModel):
|
||||
relevant_chunk_indices: list[int]
|
||||
|
||||
|
||||
class RelevanceChunk(BaseModel):
|
||||
# TODO make this document level. Also slight misnomer here as this is actually
|
||||
# done at the section level currently rather than the chunk
|
||||
relevant: bool | None = None
|
||||
class RelevanceAnalysis(BaseModel):
|
||||
relevant: bool
|
||||
content: str | None = None
|
||||
|
||||
|
||||
class LLMRelevanceSummaryResponse(BaseModel):
|
||||
relevance_summaries: dict[str, RelevanceChunk]
|
||||
class SectionRelevancePiece(RelevanceAnalysis):
|
||||
"""LLM analysis mapped to an Inference Section"""
|
||||
|
||||
document_id: str
|
||||
chunk_id: int # ID of the center chunk for a given inference section
|
||||
|
||||
|
||||
class DocumentRelevance(BaseModel):
|
||||
"""Contains all relevance information for a given search"""
|
||||
|
||||
relevance_summaries: dict[str, RelevanceAnalysis]
|
||||
|
||||
|
||||
class DanswerAnswerPiece(BaseModel):
|
||||
|
||||
@@ -5,7 +5,7 @@ personas:
|
||||
# this is for DanswerBot to use when tagged in a non-configured channel
|
||||
# Careful setting specific IDs, this won't autoincrement the next ID value for postgres
|
||||
- id: 0
|
||||
name: "Danswer"
|
||||
name: "Knowledge"
|
||||
description: >
|
||||
Assistant with access to documents from your Connected Sources.
|
||||
# Default Prompt objects attached to the persona, see prompts.yaml
|
||||
@@ -17,7 +17,7 @@ personas:
|
||||
num_chunks: 10
|
||||
# Enable/Disable usage of the LLM chunk filter feature whereby each chunk is passed to the LLM to determine
|
||||
# if the chunk is useful or not towards the latest user query
|
||||
# This feature can be overriden for all personas via DISABLE_LLM_CHUNK_FILTER env variable
|
||||
# This feature can be overriden for all personas via DISABLE_LLM_DOC_RELEVANCE env variable
|
||||
llm_relevance_filter: true
|
||||
# Enable/Disable usage of the LLM to extract query time filters including source type and time range filters
|
||||
llm_filter_extraction: true
|
||||
@@ -37,12 +37,15 @@ personas:
|
||||
# - "Engineer Onboarding"
|
||||
# - "Benefits"
|
||||
document_sets: []
|
||||
|
||||
icon_shape: 23013
|
||||
icon_color: "#6FB1FF"
|
||||
display_priority: 1
|
||||
is_visible: true
|
||||
|
||||
- id: 1
|
||||
name: "GPT"
|
||||
name: "General"
|
||||
description: >
|
||||
Assistant with no access to documents. Chat with just the Language Model.
|
||||
Assistant with no access to documents. Chat with just the Large Language Model.
|
||||
prompts:
|
||||
- "OnlyLLM"
|
||||
num_chunks: 0
|
||||
@@ -50,7 +53,10 @@ personas:
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
|
||||
icon_shape: 50910
|
||||
icon_color: "#FF6F6F"
|
||||
display_priority: 0
|
||||
is_visible: true
|
||||
|
||||
- id: 2
|
||||
name: "Paraphrase"
|
||||
@@ -63,3 +69,25 @@ personas:
|
||||
llm_filter_extraction: true
|
||||
recency_bias: "auto"
|
||||
document_sets: []
|
||||
icon_shape: 45519
|
||||
icon_color: "#6FFF8D"
|
||||
display_priority: 2
|
||||
is_visible: false
|
||||
|
||||
|
||||
- id: 3
|
||||
name: "Art"
|
||||
description: >
|
||||
Assistant for generating images based on descriptions.
|
||||
prompts:
|
||||
- "ImageGeneration"
|
||||
num_chunks: 0
|
||||
llm_relevance_filter: false
|
||||
llm_filter_extraction: false
|
||||
recency_bias: "no_decay"
|
||||
document_sets: []
|
||||
icon_shape: 234124
|
||||
icon_color: "#9B59B6"
|
||||
image_generation: true
|
||||
display_priority: 3
|
||||
is_visible: true
|
||||
|
||||
@@ -51,7 +51,8 @@ from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import SearchType
|
||||
@@ -60,6 +61,7 @@ from danswer.search.retrieval.search_runner import inference_sections_from_ids
|
||||
from danswer.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.search.utils import dedupe_documents
|
||||
from danswer.search.utils import drop_llm_indices
|
||||
from danswer.search.utils import relevant_sections_to_indices
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
@@ -178,7 +180,7 @@ def _handle_internet_search_tool_response_summary(
|
||||
rephrased_query=internet_search_response.revised_query,
|
||||
top_documents=response_docs,
|
||||
predicted_flow=QueryFlow.QUESTION_ANSWER,
|
||||
predicted_search=SearchType.HYBRID,
|
||||
predicted_search=SearchType.SEMANTIC,
|
||||
applied_source_filters=[],
|
||||
applied_time_cutoff=None,
|
||||
recency_bias_multiplier=1.0,
|
||||
@@ -187,37 +189,46 @@ def _handle_internet_search_tool_response_summary(
|
||||
)
|
||||
|
||||
|
||||
def _check_should_force_search(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
) -> ForceUseTool | None:
|
||||
# If files are already provided, don't run the search tool
|
||||
def _get_force_search_settings(
|
||||
new_msg_req: CreateChatMessageRequest, tools: list[Tool]
|
||||
) -> ForceUseTool:
|
||||
internet_search_available = any(
|
||||
isinstance(tool, InternetSearchTool) for tool in tools
|
||||
)
|
||||
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
|
||||
|
||||
if not internet_search_available and not search_tool_available:
|
||||
# Does not matter much which tool is set here as force is false and neither tool is available
|
||||
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
|
||||
|
||||
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
|
||||
# Currently, the internet search tool does not support query override
|
||||
args = (
|
||||
{"query": new_msg_req.query_override}
|
||||
if new_msg_req.query_override and tool_name == SearchTool._NAME
|
||||
else None
|
||||
)
|
||||
|
||||
if new_msg_req.file_descriptors:
|
||||
return None
|
||||
# If user has uploaded files they're using, don't run any of the search tools
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name)
|
||||
|
||||
if (
|
||||
new_msg_req.query_override
|
||||
or (
|
||||
should_force_search = any(
|
||||
[
|
||||
new_msg_req.retrieval_options
|
||||
and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS
|
||||
)
|
||||
or new_msg_req.search_doc_ids
|
||||
or DISABLE_LLM_CHOOSE_SEARCH
|
||||
):
|
||||
args = (
|
||||
{"query": new_msg_req.query_override}
|
||||
if new_msg_req.query_override
|
||||
else None
|
||||
)
|
||||
# if we are using selected docs, just put something here so the Tool doesn't need
|
||||
# to build its own args via an LLM call
|
||||
if new_msg_req.search_doc_ids:
|
||||
args = {"query": new_msg_req.message}
|
||||
and new_msg_req.retrieval_options.run_search
|
||||
== OptionalSearchSetting.ALWAYS,
|
||||
new_msg_req.search_doc_ids,
|
||||
DISABLE_LLM_CHOOSE_SEARCH,
|
||||
]
|
||||
)
|
||||
|
||||
return ForceUseTool(
|
||||
tool_name=SearchTool._NAME,
|
||||
args=args,
|
||||
)
|
||||
return None
|
||||
if should_force_search:
|
||||
# If we are using selected docs, just put something here so the Tool doesn't need to build its own args via an LLM call
|
||||
args = {"query": new_msg_req.message} if new_msg_req.search_doc_ids else args
|
||||
return ForceUseTool(force_use=True, tool_name=tool_name, args=args)
|
||||
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name, args=args)
|
||||
|
||||
|
||||
ChatPacket = (
|
||||
@@ -253,7 +264,6 @@ def stream_chat_message_objects(
|
||||
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
|
||||
"""
|
||||
try:
|
||||
user_id = user.id if user is not None else None
|
||||
@@ -274,7 +284,10 @@ def stream_chat_message_objects(
|
||||
# use alternate persona if alternative assistant id is passed in
|
||||
if alternate_assistant_id is not None:
|
||||
persona = get_persona_by_id(
|
||||
alternate_assistant_id, user=user, db_session=db_session
|
||||
alternate_assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
else:
|
||||
persona = chat_session.persona
|
||||
@@ -297,7 +310,13 @@ def stream_chat_message_objects(
|
||||
except GenAIDisabledException:
|
||||
raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.")
|
||||
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
llm_provider = llm.config.model_provider
|
||||
llm_model_name = llm.config.model_name
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm_model_name,
|
||||
provider_type=llm_provider,
|
||||
)
|
||||
llm_tokenizer_encode_func = cast(
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
@@ -361,6 +380,14 @@ def stream_chat_message_objects(
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
|
||||
# Disable Query Rephrasing for the first message
|
||||
# This leads to a better first response since the LLM rephrasing the question
|
||||
# leads to worst search quality
|
||||
if not history_msgs:
|
||||
new_msg_req.query_override = (
|
||||
new_msg_req.query_override or new_msg_req.message
|
||||
)
|
||||
|
||||
# load all files needed for this chat chain in memory
|
||||
files = load_all_chat_files(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
@@ -476,6 +503,9 @@ def stream_chat_message_objects(
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
evaluation_type=LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
@@ -544,9 +574,11 @@ def stream_chat_message_objects(
|
||||
tools.extend(tool_list)
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(tools)
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
tools, llm_tokenizer
|
||||
)
|
||||
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
llm_provider, llm_model_name
|
||||
)
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
@@ -576,11 +608,7 @@ def stream_chat_message_objects(
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
],
|
||||
tools=tools,
|
||||
force_use_tool=(
|
||||
_check_should_force_search(new_msg_req)
|
||||
if search_tool and len(tools) == 1
|
||||
else None
|
||||
),
|
||||
force_use_tool=_get_force_search_settings(new_msg_req, tools),
|
||||
)
|
||||
|
||||
reference_db_search_docs = None
|
||||
@@ -606,18 +634,28 @@ def stream_chat_message_objects(
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
chunk_indices = packet.response
|
||||
relevance_sections = packet.response
|
||||
|
||||
if reference_db_search_docs is not None and dropped_indices:
|
||||
chunk_indices = drop_llm_indices(
|
||||
llm_indices=chunk_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
if reference_db_search_docs is not None:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
relevant_chunk_indices=llm_indices
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
relevant_chunk_indices=chunk_indices
|
||||
)
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
@@ -655,18 +693,29 @@ def stream_chat_message_objects(
|
||||
yield cast(ChatPacket, packet)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to process chat message")
|
||||
|
||||
# Don't leak the API key
|
||||
error_msg = str(e)
|
||||
if llm.config.api_key and llm.config.api_key.lower() in error_msg.lower():
|
||||
|
||||
logger.exception(f"Failed to process chat message: {error_msg}")
|
||||
|
||||
if "Illegal header value b'Bearer '" in error_msg:
|
||||
error_msg = (
|
||||
f"LLM failed to respond. Invalid API "
|
||||
f"key error from '{llm.config.model_provider}'."
|
||||
f"Authentication error: Invalid or empty API key provided for '{llm.config.model_provider}'. "
|
||||
"Please check your API key configuration."
|
||||
)
|
||||
elif (
|
||||
"Invalid leading whitespace, reserved character(s), or return character(s) in header value"
|
||||
in error_msg
|
||||
):
|
||||
error_msg = (
|
||||
f"Authentication error: Invalid API key format for '{llm.config.model_provider}'. "
|
||||
"Please ensure your API key does not contain leading/trailing whitespace or invalid characters."
|
||||
)
|
||||
elif llm.config.api_key and llm.config.api_key.lower() in error_msg.lower():
|
||||
error_msg = f"LLM failed to respond. Invalid API key error from '{llm.config.model_provider}'."
|
||||
else:
|
||||
error_msg = "An unexpected error occurred while processing your request. Please try again later."
|
||||
|
||||
yield StreamingError(error=error_msg)
|
||||
# Cancel the transaction so that no messages are saved
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
|
||||
@@ -30,7 +30,23 @@ prompts:
|
||||
# Prompts the LLM to include citations in the for [1], [2] etc.
|
||||
# which get parsed to match the passed in sources
|
||||
include_citations: true
|
||||
|
||||
|
||||
- name: "ImageGeneration"
|
||||
description: "Generates images based on user prompts!"
|
||||
system: >
|
||||
You are an advanced image generation system capable of creating diverse and detailed images.
|
||||
|
||||
You can interpret user prompts and generate high-quality, creative images that match their descriptions.
|
||||
|
||||
You always strive to create safe and appropriate content, avoiding any harmful or offensive imagery.
|
||||
task: >
|
||||
Generate an image based on the user's description.
|
||||
|
||||
Provide a detailed description of the generated image, including key elements, colors, and composition.
|
||||
|
||||
If the request is not possible or appropriate, explain why and suggest alternatives.
|
||||
datetime_aware: true
|
||||
include_citations: false
|
||||
|
||||
- name: "OnlyLLM"
|
||||
description: "Chat directly with the LLM!"
|
||||
|
||||
@@ -129,6 +129,17 @@ POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
|
||||
# defaults to False
|
||||
POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true"
|
||||
|
||||
# recycle timeout in seconds
|
||||
POSTGRES_POOL_RECYCLE_DEFAULT = 60 * 20 # 20 minutes
|
||||
try:
|
||||
POSTGRES_POOL_RECYCLE = int(
|
||||
os.environ.get("POSTGRES_POOL_RECYCLE", POSTGRES_POOL_RECYCLE_DEFAULT)
|
||||
)
|
||||
except ValueError:
|
||||
POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT
|
||||
|
||||
#####
|
||||
# Connector Configs
|
||||
@@ -191,6 +202,11 @@ CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING = (
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Attachments exceeding this size will not be retrieved (in bytes)
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 50 * 1024 * 1024)
|
||||
)
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
@@ -212,10 +228,11 @@ EXPERIMENTAL_CHECKPOINTING_ENABLED = (
|
||||
os.environ.get("EXPERIMENTAL_CHECKPOINTING_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
PRUNING_DISABLED = -1
|
||||
DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day
|
||||
|
||||
PREVENT_SIMULTANEOUS_PRUNING = (
|
||||
os.environ.get("PREVENT_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||
ALLOW_SIMULTANEOUS_PRUNING = (
|
||||
os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true"
|
||||
)
|
||||
|
||||
# This is the maxiumum rate at which documents are queried for a pruning job. 0 disables the limitation.
|
||||
@@ -248,8 +265,13 @@ DISABLE_INDEX_UPDATE_ON_SWAP = (
|
||||
# fairly large amount of memory in order to increase substantially, since
|
||||
# each worker loads the embedding models into memory.
|
||||
NUM_INDEXING_WORKERS = int(os.environ.get("NUM_INDEXING_WORKERS") or 1)
|
||||
NUM_SECONDARY_INDEXING_WORKERS = int(
|
||||
os.environ.get("NUM_SECONDARY_INDEXING_WORKERS") or NUM_INDEXING_WORKERS
|
||||
)
|
||||
# More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors)
|
||||
ENABLE_MINI_CHUNK = os.environ.get("ENABLE_MINI_CHUNK", "").lower() == "true"
|
||||
ENABLE_MULTIPASS_INDEXING = (
|
||||
os.environ.get("ENABLE_MULTIPASS_INDEXING", "").lower() == "true"
|
||||
)
|
||||
# Finer grained chunking for more detail retention
|
||||
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
|
||||
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
|
||||
@@ -260,6 +282,10 @@ SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() ==
|
||||
# Timeout to wait for job's last update before killing it, in hours
|
||||
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))
|
||||
|
||||
# The indexer will warn in the logs whenver a document exceeds this threshold (in bytes)
|
||||
INDEXING_SIZE_WARNING_THRESHOLD = int(
|
||||
os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD", 100 * 1024 * 1024)
|
||||
)
|
||||
|
||||
#####
|
||||
# Miscellaneous
|
||||
|
||||
@@ -3,12 +3,13 @@ import os
|
||||
|
||||
PROMPTS_YAML = "./danswer/chat/prompts.yaml"
|
||||
PERSONAS_YAML = "./danswer/chat/personas.yaml"
|
||||
INPUT_PROMPT_YAML = "./danswer/chat/input_prompts.yaml"
|
||||
|
||||
NUM_RETURNED_HITS = 50
|
||||
# Used for LLM filtering and reranking
|
||||
# We want this to be approximately the number of results we want to show on the first page
|
||||
# It cannot be too large due to cost and latency implications
|
||||
NUM_RERANKED_RESULTS = 20
|
||||
NUM_POSTPROCESSED_RESULTS = 20
|
||||
|
||||
# May be less depending on model
|
||||
MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
|
||||
@@ -32,11 +33,6 @@ DISABLE_LLM_QUERY_ANSWERABILITY = QA_PROMPT_OVERRIDE == "weak"
|
||||
# Note this is not in any of the deployment configs yet
|
||||
CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 0)
|
||||
CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 0)
|
||||
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
|
||||
# in relation to the user query
|
||||
DISABLE_LLM_CHUNK_FILTER = (
|
||||
os.environ.get("DISABLE_LLM_CHUNK_FILTER", "").lower() == "true"
|
||||
)
|
||||
# Whether the LLM should be used to decide if a search would help given the chat history
|
||||
DISABLE_LLM_CHOOSE_SEARCH = (
|
||||
os.environ.get("DISABLE_LLM_CHOOSE_SEARCH", "").lower() == "true"
|
||||
@@ -47,15 +43,11 @@ DISABLE_LLM_QUERY_REPHRASE = (
|
||||
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
|
||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
||||
# Keyword Search Drop Stopwords
|
||||
# If user has changed the default model, would most likely be to use a multilingual
|
||||
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
|
||||
if os.environ.get("EDIT_KEYWORD_QUERY"):
|
||||
EDIT_KEYWORD_QUERY = os.environ.get("EDIT_KEYWORD_QUERY", "").lower() == "true"
|
||||
else:
|
||||
EDIT_KEYWORD_QUERY = not os.environ.get("DOCUMENT_ENCODER_MODEL")
|
||||
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
|
||||
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.62)))
|
||||
HYBRID_ALPHA_KEYWORD = max(
|
||||
0, min(1, float(os.environ.get("HYBRID_ALPHA_KEYWORD") or 0.4))
|
||||
)
|
||||
# Weighting factor between Title and Content of documents during search, 1 for completely
|
||||
# Title based. Default heavily favors Content because Title is also included at the top of
|
||||
# Content. This is to avoid cases where the Content is very relevant but it may not be clear
|
||||
@@ -63,6 +55,7 @@ HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.62)))
|
||||
TITLE_CONTENT_RATIO = max(
|
||||
0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.20))
|
||||
)
|
||||
|
||||
# A list of languages passed to the LLM to rephase the query
|
||||
# For example "English,French,Spanish", be sure to use the "," separator
|
||||
MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None
|
||||
@@ -75,16 +68,16 @@ LANGUAGE_CHAT_NAMING_HINT = (
|
||||
or "The name of the conversation must be in the same language as the user query."
|
||||
)
|
||||
|
||||
|
||||
# Agentic search takes significantly more tokens and therefore has much higher cost.
|
||||
# This configuration allows users to get a search-only experience with instant results
|
||||
# and no involvement from the LLM.
|
||||
# Additionally, some LLM providers have strict rate limits which may prohibit
|
||||
# sending many API requests at once (as is done in agentic search).
|
||||
DISABLE_AGENTIC_SEARCH = (
|
||||
os.environ.get("DISABLE_AGENTIC_SEARCH") or "false"
|
||||
).lower() == "true"
|
||||
|
||||
# Whether the LLM should evaluate all of the document chunks passed in for usefulness
|
||||
# in relation to the user query
|
||||
DISABLE_LLM_DOC_RELEVANCE = (
|
||||
os.environ.get("DISABLE_LLM_DOC_RELEVANCE", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Stops streaming answers back to the UI if this pattern is seen:
|
||||
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None
|
||||
|
||||
@@ -44,7 +44,6 @@ QUERY_EVENT_ID = "query_event_id"
|
||||
LLM_CHUNKS = "llm_chunks"
|
||||
|
||||
# For chunking/processing chunks
|
||||
MAX_CHUNK_TITLE_LEN = 1000
|
||||
RETURN_SEPARATOR = "\n\r\n"
|
||||
SECTION_SEPARATOR = "\n\n"
|
||||
# For combining attributes, doesn't have to be unique/perfect to work
|
||||
@@ -60,12 +59,37 @@ DISABLED_GEN_AI_MSG = (
|
||||
"You can still use Danswer as a search engine."
|
||||
)
|
||||
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
POSTGRES_CELERY_APP_NAME = "celery"
|
||||
POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat"
|
||||
POSTGRES_CELERY_WORKER_APP_NAME = "celery_worker"
|
||||
POSTGRES_PERMISSIONS_APP_NAME = "permissions"
|
||||
POSTGRES_UNKNOWN_APP_NAME = "unknown"
|
||||
|
||||
# API Keys
|
||||
DANSWER_API_KEY_PREFIX = "API_KEY__"
|
||||
DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "danswerapikey.ai"
|
||||
UNNAMED_KEY_PLACEHOLDER = "Unnamed"
|
||||
|
||||
# Key-Value store keys
|
||||
KV_REINDEX_KEY = "needs_reindexing"
|
||||
KV_SEARCH_SETTINGS = "search_settings"
|
||||
KV_USER_STORE_KEY = "INVITED_USERS"
|
||||
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
|
||||
KV_CRED_KEY = "credential_id_{}"
|
||||
KV_GMAIL_CRED_KEY = "gmail_app_credential"
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
|
||||
KV_GOOGLE_DRIVE_CRED_KEY = "google_drive_app_credential"
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
|
||||
KV_SLACK_BOT_TOKENS_CONFIG_KEY = "slack_bot_tokens_config_key"
|
||||
KV_GEN_AI_KEY_CHECK_TIME = "genai_api_key_last_check_time"
|
||||
KV_SETTINGS_KEY = "danswer_settings"
|
||||
KV_CUSTOMER_UUID_KEY = "customer_uuid"
|
||||
KV_ENTERPRISE_SETTINGS_KEY = "danswer_enterprise_settings"
|
||||
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
# Special case, document passed in via Danswer APIs without specifying a source type
|
||||
@@ -109,6 +133,10 @@ class DocumentSource(str, Enum):
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
|
||||
|
||||
class NotificationType(str, Enum):
|
||||
REINDEX = "reindex"
|
||||
|
||||
|
||||
class BlobType(str, Enum):
|
||||
R2 = "r2"
|
||||
S3 = "s3"
|
||||
|
||||
@@ -12,7 +12,7 @@ import os
|
||||
# The useable models configured as below must be SentenceTransformer compatible
|
||||
# NOTE: DO NOT CHANGE SET THESE UNLESS YOU KNOW WHAT YOU ARE DOING
|
||||
# IDEALLY, YOU SHOULD CHANGE EMBEDDING MODELS VIA THE UI
|
||||
DEFAULT_DOCUMENT_ENCODER_MODEL = "intfloat/e5-base-v2"
|
||||
DEFAULT_DOCUMENT_ENCODER_MODEL = "nomic-ai/nomic-embed-text-v1"
|
||||
DOCUMENT_ENCODER_MODEL = (
|
||||
os.environ.get("DOCUMENT_ENCODER_MODEL") or DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
)
|
||||
@@ -34,17 +34,16 @@ OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS = False
|
||||
SIM_SCORE_RANGE_LOW = float(os.environ.get("SIM_SCORE_RANGE_LOW") or 0.0)
|
||||
SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
|
||||
# Certain models like e5, BGE, etc use a prefix for asymmetric retrievals (query generally shorter than docs)
|
||||
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "query: ")
|
||||
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
|
||||
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "search_query: ")
|
||||
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "search_document: ")
|
||||
# Purely an optimization, memory limitation consideration
|
||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||
# don't send over too many chunks at once, as sending too many could cause timeouts
|
||||
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = 512
|
||||
# For score display purposes, only way is to know the expected ranges
|
||||
CROSS_ENCODER_RANGE_MAX = 1
|
||||
CROSS_ENCODER_RANGE_MIN = 0
|
||||
|
||||
# Unused currently, can't be used with the current default encoder model due to its output range
|
||||
SEARCH_DISTANCE_CUTOFF = 0
|
||||
|
||||
|
||||
#####
|
||||
# Generative AI Model Configs
|
||||
|
||||
@@ -169,7 +169,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
end: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError("Blog storage")
|
||||
raise ConnectorMissingCredentialError("Blob storage")
|
||||
|
||||
paginator = self.s3_client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix)
|
||||
@@ -230,7 +230,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError("Blog storage")
|
||||
raise ConnectorMissingCredentialError("Blob storage")
|
||||
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
@@ -13,6 +13,7 @@ import bs4
|
||||
from atlassian import Confluence # type:ignore
|
||||
from requests import HTTPError
|
||||
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_INDEX_ONLY_ACTIVE_PAGES
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_SKIP_LABEL_INDEXING
|
||||
@@ -217,16 +218,19 @@ class RecursiveIndexer:
|
||||
self,
|
||||
batch_size: int,
|
||||
confluence_client: Confluence,
|
||||
index_origin: bool,
|
||||
index_recursively: bool,
|
||||
origin_page_id: str,
|
||||
) -> None:
|
||||
self.batch_size = 1
|
||||
# batch_size
|
||||
self.confluence_client = confluence_client
|
||||
self.index_origin = index_origin
|
||||
self.index_recursively = index_recursively
|
||||
self.origin_page_id = origin_page_id
|
||||
self.pages = self.recurse_children_pages(0, self.origin_page_id)
|
||||
|
||||
def get_origin_page(self) -> list[dict[str, Any]]:
|
||||
return [self._fetch_origin_page()]
|
||||
|
||||
def get_pages(self, ind: int, size: int) -> list[dict]:
|
||||
if ind * size > len(self.pages):
|
||||
return []
|
||||
@@ -282,12 +286,11 @@ class RecursiveIndexer:
|
||||
current_level_pages = next_level_pages
|
||||
next_level_pages = []
|
||||
|
||||
if self.index_origin:
|
||||
try:
|
||||
origin_page = self._fetch_origin_page()
|
||||
pages.append(origin_page)
|
||||
except Exception as e:
|
||||
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
|
||||
try:
|
||||
origin_page = self._fetch_origin_page()
|
||||
pages.append(origin_page)
|
||||
except Exception as e:
|
||||
logger.warning(f"Appending origin page with id {page_id} failed: {e}")
|
||||
|
||||
return pages
|
||||
|
||||
@@ -340,7 +343,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
wiki_page_url: str,
|
||||
index_origin: bool = True,
|
||||
index_recursively: bool = True,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
# if a page has one of the labels specified in this list, we will just
|
||||
@@ -352,7 +355,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
self.recursive_indexer: RecursiveIndexer | None = None
|
||||
self.index_origin = index_origin
|
||||
self.index_recursively = index_recursively
|
||||
(
|
||||
self.wiki_base,
|
||||
self.space,
|
||||
@@ -369,7 +372,7 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
|
||||
logger.info(
|
||||
f"wiki_base: {self.wiki_base}, space: {self.space}, page_id: {self.page_id},"
|
||||
+ f" space_level_scan: {self.space_level_scan}, origin: {self.index_origin}"
|
||||
+ f" space_level_scan: {self.space_level_scan}, index_recursively: {self.index_recursively}"
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
@@ -453,10 +456,13 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
origin_page_id=self.page_id,
|
||||
batch_size=self.batch_size,
|
||||
confluence_client=self.confluence_client,
|
||||
index_origin=self.index_origin,
|
||||
index_recursively=self.index_recursively,
|
||||
)
|
||||
|
||||
return self.recursive_indexer.get_pages(start_ind, batch_size)
|
||||
if self.index_recursively:
|
||||
return self.recursive_indexer.get_pages(start_ind, batch_size)
|
||||
else:
|
||||
return self.recursive_indexer.get_origin_page()
|
||||
|
||||
pages: list[dict[str, Any]] = []
|
||||
|
||||
@@ -555,6 +561,17 @@ class ConfluenceConnector(LoadConnector, PollConnector):
|
||||
if attachment["title"] not in files_in_used:
|
||||
continue
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
continue
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
response = confluence_client._session.get(download_link)
|
||||
|
||||
|
||||
@@ -56,6 +56,16 @@ def extract_text_from_content(content: dict) -> str:
|
||||
return " ".join(texts)
|
||||
|
||||
|
||||
def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any:
|
||||
if hasattr(jira_issue.fields, field):
|
||||
return getattr(jira_issue.fields, field)
|
||||
|
||||
try:
|
||||
return jira_issue.raw["fields"][field]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _get_comment_strs(
|
||||
jira: Issue, comment_email_blacklist: tuple[str, ...] = ()
|
||||
) -> list[str]:
|
||||
@@ -117,8 +127,10 @@ def fetch_jira_issues_batch(
|
||||
continue
|
||||
|
||||
comments = _get_comment_strs(jira, comment_email_blacklist)
|
||||
semantic_rep = f"{jira.fields.description}\n" + "\n".join(
|
||||
[f"Comment: {comment}" for comment in comments]
|
||||
semantic_rep = (
|
||||
f"{jira.fields.description}\n"
|
||||
if jira.fields.description
|
||||
else "" + "\n".join([f"Comment: {comment}" for comment in comments])
|
||||
)
|
||||
|
||||
page_url = f"{jira_client.client_info()}/browse/{jira.key}"
|
||||
@@ -147,14 +159,18 @@ def fetch_jira_issues_batch(
|
||||
pass
|
||||
|
||||
metadata_dict = {}
|
||||
if jira.fields.priority:
|
||||
metadata_dict["priority"] = jira.fields.priority.name
|
||||
if jira.fields.status:
|
||||
metadata_dict["status"] = jira.fields.status.name
|
||||
if jira.fields.resolution:
|
||||
metadata_dict["resolution"] = jira.fields.resolution.name
|
||||
if jira.fields.labels:
|
||||
metadata_dict["label"] = jira.fields.labels
|
||||
priority = best_effort_get_field_from_issue(jira, "priority")
|
||||
if priority:
|
||||
metadata_dict["priority"] = priority.name
|
||||
status = best_effort_get_field_from_issue(jira, "status")
|
||||
if status:
|
||||
metadata_dict["status"] = status.name
|
||||
resolution = best_effort_get_field_from_issue(jira, "resolution")
|
||||
if resolution:
|
||||
metadata_dict["resolution"] = resolution.name
|
||||
labels = best_effort_get_field_from_issue(jira, "labels")
|
||||
if labels:
|
||||
metadata_dict["label"] = labels
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
|
||||
@@ -64,7 +64,7 @@ class DiscourseConnector(PollConnector):
|
||||
self.permissions: DiscoursePerms | None = None
|
||||
self.active_categories: set | None = None
|
||||
|
||||
@rate_limit_builder(max_calls=100, period=60)
|
||||
@rate_limit_builder(max_calls=50, period=60)
|
||||
def _make_request(self, endpoint: str, params: dict | None = None) -> Response:
|
||||
if not self.permissions:
|
||||
raise ConnectorMissingCredentialError("Discourse")
|
||||
|
||||
@@ -11,16 +11,17 @@ from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.connectors.gmail.constants import CRED_KEY
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_CRED_KEY
|
||||
from danswer.configs.constants import KV_GMAIL_CRED_KEY
|
||||
from danswer.configs.constants import KV_GMAIL_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.gmail.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
from danswer.connectors.gmail.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.gmail.constants import GMAIL_CRED_KEY
|
||||
from danswer.connectors.gmail.constants import (
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.gmail.constants import GMAIL_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.gmail.constants import SCOPES
|
||||
from danswer.db.credentials import update_credential_json
|
||||
from danswer.db.models import User
|
||||
@@ -71,7 +72,7 @@ def get_gmail_creds_for_service_account(
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = get_dynamic_config_store().load(CRED_KEY.format(str(credential_id)))
|
||||
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
if csrf != state:
|
||||
raise PermissionError(
|
||||
"State from Gmail Connector callback does not match expected"
|
||||
@@ -79,7 +80,7 @@ def verify_csrf(credential_id: int, state: str) -> None:
|
||||
|
||||
|
||||
def get_gmail_auth_url(credential_id: int) -> str:
|
||||
creds_str = str(get_dynamic_config_store().load(GMAIL_CRED_KEY))
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
@@ -91,12 +92,14 @@ def get_gmail_auth_url(credential_id: int) -> str:
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_dynamic_config_store().store(CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True) # type: ignore
|
||||
get_dynamic_config_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
def get_auth_url(credential_id: int) -> str:
|
||||
creds_str = str(get_dynamic_config_store().load(GMAIL_CRED_KEY))
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
@@ -108,7 +111,9 @@ def get_auth_url(credential_id: int) -> str:
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_dynamic_config_store().store(CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True) # type: ignore
|
||||
get_dynamic_config_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
@@ -146,28 +151,29 @@ def build_service_account_creds(
|
||||
credential_dict[DB_CREDENTIALS_DICT_DELEGATED_USER_KEY] = delegated_user_email
|
||||
|
||||
return CredentialBase(
|
||||
source=DocumentSource.GMAIL,
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
)
|
||||
|
||||
|
||||
def get_google_app_gmail_cred() -> GoogleAppCredentials:
|
||||
creds_str = str(get_dynamic_config_store().load(GMAIL_CRED_KEY))
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_CRED_KEY))
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_google_app_gmail_cred(app_credentials: GoogleAppCredentials) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
GMAIL_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def delete_google_app_gmail_cred() -> None:
|
||||
get_dynamic_config_store().delete(GMAIL_CRED_KEY)
|
||||
get_dynamic_config_store().delete(KV_GMAIL_CRED_KEY)
|
||||
|
||||
|
||||
def get_gmail_service_account_key() -> GoogleServiceAccountKey:
|
||||
creds_str = str(get_dynamic_config_store().load(GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY))
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
|
||||
|
||||
@@ -175,19 +181,19 @@ def upsert_gmail_service_account_key(
|
||||
service_account_key: GoogleServiceAccountKey,
|
||||
) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def delete_gmail_service_account_key() -> None:
|
||||
get_dynamic_config_store().delete(GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
get_dynamic_config_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
|
||||
|
||||
def delete_service_account_key() -> None:
|
||||
get_dynamic_config_store().delete(GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
get_dynamic_config_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY)
|
||||
|
||||
@@ -1,7 +1,4 @@
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "gmail_tokens"
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "gmail_delegated_user"
|
||||
CRED_KEY = "credential_id_{}"
|
||||
GMAIL_CRED_KEY = "gmail_app_credential"
|
||||
GMAIL_SERVICE_ACCOUNT_KEY = "gmail_service_account_key"
|
||||
SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
|
||||
@@ -306,24 +306,29 @@ def get_all_files_batched(
|
||||
|
||||
def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
|
||||
mime_type = file["mimeType"]
|
||||
|
||||
if mime_type not in set(item.value for item in GDriveMimeType):
|
||||
# Unsupported file types can still have a title, finding this way is still useful
|
||||
return UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
|
||||
if mime_type == GDriveMimeType.DOC.value:
|
||||
return (
|
||||
if mime_type in [
|
||||
GDriveMimeType.DOC.value,
|
||||
GDriveMimeType.PPT.value,
|
||||
GDriveMimeType.SPREADSHEET.value,
|
||||
]:
|
||||
export_mime_type = "text/plain"
|
||||
if mime_type == GDriveMimeType.SPREADSHEET.value:
|
||||
export_mime_type = "text/csv"
|
||||
elif mime_type == GDriveMimeType.PPT.value:
|
||||
export_mime_type = "text/plain"
|
||||
|
||||
response = (
|
||||
service.files()
|
||||
.export(fileId=file["id"], mimeType="text/plain")
|
||||
.export(fileId=file["id"], mimeType=export_mime_type)
|
||||
.execute()
|
||||
.decode("utf-8")
|
||||
)
|
||||
elif mime_type == GDriveMimeType.SPREADSHEET.value:
|
||||
return (
|
||||
service.files()
|
||||
.export(fileId=file["id"], mimeType="text/csv")
|
||||
.execute()
|
||||
.decode("utf-8")
|
||||
)
|
||||
return response.decode("utf-8")
|
||||
|
||||
elif mime_type == GDriveMimeType.WORD_DOC.value:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
return docx_to_text(file=io.BytesIO(response))
|
||||
@@ -333,9 +338,6 @@ def extract_text(file: dict[str, str], service: discovery.Resource) -> str:
|
||||
elif mime_type == GDriveMimeType.POWERPOINT.value:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
return pptx_to_text(file=io.BytesIO(response))
|
||||
elif mime_type == GDriveMimeType.PPT.value:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
return pptx_to_text(file=io.BytesIO(response))
|
||||
|
||||
return UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
|
||||
|
||||
@@ -11,7 +11,10 @@ from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.connectors.google_drive.constants import CRED_KEY
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
@@ -19,8 +22,6 @@ from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.google_drive.constants import GOOGLE_DRIVE_CRED_KEY
|
||||
from danswer.connectors.google_drive.constants import GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.google_drive.constants import SCOPES
|
||||
from danswer.db.credentials import update_credential_json
|
||||
from danswer.db.models import User
|
||||
@@ -71,7 +72,7 @@ def get_google_drive_creds_for_service_account(
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = get_dynamic_config_store().load(CRED_KEY.format(str(credential_id)))
|
||||
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
if csrf != state:
|
||||
raise PermissionError(
|
||||
"State from Google Drive Connector callback does not match expected"
|
||||
@@ -79,7 +80,7 @@ def verify_csrf(credential_id: int, state: str) -> None:
|
||||
|
||||
|
||||
def get_auth_url(credential_id: int) -> str:
|
||||
creds_str = str(get_dynamic_config_store().load(GOOGLE_DRIVE_CRED_KEY))
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
@@ -91,7 +92,9 @@ def get_auth_url(credential_id: int) -> str:
|
||||
parsed_url = cast(ParseResult, urlparse(auth_url))
|
||||
params = parse_qs(parsed_url.query)
|
||||
|
||||
get_dynamic_config_store().store(CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True) # type: ignore
|
||||
get_dynamic_config_store().store(
|
||||
KV_CRED_KEY.format(credential_id), params.get("state", [None])[0], encrypt=True
|
||||
) # type: ignore
|
||||
return str(auth_url)
|
||||
|
||||
|
||||
@@ -118,6 +121,7 @@ def update_credential_access_tokens(
|
||||
|
||||
|
||||
def build_service_account_creds(
|
||||
source: DocumentSource,
|
||||
delegated_user_email: str | None = None,
|
||||
) -> CredentialBase:
|
||||
service_account_key = get_service_account_key()
|
||||
@@ -131,34 +135,37 @@ def build_service_account_creds(
|
||||
return CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
)
|
||||
|
||||
|
||||
def get_google_app_cred() -> GoogleAppCredentials:
|
||||
creds_str = str(get_dynamic_config_store().load(GOOGLE_DRIVE_CRED_KEY))
|
||||
creds_str = str(get_dynamic_config_store().load(KV_GOOGLE_DRIVE_CRED_KEY))
|
||||
return GoogleAppCredentials(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_google_app_cred(app_credentials: GoogleAppCredentials) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def delete_google_app_cred() -> None:
|
||||
get_dynamic_config_store().delete(GOOGLE_DRIVE_CRED_KEY)
|
||||
get_dynamic_config_store().delete(KV_GOOGLE_DRIVE_CRED_KEY)
|
||||
|
||||
|
||||
def get_service_account_key() -> GoogleServiceAccountKey:
|
||||
creds_str = str(get_dynamic_config_store().load(GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY))
|
||||
creds_str = str(
|
||||
get_dynamic_config_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
)
|
||||
return GoogleServiceAccountKey(**json.loads(creds_str))
|
||||
|
||||
|
||||
def upsert_service_account_key(service_account_key: GoogleServiceAccountKey) -> None:
|
||||
get_dynamic_config_store().store(
|
||||
GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True
|
||||
)
|
||||
|
||||
|
||||
def delete_service_account_key() -> None:
|
||||
get_dynamic_config_store().delete(GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
get_dynamic_config_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user"
|
||||
CRED_KEY = "credential_id_{}"
|
||||
GOOGLE_DRIVE_CRED_KEY = "google_drive_app_credential"
|
||||
GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
|
||||
SCOPES = [
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly",
|
||||
|
||||
@@ -86,7 +86,6 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
categories: The categories to include in the index.
|
||||
pages: The pages to include in the index.
|
||||
recurse_depth: The depth to recurse into categories. -1 means unbounded recursion.
|
||||
connector_name: The name of the connector.
|
||||
language_code: The language code of the wiki.
|
||||
batch_size: The batch size for loading documents.
|
||||
|
||||
@@ -104,7 +103,6 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
categories: list[str],
|
||||
pages: list[str],
|
||||
recurse_depth: int,
|
||||
connector_name: str,
|
||||
language_code: str = "en",
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
@@ -118,10 +116,8 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
|
||||
# short names can only have ascii letters and digits
|
||||
self.connector_name = connector_name
|
||||
connector_name = "".join(ch for ch in connector_name if ch.isalnum())
|
||||
|
||||
self.family = family_class_dispatch(hostname, connector_name)()
|
||||
self.family = family_class_dispatch(hostname, "Wikipedia Connector")()
|
||||
self.site = pywikibot.Site(fam=self.family, code=language_code)
|
||||
self.categories = [
|
||||
pywikibot.Category(self.site, f"Category:{category.replace(' ', '_')}")
|
||||
@@ -210,7 +206,6 @@ class MediaWikiConnector(LoadConnector, PollConnector):
|
||||
if __name__ == "__main__":
|
||||
HOSTNAME = "fallout.fandom.com"
|
||||
test_connector = MediaWikiConnector(
|
||||
connector_name="Fallout",
|
||||
hostname=HOSTNAME,
|
||||
categories=["Fallout:_New_Vegas_factions"],
|
||||
pages=["Fallout: New Vegas"],
|
||||
|
||||
@@ -114,7 +114,9 @@ class DocumentBase(BaseModel):
|
||||
title: str | None = None
|
||||
from_ingestion_api: bool = False
|
||||
|
||||
def get_title_for_document_index(self) -> str | None:
|
||||
def get_title_for_document_index(
|
||||
self,
|
||||
) -> str | None:
|
||||
# If title is explicitly empty, return a None here for embedding purposes
|
||||
if self.title == "":
|
||||
return None
|
||||
|
||||
@@ -68,12 +68,13 @@ def make_slack_api_call_paginated(
|
||||
|
||||
|
||||
def make_slack_api_rate_limited(
|
||||
call: Callable[..., SlackResponse], max_retries: int = 3
|
||||
call: Callable[..., SlackResponse], max_retries: int = 7
|
||||
) -> Callable[..., SlackResponse]:
|
||||
"""Wraps calls to slack API so that they automatically handle rate limiting"""
|
||||
|
||||
@wraps(call)
|
||||
def rate_limited_call(**kwargs: Any) -> SlackResponse:
|
||||
last_exception = None
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
# Make the API call
|
||||
@@ -85,14 +86,20 @@ def make_slack_api_rate_limited(
|
||||
return response
|
||||
|
||||
except SlackApiError as e:
|
||||
if e.response["error"] == "ratelimited":
|
||||
last_exception = e
|
||||
try:
|
||||
error = e.response["error"]
|
||||
except KeyError:
|
||||
error = "unknown error"
|
||||
|
||||
if error == "ratelimited":
|
||||
# Handle rate limiting: get the 'Retry-After' header value and sleep for that duration
|
||||
retry_after = int(e.response.headers.get("Retry-After", 1))
|
||||
logger.info(
|
||||
f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
|
||||
)
|
||||
time.sleep(retry_after)
|
||||
elif e.response["error"] in ["already_reacted", "no_reaction"]:
|
||||
elif error in ["already_reacted", "no_reaction"]:
|
||||
# The response isn't used for reactions, this is basically just a pass
|
||||
return e.response
|
||||
else:
|
||||
@@ -100,7 +107,11 @@ def make_slack_api_rate_limited(
|
||||
raise
|
||||
|
||||
# If the code reaches this point, all retries have been exhausted
|
||||
raise Exception(f"Max retries ({max_retries}) exceeded")
|
||||
msg = f"Max retries ({max_retries}) exceeded"
|
||||
if last_exception:
|
||||
raise Exception(msg) from last_exception
|
||||
else:
|
||||
raise Exception(msg)
|
||||
|
||||
return rate_limited_call
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from playwright.sync_api import BrowserContext
|
||||
from playwright.sync_api import Playwright
|
||||
from playwright.sync_api import sync_playwright
|
||||
from requests_oauthlib import OAuth2Session # type:ignore
|
||||
from urllib3.exceptions import MaxRetryError
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import WEB_CONNECTOR_OAUTH_CLIENT_ID
|
||||
@@ -83,6 +84,13 @@ def check_internet_connection(url: str) -> None:
|
||||
try:
|
||||
response = requests.get(url, timeout=3)
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.SSLError as e:
|
||||
cause = (
|
||||
e.args[0].reason
|
||||
if isinstance(e.args, tuple) and isinstance(e.args[0], MaxRetryError)
|
||||
else e.args
|
||||
)
|
||||
raise Exception(f"SSL error {str(cause)}")
|
||||
except (requests.RequestException, ValueError):
|
||||
raise Exception(f"Unable to reach {url} - check your internet connection")
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ class WikipediaConnector(wiki.MediaWikiConnector):
|
||||
categories: list[str],
|
||||
pages: list[str],
|
||||
recurse_depth: int,
|
||||
connector_name: str,
|
||||
language_code: str = "en",
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
@@ -24,7 +23,6 @@ class WikipediaConnector(wiki.MediaWikiConnector):
|
||||
categories=categories,
|
||||
pages=pages,
|
||||
recurse_depth=recurse_depth,
|
||||
connector_name=connector_name,
|
||||
language_code=language_code,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
from zenpy import Zenpy # type: ignore
|
||||
from zenpy.lib.api_objects.help_centre_objects import Article # type: ignore
|
||||
|
||||
@@ -19,12 +21,24 @@ from danswer.connectors.models import Section
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
|
||||
|
||||
def _article_to_document(article: Article) -> Document:
|
||||
def _article_to_document(article: Article, content_tags: dict[str, str]) -> Document:
|
||||
author = BasicExpertInfo(
|
||||
display_name=article.author.name, email=article.author.email
|
||||
)
|
||||
update_time = time_str_to_utc(article.updated_at)
|
||||
labels = [str(label) for label in article.label_names]
|
||||
|
||||
# build metadata
|
||||
metadata: dict[str, str | list[str]] = {
|
||||
"labels": [str(label) for label in article.label_names if label],
|
||||
"content_tags": [
|
||||
content_tags[tag_id]
|
||||
for tag_id in article.content_tag_ids
|
||||
if tag_id in content_tags
|
||||
],
|
||||
}
|
||||
|
||||
# remove empty values
|
||||
metadata = {k: v for k, v in metadata.items() if v}
|
||||
|
||||
return Document(
|
||||
id=f"article:{article.id}",
|
||||
@@ -35,7 +49,7 @@ def _article_to_document(article: Article) -> Document:
|
||||
semantic_identifier=article.title,
|
||||
doc_updated_at=update_time,
|
||||
primary_owners=[author],
|
||||
metadata={"labels": labels} if labels else {},
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -48,6 +62,42 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.zendesk_client: Zenpy | None = None
|
||||
self.content_tags: dict[str, str] = {}
|
||||
|
||||
@retry(tries=3, delay=2, backoff=2)
|
||||
def _set_content_tags(
|
||||
self, subdomain: str, email: str, token: str, page_size: int = 30
|
||||
) -> None:
|
||||
# Construct the base URL
|
||||
base_url = f"https://{subdomain}.zendesk.com/api/v2/guide/content_tags"
|
||||
|
||||
# Set up authentication
|
||||
auth = (f"{email}/token", token)
|
||||
|
||||
# Set up pagination parameters
|
||||
params = {"page[size]": page_size}
|
||||
|
||||
try:
|
||||
while True:
|
||||
# Make the GET request
|
||||
response = requests.get(base_url, auth=auth, params=params)
|
||||
|
||||
# Check if the request was successful
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
content_tag_list = data.get("records", [])
|
||||
for tag in content_tag_list:
|
||||
self.content_tags[tag["id"]] = tag["name"]
|
||||
|
||||
# Check if there are more pages
|
||||
if data.get("meta", {}).get("has_more", False):
|
||||
params["page[after]"] = data["meta"]["after_cursor"]
|
||||
else:
|
||||
break
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code}\n{response.text}")
|
||||
except Exception as e:
|
||||
raise Exception(f"Error fetching content tags: {str(e)}")
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# Subdomain is actually the whole URL
|
||||
@@ -62,6 +112,11 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
email=credentials["zendesk_email"],
|
||||
token=credentials["zendesk_token"],
|
||||
)
|
||||
self._set_content_tags(
|
||||
subdomain,
|
||||
credentials["zendesk_email"],
|
||||
credentials["zendesk_token"],
|
||||
)
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
@@ -92,10 +147,30 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
):
|
||||
continue
|
||||
|
||||
doc_batch.append(_article_to_document(article))
|
||||
doc_batch.append(_article_to_document(article, self.content_tags))
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch.clear()
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import time
|
||||
|
||||
connector = ZendeskConnector()
|
||||
connector.load_credentials(
|
||||
{
|
||||
"zendesk_subdomain": os.environ["ZENDESK_SUBDOMAIN"],
|
||||
"zendesk_email": os.environ["ZENDESK_EMAIL"],
|
||||
"zendesk_token": os.environ["ZENDESK_TOKEN"],
|
||||
}
|
||||
)
|
||||
|
||||
current = time.time()
|
||||
one_day_ago = current - 24 * 60 * 60 # 1 day
|
||||
document_batches = connector.poll_source(one_day_ago, current)
|
||||
|
||||
print(next(document_batches))
|
||||
|
||||
@@ -70,6 +70,10 @@ def _process_citations_for_slack(text: str) -> str:
|
||||
def slack_link_format(match: Match) -> str:
|
||||
link_text = match.group(1)
|
||||
link_url = match.group(2)
|
||||
|
||||
# Account for empty link citations
|
||||
if link_url == "":
|
||||
return f"[{link_text}]"
|
||||
return f"<{link_url}|[{link_text}]>"
|
||||
|
||||
# Substitute all matches in the input text
|
||||
@@ -299,7 +303,9 @@ def build_sources_blocks(
|
||||
else []
|
||||
)
|
||||
+ [
|
||||
MarkdownTextObject(
|
||||
MarkdownTextObject(text=f"{document_title}")
|
||||
if d.link == ""
|
||||
else MarkdownTextObject(
|
||||
text=f"*<{d.link}|[{citation_num}] {document_title}>*\n{final_metadata_str}"
|
||||
),
|
||||
]
|
||||
|
||||
@@ -50,7 +50,7 @@ from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
from danswer.search.search_settings import get_search_settings
|
||||
|
||||
|
||||
srl = SlackRateLimiter()
|
||||
@@ -223,15 +223,23 @@ def handle_regular_answer(
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
# Always apply reranking settings if it exists, this is the non-streaming flow
|
||||
saved_search_settings = get_search_settings()
|
||||
|
||||
# This includes throwing out answer via reflexion
|
||||
answer = _get_answer(
|
||||
DirectQARequest(
|
||||
messages=messages,
|
||||
multilingual_query_expansion=saved_search_settings.multilingual_expansion
|
||||
if saved_search_settings
|
||||
else None,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
persona_id=persona.id if persona is not None else 0,
|
||||
retrieval_options=retrieval_details,
|
||||
chain_of_thought=not disable_cot,
|
||||
skip_rerank=not ENABLE_RERANKING_ASYNC_FLOW,
|
||||
rerank_settings=saved_search_settings.to_reranking_detail()
|
||||
if saved_search_settings
|
||||
else None,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -50,9 +50,9 @@ from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
@@ -470,9 +470,8 @@ if __name__ == "__main__":
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
if embedding_model.cloud_provider_id is None:
|
||||
warm_up_encoders(
|
||||
model_name=embedding_model.model_name,
|
||||
normalize=embedding_model.normalize,
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
from danswer.configs.constants import KV_SLACK_BOT_TOKENS_CONFIG_KEY
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
|
||||
|
||||
_SLACK_BOT_TOKENS_CONFIG_KEY = "slack_bot_tokens_config_key"
|
||||
|
||||
|
||||
def fetch_tokens() -> SlackBotTokens:
|
||||
# first check env variables
|
||||
app_token = os.environ.get("DANSWER_BOT_SLACK_APP_TOKEN")
|
||||
@@ -17,7 +15,7 @@ def fetch_tokens() -> SlackBotTokens:
|
||||
|
||||
dynamic_config_store = get_dynamic_config_store()
|
||||
return SlackBotTokens(
|
||||
**cast(dict, dynamic_config_store.load(key=_SLACK_BOT_TOKENS_CONFIG_KEY))
|
||||
**cast(dict, dynamic_config_store.load(key=KV_SLACK_BOT_TOKENS_CONFIG_KEY))
|
||||
)
|
||||
|
||||
|
||||
@@ -26,5 +24,5 @@ def save_tokens(
|
||||
) -> None:
|
||||
dynamic_config_store = get_dynamic_config_store()
|
||||
dynamic_config_store.store(
|
||||
key=_SLACK_BOT_TOKENS_CONFIG_KEY, val=dict(tokens), encrypt=True
|
||||
key=KV_SLACK_BOT_TOKENS_CONFIG_KEY, val=dict(tokens), encrypt=True
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@ from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.chat.models import LLMRelevanceSummaryResponse
|
||||
from danswer.chat.models import DocumentRelevance
|
||||
from danswer.configs.chat_configs import HARD_DELETE_CHATS
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.models import ChatMessage
|
||||
@@ -541,11 +541,11 @@ def get_doc_query_identifiers_from_model(
|
||||
def update_search_docs_table_with_relevance(
|
||||
db_session: Session,
|
||||
reference_db_search_docs: list[SearchDoc],
|
||||
relevance_summary: LLMRelevanceSummaryResponse,
|
||||
relevance_summary: DocumentRelevance,
|
||||
) -> None:
|
||||
for search_doc in reference_db_search_docs:
|
||||
relevance_data = relevance_summary.relevance_summaries.get(
|
||||
f"{search_doc.document_id}-{search_doc.chunk_ind}"
|
||||
search_doc.document_id
|
||||
)
|
||||
if relevance_data is not None:
|
||||
db_session.execute(
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import cast
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import aliased
|
||||
@@ -11,6 +11,7 @@ from danswer.configs.app_configs import DEFAULT_PRUNING_FREQ
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.server.documents.models import ConnectorBase
|
||||
from danswer.server.documents.models import ObjectCreationIdResponse
|
||||
@@ -20,19 +21,24 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def check_connectors_exist(db_session: Session) -> bool:
|
||||
# Connector 0 is created on server startup as a default for ingestion
|
||||
# it will always exist and we don't need to count it for this
|
||||
stmt = select(exists(Connector).where(Connector.id > 0))
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar() or False
|
||||
|
||||
|
||||
def fetch_connectors(
|
||||
db_session: Session,
|
||||
sources: list[DocumentSource] | None = None,
|
||||
input_types: list[InputType] | None = None,
|
||||
disabled_status: bool | None = None,
|
||||
) -> list[Connector]:
|
||||
stmt = select(Connector)
|
||||
if sources is not None:
|
||||
stmt = stmt.where(Connector.source.in_(sources))
|
||||
if input_types is not None:
|
||||
stmt = stmt.where(Connector.input_type.in_(input_types))
|
||||
if disabled_status is not None:
|
||||
stmt = stmt.where(Connector.disabled == disabled_status)
|
||||
results = db_session.scalars(stmt)
|
||||
return list(results.all())
|
||||
|
||||
@@ -85,10 +91,8 @@ def create_connector(
|
||||
input_type=connector_data.input_type,
|
||||
connector_specific_config=connector_data.connector_specific_config,
|
||||
refresh_freq=connector_data.refresh_freq,
|
||||
prune_freq=connector_data.prune_freq
|
||||
if connector_data.prune_freq is not None
|
||||
else DEFAULT_PRUNING_FREQ,
|
||||
disabled=connector_data.disabled,
|
||||
indexing_start=connector_data.indexing_start,
|
||||
prune_freq=connector_data.prune_freq,
|
||||
)
|
||||
db_session.add(connector)
|
||||
db_session.commit()
|
||||
@@ -122,33 +126,18 @@ def update_connector(
|
||||
if connector_data.prune_freq is not None
|
||||
else DEFAULT_PRUNING_FREQ
|
||||
)
|
||||
connector.disabled = connector_data.disabled
|
||||
|
||||
db_session.commit()
|
||||
return connector
|
||||
|
||||
|
||||
def disable_connector(
|
||||
connector_id: int,
|
||||
db_session: Session,
|
||||
) -> StatusResponse[int]:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
|
||||
connector.disabled = True
|
||||
db_session.commit()
|
||||
return StatusResponse(
|
||||
success=True, message="Connector deleted successfully", data=connector_id
|
||||
)
|
||||
|
||||
|
||||
def delete_connector(
|
||||
connector_id: int,
|
||||
db_session: Session,
|
||||
) -> StatusResponse[int]:
|
||||
"""Currently unused due to foreign key restriction from IndexAttempt
|
||||
Use disable_connector instead"""
|
||||
"""Only used in special cases (e.g. a connector is in a bad state and we need to delete it).
|
||||
Be VERY careful using this, as it could lead to a bad state if not used correctly.
|
||||
"""
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
if connector is None:
|
||||
return StatusResponse(
|
||||
@@ -179,11 +168,9 @@ def fetch_latest_index_attempt_by_connector(
|
||||
latest_index_attempts: list[IndexAttempt] = []
|
||||
|
||||
if source:
|
||||
connectors = fetch_connectors(
|
||||
db_session, sources=[source], disabled_status=False
|
||||
)
|
||||
connectors = fetch_connectors(db_session, sources=[source])
|
||||
else:
|
||||
connectors = fetch_connectors(db_session, disabled_status=False)
|
||||
connectors = fetch_connectors(db_session)
|
||||
|
||||
if not connectors:
|
||||
return []
|
||||
@@ -191,7 +178,8 @@ def fetch_latest_index_attempt_by_connector(
|
||||
for connector in connectors:
|
||||
latest_index_attempt = (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(IndexAttempt.connector_id == connector.id)
|
||||
.join(ConnectorCredentialPair)
|
||||
.filter(ConnectorCredentialPair.connector_id == connector.id)
|
||||
.order_by(IndexAttempt.time_updated.desc())
|
||||
.first()
|
||||
)
|
||||
@@ -207,13 +195,11 @@ def fetch_latest_index_attempts_by_status(
|
||||
) -> list[IndexAttempt]:
|
||||
subquery = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_id,
|
||||
IndexAttempt.credential_id,
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
IndexAttempt.status,
|
||||
func.max(IndexAttempt.time_updated).label("time_updated"),
|
||||
)
|
||||
.group_by(IndexAttempt.connector_id)
|
||||
.group_by(IndexAttempt.credential_id)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
.group_by(IndexAttempt.status)
|
||||
.subquery()
|
||||
)
|
||||
@@ -223,12 +209,13 @@ def fetch_latest_index_attempts_by_status(
|
||||
query = db_session.query(IndexAttempt).join(
|
||||
alias,
|
||||
and_(
|
||||
IndexAttempt.connector_id == alias.connector_id,
|
||||
IndexAttempt.credential_id == alias.credential_id,
|
||||
IndexAttempt.connector_credential_pair_id
|
||||
== alias.connector_credential_pair_id,
|
||||
IndexAttempt.status == alias.status,
|
||||
IndexAttempt.time_updated == alias.time_updated,
|
||||
),
|
||||
)
|
||||
|
||||
return cast(list[IndexAttempt], query.all())
|
||||
|
||||
|
||||
@@ -247,20 +234,29 @@ def fetch_unique_document_sources(db_session: Session) -> list[DocumentSource]:
|
||||
def create_initial_default_connector(db_session: Session) -> None:
|
||||
default_connector_id = 0
|
||||
default_connector = fetch_connector_by_id(default_connector_id, db_session)
|
||||
|
||||
if default_connector is not None:
|
||||
if (
|
||||
default_connector.source != DocumentSource.INGESTION_API
|
||||
or default_connector.input_type != InputType.LOAD_STATE
|
||||
or default_connector.refresh_freq is not None
|
||||
or default_connector.disabled
|
||||
or default_connector.name != "Ingestion API"
|
||||
or default_connector.connector_specific_config != {}
|
||||
or default_connector.prune_freq is not None
|
||||
):
|
||||
raise ValueError(
|
||||
"DB is not in a valid initial state. "
|
||||
"Default connector does not have expected values."
|
||||
logger.warning(
|
||||
"Default connector does not have expected values. Updating to proper state."
|
||||
)
|
||||
# Ensure default connector has correct valuesg
|
||||
default_connector.source = DocumentSource.INGESTION_API
|
||||
default_connector.input_type = InputType.LOAD_STATE
|
||||
default_connector.refresh_freq = None
|
||||
default_connector.name = "Ingestion API"
|
||||
default_connector.connector_specific_config = {}
|
||||
default_connector.prune_freq = None
|
||||
db_session.commit()
|
||||
return
|
||||
|
||||
# Create a new default connector if it doesn't exist
|
||||
connector = Connector(
|
||||
id=default_connector_id,
|
||||
name="Ingestion API",
|
||||
|
||||
@@ -6,8 +6,10 @@ from sqlalchemy import desc
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.credentials import fetch_credential_by_id
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexAttempt
|
||||
@@ -25,7 +27,9 @@ def get_connector_credential_pairs(
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
stmt = select(ConnectorCredentialPair)
|
||||
if not include_disabled:
|
||||
stmt = stmt.where(ConnectorCredentialPair.connector.disabled == False) # noqa
|
||||
stmt = stmt.where(
|
||||
ConnectorCredentialPair.status == ConnectorCredentialPairStatus.ACTIVE
|
||||
) # noqa
|
||||
results = db_session.scalars(stmt)
|
||||
return list(results.all())
|
||||
|
||||
@@ -42,6 +46,17 @@ def get_connector_credential_pair(
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
def get_connector_credential_source_from_id(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
) -> DocumentSource | None:
|
||||
stmt = select(ConnectorCredentialPair)
|
||||
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
|
||||
result = db_session.execute(stmt)
|
||||
cc_pair = result.scalar_one_or_none()
|
||||
return cc_pair.connector.source if cc_pair else None
|
||||
|
||||
|
||||
def get_connector_credential_pair_from_id(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
@@ -75,26 +90,78 @@ def get_last_successful_attempt_time(
|
||||
# For Secondary Index we don't keep track of the latest success, so have to calculate it live
|
||||
attempt = (
|
||||
db_session.query(IndexAttempt)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.filter(
|
||||
IndexAttempt.connector_id == connector_id,
|
||||
IndexAttempt.credential_id == credential_id,
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
IndexAttempt.embedding_model_id == embedding_model.id,
|
||||
IndexAttempt.status == IndexingStatus.SUCCESS,
|
||||
)
|
||||
.order_by(IndexAttempt.time_started.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
if not attempt or not attempt.time_started:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
if connector and connector.indexing_start:
|
||||
return connector.indexing_start.timestamp()
|
||||
return 0.0
|
||||
|
||||
return attempt.time_started.timestamp()
|
||||
|
||||
|
||||
"""Updates"""
|
||||
|
||||
|
||||
def _update_connector_credential_pair(
|
||||
db_session: Session,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
status: ConnectorCredentialPairStatus | None = None,
|
||||
net_docs: int | None = None,
|
||||
run_dt: datetime | None = None,
|
||||
) -> None:
|
||||
# simply don't update last_successful_index_time if run_dt is not specified
|
||||
# at worst, this would result in re-indexing documents that were already indexed
|
||||
if run_dt is not None:
|
||||
cc_pair.last_successful_index_time = run_dt
|
||||
if net_docs is not None:
|
||||
cc_pair.total_docs_indexed += net_docs
|
||||
if status is not None:
|
||||
cc_pair.status = status
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_connector_credential_pair_from_id(
|
||||
db_session: Session,
|
||||
cc_pair_id: int,
|
||||
status: ConnectorCredentialPairStatus | None = None,
|
||||
net_docs: int | None = None,
|
||||
run_dt: datetime | None = None,
|
||||
) -> None:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
logger.warning(
|
||||
f"Attempted to update pair for Connector Credential Pair '{cc_pair_id}'"
|
||||
f" but it does not exist"
|
||||
)
|
||||
return
|
||||
|
||||
_update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
cc_pair=cc_pair,
|
||||
status=status,
|
||||
net_docs=net_docs,
|
||||
run_dt=run_dt,
|
||||
)
|
||||
|
||||
|
||||
def update_connector_credential_pair(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
status: ConnectorCredentialPairStatus | None = None,
|
||||
net_docs: int | None = None,
|
||||
run_dt: datetime | None = None,
|
||||
) -> None:
|
||||
@@ -105,13 +172,14 @@ def update_connector_credential_pair(
|
||||
f"and credential id {credential_id}"
|
||||
)
|
||||
return
|
||||
# simply don't update last_successful_index_time if run_dt is not specified
|
||||
# at worst, this would result in re-indexing documents that were already indexed
|
||||
if run_dt is not None:
|
||||
cc_pair.last_successful_index_time = run_dt
|
||||
if net_docs is not None:
|
||||
cc_pair.total_docs_indexed += net_docs
|
||||
db_session.commit()
|
||||
|
||||
_update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
cc_pair=cc_pair,
|
||||
status=status,
|
||||
net_docs=net_docs,
|
||||
run_dt=run_dt,
|
||||
)
|
||||
|
||||
|
||||
def delete_connector_credential_pair__no_commit(
|
||||
@@ -142,6 +210,8 @@ def associate_default_cc_pair(db_session: Session) -> None:
|
||||
connector_id=0,
|
||||
credential_id=0,
|
||||
name="DefaultCCPair",
|
||||
status=ConnectorCredentialPairStatus.ACTIVE,
|
||||
is_public=True,
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.commit()
|
||||
@@ -186,6 +256,7 @@ def add_credential_to_connector(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
name=cc_pair_name,
|
||||
status=ConnectorCredentialPairStatus.ACTIVE,
|
||||
is_public=is_public,
|
||||
)
|
||||
db_session.add(association)
|
||||
@@ -241,6 +312,12 @@ def remove_credential_from_connector(
|
||||
)
|
||||
|
||||
|
||||
def fetch_connector_credential_pairs(
|
||||
db_session: Session,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
return db_session.query(ConnectorCredentialPair).all()
|
||||
|
||||
|
||||
def resync_cc_pair(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
@@ -253,10 +330,14 @@ def resync_cc_pair(
|
||||
) -> IndexAttempt | None:
|
||||
query = (
|
||||
db_session.query(IndexAttempt)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(EmbeddingModel, IndexAttempt.embedding_model_id == EmbeddingModel.id)
|
||||
.filter(
|
||||
IndexAttempt.connector_id == connector_id,
|
||||
IndexAttempt.credential_id == credential_id,
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
EmbeddingModel.status == IndexModelStatus.PRESENT,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -2,10 +2,13 @@ from typing import Any
|
||||
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import and_
|
||||
from sqlalchemy.sql.expression import or_
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.gmail.constants import (
|
||||
GMAIL_DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
@@ -14,8 +17,10 @@ from danswer.connectors.google_drive.constants import (
|
||||
)
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import User
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import CredentialDataUpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -74,6 +79,69 @@ def fetch_credential_by_id(
|
||||
return credential
|
||||
|
||||
|
||||
def fetch_credentials_by_source(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
document_source: DocumentSource | None = None,
|
||||
) -> list[Credential]:
|
||||
base_query = select(Credential).where(Credential.source == document_source)
|
||||
base_query = _attach_user_filters(base_query, user)
|
||||
credentials = db_session.execute(base_query).scalars().all()
|
||||
return list(credentials)
|
||||
|
||||
|
||||
def swap_credentials_connector(
|
||||
new_credential_id: int, connector_id: int, user: User | None, db_session: Session
|
||||
) -> ConnectorCredentialPair:
|
||||
# Check if the user has permission to use the new credential
|
||||
new_credential = fetch_credential_by_id(new_credential_id, user, db_session)
|
||||
if not new_credential:
|
||||
raise ValueError(
|
||||
f"No Credential found with id {new_credential_id} or user doesn't have permission to use it"
|
||||
)
|
||||
|
||||
# Existing pair
|
||||
existing_pair = db_session.execute(
|
||||
select(ConnectorCredentialPair).where(
|
||||
ConnectorCredentialPair.connector_id == connector_id
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if not existing_pair:
|
||||
raise ValueError(
|
||||
f"No ConnectorCredentialPair found for connector_id {connector_id}"
|
||||
)
|
||||
|
||||
# Check if the new credential is compatible with the connector
|
||||
if new_credential.source != existing_pair.connector.source:
|
||||
raise ValueError(
|
||||
f"New credential source {new_credential.source} does not match connector source {existing_pair.connector.source}"
|
||||
)
|
||||
|
||||
db_session.execute(
|
||||
update(DocumentByConnectorCredentialPair)
|
||||
.where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== existing_pair.credential_id,
|
||||
)
|
||||
)
|
||||
.values(credential_id=new_credential_id)
|
||||
)
|
||||
|
||||
# Update the existing pair with the new credential
|
||||
existing_pair.credential_id = new_credential_id
|
||||
existing_pair.credential = new_credential
|
||||
|
||||
# Commit the changes
|
||||
db_session.commit()
|
||||
|
||||
# Refresh the object to ensure all relationships are up-to-date
|
||||
db_session.refresh(existing_pair)
|
||||
return existing_pair
|
||||
|
||||
|
||||
def create_credential(
|
||||
credential_data: CredentialBase,
|
||||
user: User | None,
|
||||
@@ -83,6 +151,8 @@ def create_credential(
|
||||
credential_json=credential_data.credential_json,
|
||||
user_id=user.id if user else None,
|
||||
admin_public=credential_data.admin_public,
|
||||
source=credential_data.source,
|
||||
name=credential_data.name,
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.commit()
|
||||
@@ -90,6 +160,28 @@ def create_credential(
|
||||
return credential
|
||||
|
||||
|
||||
def alter_credential(
|
||||
credential_id: int,
|
||||
credential_data: CredentialDataUpdateRequest,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> Credential | None:
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
credential.name = credential_data.name
|
||||
|
||||
# Update only the keys present in credential_data.credential_json
|
||||
for key, value in credential_data.credential_json.items():
|
||||
credential.credential_json[key] = value
|
||||
|
||||
credential.user_id = user.id if user is not None else None
|
||||
db_session.commit()
|
||||
return credential
|
||||
|
||||
|
||||
def update_credential(
|
||||
credential_id: int,
|
||||
credential_data: CredentialBase,
|
||||
@@ -136,6 +228,7 @@ def delete_credential(
|
||||
credential_id: int,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
if credential is None:
|
||||
@@ -149,11 +242,38 @@ def delete_credential(
|
||||
.all()
|
||||
)
|
||||
|
||||
if associated_connectors:
|
||||
raise ValueError(
|
||||
f"Cannot delete credential {credential_id} as it is still associated with {len(associated_connectors)} connector(s). "
|
||||
"Please delete all associated connectors first."
|
||||
)
|
||||
associated_doc_cc_pairs = (
|
||||
db_session.query(DocumentByConnectorCredentialPair)
|
||||
.filter(DocumentByConnectorCredentialPair.credential_id == credential_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
if associated_connectors or associated_doc_cc_pairs:
|
||||
if force:
|
||||
logger.warning(
|
||||
f"Force deleting credential {credential_id} and its associated records"
|
||||
)
|
||||
|
||||
# Delete DocumentByConnectorCredentialPair records first
|
||||
for doc_cc_pair in associated_doc_cc_pairs:
|
||||
db_session.delete(doc_cc_pair)
|
||||
|
||||
# Then delete ConnectorCredentialPair records
|
||||
for connector in associated_connectors:
|
||||
db_session.delete(connector)
|
||||
|
||||
# Commit these deletions before deleting the credential
|
||||
db_session.flush()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot delete credential as it is still associated with "
|
||||
f"{len(associated_connectors)} connector(s) and {len(associated_doc_cc_pairs)} document(s). "
|
||||
)
|
||||
|
||||
if force:
|
||||
logger.info(f"Force deleting credential {credential_id}")
|
||||
else:
|
||||
logger.info(f"Deleting credential {credential_id}")
|
||||
|
||||
db_session.delete(credential)
|
||||
db_session.commit()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import get_last_attempt
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexingStatus
|
||||
@@ -13,7 +14,7 @@ def check_deletion_attempt_is_allowed(
|
||||
) -> str | None:
|
||||
"""
|
||||
To be deletable:
|
||||
(1) connector should be disabled
|
||||
(1) connector should be paused
|
||||
(2) there should be no in-progress/planned index attempts
|
||||
|
||||
Returns an error message if the deletion attempt is not allowed, otherwise None.
|
||||
@@ -23,7 +24,10 @@ def check_deletion_attempt_is_allowed(
|
||||
f"'{connector_credential_pair.credential_id}' is not deletable."
|
||||
)
|
||||
|
||||
if not connector_credential_pair.connector.disabled:
|
||||
if (
|
||||
connector_credential_pair.status != ConnectorCredentialPairStatus.PAUSED
|
||||
and connector_credential_pair.status != ConnectorCredentialPairStatus.DELETING
|
||||
):
|
||||
return base_error_msg + " Connector must be paused."
|
||||
|
||||
connector_id = connector_credential_pair.connector_id
|
||||
|
||||
@@ -7,6 +7,7 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import exists
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
@@ -16,6 +17,7 @@ from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.feedback import delete_document_feedback_for_documents__no_commit
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
@@ -30,6 +32,12 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def check_docs_exist(db_session: Session) -> bool:
|
||||
stmt = select(exists(DbDocument))
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar() or False
|
||||
|
||||
|
||||
def get_documents_for_connector_credential_pair(
|
||||
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
|
||||
) -> Sequence[DbDocument]:
|
||||
@@ -103,36 +111,19 @@ def get_document_cnts_for_cc_pairs(
|
||||
def get_acccess_info_for_documents(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
|
||||
) -> Sequence[tuple[str, list[UUID | None], bool]]:
|
||||
"""Gets back all relevant access info for the given documents. This includes
|
||||
the user_ids for cc pairs that the document is associated with + whether any
|
||||
of the associated cc pairs are intending to make the document globally public.
|
||||
|
||||
If `cc_pair_to_delete` is specified, gets the above access info as if that
|
||||
pair had been deleted. This is needed since we want to delete from the Vespa
|
||||
before deleting from Postgres to ensure that the state of Postgres never "loses"
|
||||
documents that still exist in Vespa.
|
||||
"""
|
||||
stmt = select(
|
||||
DocumentByConnectorCredentialPair.id,
|
||||
func.array_agg(Credential.user_id).label("user_ids"),
|
||||
func.bool_or(ConnectorCredentialPair.is_public).label("public_doc"),
|
||||
).where(DocumentByConnectorCredentialPair.id.in_(document_ids))
|
||||
|
||||
# pretend that the specified cc pair doesn't exist
|
||||
if cc_pair_to_delete:
|
||||
stmt = stmt.where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
!= cc_pair_to_delete.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
!= cc_pair_to_delete.credential_id,
|
||||
)
|
||||
)
|
||||
|
||||
stmt = (
|
||||
stmt.join(
|
||||
select(
|
||||
DocumentByConnectorCredentialPair.id,
|
||||
func.array_agg(Credential.user_id).label("user_ids"),
|
||||
func.bool_or(ConnectorCredentialPair.is_public).label("public_doc"),
|
||||
)
|
||||
.where(DocumentByConnectorCredentialPair.id.in_(document_ids))
|
||||
.join(
|
||||
Credential,
|
||||
DocumentByConnectorCredentialPair.credential_id == Credential.id,
|
||||
)
|
||||
@@ -145,6 +136,9 @@ def get_acccess_info_for_documents(
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
# don't include CC pairs that are being deleted
|
||||
# NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them
|
||||
.where(ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING)
|
||||
.group_by(DocumentByConnectorCredentialPair.id)
|
||||
)
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
@@ -311,7 +305,7 @@ def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool
|
||||
|
||||
|
||||
_NUM_LOCK_ATTEMPTS = 10
|
||||
_LOCK_RETRY_DELAY = 30
|
||||
_LOCK_RETRY_DELAY = 10
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Document
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
@@ -270,37 +271,20 @@ def mark_document_set_as_to_be_deleted(
|
||||
raise
|
||||
|
||||
|
||||
def mark_cc_pair__document_set_relationships_to_be_deleted__no_commit(
|
||||
cc_pair_id: int, db_session: Session
|
||||
) -> set[int]:
|
||||
"""Marks all CC Pair -> Document Set relationships for the specified
|
||||
`cc_pair_id` as not current and returns the list of all document set IDs
|
||||
affected.
|
||||
|
||||
NOTE: rases a `ValueError` if any of the document sets are currently syncing
|
||||
to avoid getting into a bad state."""
|
||||
document_set__cc_pair_relationships = db_session.scalars(
|
||||
select(DocumentSet__ConnectorCredentialPair).where(
|
||||
def delete_document_set_cc_pair_relationship__no_commit(
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
) -> None:
|
||||
"""Deletes all rows from DocumentSet__ConnectorCredentialPair where the
|
||||
connector_credential_pair_id matches the given cc_pair_id."""
|
||||
delete_stmt = delete(DocumentSet__ConnectorCredentialPair).where(
|
||||
and_(
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
DocumentSet__ConnectorCredentialPair.connector_credential_pair_id
|
||||
== cc_pair_id
|
||||
== ConnectorCredentialPair.id,
|
||||
)
|
||||
).all()
|
||||
|
||||
document_set_ids_touched: set[int] = set()
|
||||
for document_set__cc_pair_relationship in document_set__cc_pair_relationships:
|
||||
document_set__cc_pair_relationship.is_current = False
|
||||
|
||||
if not document_set__cc_pair_relationship.document_set.is_up_to_date:
|
||||
raise ValueError(
|
||||
"Cannot delete CC pair while it is attached to a document set "
|
||||
"that is syncing. Please wait for the document set to finish "
|
||||
"syncing, and then try again."
|
||||
)
|
||||
|
||||
document_set__cc_pair_relationship.document_set.is_up_to_date = False
|
||||
document_set_ids_touched.add(document_set__cc_pair_relationship.document_set_id)
|
||||
|
||||
return document_set_ids_touched
|
||||
)
|
||||
db_session.execute(delete_stmt)
|
||||
|
||||
|
||||
def fetch_document_sets(
|
||||
@@ -431,8 +415,10 @@ def fetch_documents_for_document_set_paginated(
|
||||
|
||||
|
||||
def fetch_document_sets_for_documents(
|
||||
document_ids: list[str], db_session: Session
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> Sequence[tuple[str, list[str]]]:
|
||||
"""Gives back a list of (document_id, list[document_set_names]) tuples"""
|
||||
stmt = (
|
||||
select(Document.id, func.array_agg(DocumentSetDBModel.name))
|
||||
.join(
|
||||
@@ -459,6 +445,10 @@ def fetch_document_sets_for_documents(
|
||||
Document.id == DocumentByConnectorCredentialPair.id,
|
||||
)
|
||||
.where(Document.id.in_(document_ids))
|
||||
# don't include CC pairs that are being deleted
|
||||
# NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them
|
||||
# as we can assume their document sets are no longer relevant
|
||||
.where(ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING)
|
||||
.where(DocumentSet__ConnectorCredentialPair.is_current == True) # noqa: E712
|
||||
.group_by(Document.id)
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ from danswer.db.models import CloudEmbeddingProvider
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.search.search_nlp_models import clean_model_name
|
||||
from danswer.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from danswer.server.manage.embedding.models import (
|
||||
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
|
||||
)
|
||||
|
||||
@@ -16,8 +16,11 @@ from sqlalchemy.orm import sessionmaker
|
||||
from danswer.configs.app_configs import POSTGRES_DB
|
||||
from danswer.configs.app_configs import POSTGRES_HOST
|
||||
from danswer.configs.app_configs import POSTGRES_PASSWORD
|
||||
from danswer.configs.app_configs import POSTGRES_POOL_PRE_PING
|
||||
from danswer.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||
from danswer.configs.app_configs import POSTGRES_PORT
|
||||
from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -25,12 +28,18 @@ logger = setup_logger()
|
||||
SYNC_DB_API = "psycopg2"
|
||||
ASYNC_DB_API = "asyncpg"
|
||||
|
||||
POSTGRES_APP_NAME = (
|
||||
POSTGRES_UNKNOWN_APP_NAME # helps to diagnose open connections in postgres
|
||||
)
|
||||
|
||||
# global so we don't create more than one engine per process
|
||||
# outside of being best practice, this is needed so we can properly pool
|
||||
# connections and not create a new pool on every request
|
||||
_SYNC_ENGINE: Engine | None = None
|
||||
_ASYNC_ENGINE: AsyncEngine | None = None
|
||||
|
||||
SessionFactory: sessionmaker[Session] | None = None
|
||||
|
||||
|
||||
def get_db_current_time(db_session: Session) -> datetime:
|
||||
"""Get the current time from Postgres representing the start of the transaction
|
||||
@@ -51,24 +60,50 @@ def build_connection_string(
|
||||
host: str = POSTGRES_HOST,
|
||||
port: str = POSTGRES_PORT,
|
||||
db: str = POSTGRES_DB,
|
||||
app_name: str | None = None,
|
||||
) -> str:
|
||||
if app_name:
|
||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}?application_name={app_name}"
|
||||
|
||||
return f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}"
|
||||
|
||||
|
||||
def init_sqlalchemy_engine(app_name: str) -> None:
|
||||
global POSTGRES_APP_NAME
|
||||
POSTGRES_APP_NAME = app_name
|
||||
|
||||
|
||||
def get_sqlalchemy_engine() -> Engine:
|
||||
global _SYNC_ENGINE
|
||||
if _SYNC_ENGINE is None:
|
||||
connection_string = build_connection_string(db_api=SYNC_DB_API)
|
||||
_SYNC_ENGINE = create_engine(connection_string, pool_size=40, max_overflow=10)
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API, app_name=POSTGRES_APP_NAME + "_sync"
|
||||
)
|
||||
_SYNC_ENGINE = create_engine(
|
||||
connection_string,
|
||||
pool_size=40,
|
||||
max_overflow=10,
|
||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||
)
|
||||
return _SYNC_ENGINE
|
||||
|
||||
|
||||
def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
global _ASYNC_ENGINE
|
||||
if _ASYNC_ENGINE is None:
|
||||
# underlying asyncpg cannot accept application_name directly in the connection string
|
||||
# https://github.com/MagicStack/asyncpg/issues/798
|
||||
connection_string = build_connection_string()
|
||||
_ASYNC_ENGINE = create_async_engine(
|
||||
connection_string, pool_size=40, max_overflow=10
|
||||
connection_string,
|
||||
connect_args={
|
||||
"server_settings": {"application_name": POSTGRES_APP_NAME + "_async"}
|
||||
},
|
||||
pool_size=40,
|
||||
max_overflow=10,
|
||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||
)
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
@@ -115,4 +150,8 @@ async def warm_up_connections(
|
||||
await async_conn.close()
|
||||
|
||||
|
||||
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
|
||||
def get_session_factory() -> sessionmaker[Session]:
|
||||
global SessionFactory
|
||||
if SessionFactory is None:
|
||||
SessionFactory = sessionmaker(bind=get_sqlalchemy_engine())
|
||||
return SessionFactory
|
||||
|
||||
@@ -33,3 +33,9 @@ class IndexModelStatus(str, PyEnum):
|
||||
class ChatSessionSharedStatus(str, PyEnum):
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
|
||||
|
||||
class ConnectorCredentialPairStatus(str, PyEnum):
|
||||
ACTIVE = "ACTIVE"
|
||||
PAUSED = "PAUSED"
|
||||
DELETING = "DELETING"
|
||||
|
||||
@@ -15,6 +15,7 @@ from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.server.documents.models import ConnectorCredentialPair
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
@@ -23,6 +24,22 @@ from danswer.utils.telemetry import RecordType
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_last_attempt_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
embedding_model_id: int,
|
||||
db_session: Session,
|
||||
) -> IndexAttempt | None:
|
||||
return (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair_id,
|
||||
IndexAttempt.embedding_model_id == embedding_model_id,
|
||||
)
|
||||
.order_by(IndexAttempt.time_updated.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
def get_index_attempt(
|
||||
db_session: Session, index_attempt_id: int
|
||||
) -> IndexAttempt | None:
|
||||
@@ -31,15 +48,13 @@ def get_index_attempt(
|
||||
|
||||
|
||||
def create_index_attempt(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
connector_credential_pair_id: int,
|
||||
embedding_model_id: int,
|
||||
db_session: Session,
|
||||
from_beginning: bool = False,
|
||||
) -> int:
|
||||
new_attempt = IndexAttempt(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
connector_credential_pair_id=connector_credential_pair_id,
|
||||
embedding_model_id=embedding_model_id,
|
||||
from_beginning=from_beginning,
|
||||
status=IndexingStatus.NOT_STARTED,
|
||||
@@ -56,7 +71,9 @@ def get_inprogress_index_attempts(
|
||||
) -> list[IndexAttempt]:
|
||||
stmt = select(IndexAttempt)
|
||||
if connector_id is not None:
|
||||
stmt = stmt.where(IndexAttempt.connector_id == connector_id)
|
||||
stmt = stmt.where(
|
||||
IndexAttempt.connector_credential_pair.has(connector_id=connector_id)
|
||||
)
|
||||
stmt = stmt.where(IndexAttempt.status == IndexingStatus.IN_PROGRESS)
|
||||
|
||||
incomplete_attempts = db_session.scalars(stmt)
|
||||
@@ -65,21 +82,31 @@ def get_inprogress_index_attempts(
|
||||
|
||||
def get_not_started_index_attempts(db_session: Session) -> list[IndexAttempt]:
|
||||
"""This eagerly loads the connector and credential so that the db_session can be expired
|
||||
before running long-living indexing jobs, which causes increasing memory usage"""
|
||||
before running long-living indexing jobs, which causes increasing memory usage.
|
||||
|
||||
Results are ordered by time_created (oldest to newest)."""
|
||||
stmt = select(IndexAttempt)
|
||||
stmt = stmt.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
|
||||
stmt = stmt.order_by(IndexAttempt.time_created)
|
||||
stmt = stmt.options(
|
||||
joinedload(IndexAttempt.connector), joinedload(IndexAttempt.credential)
|
||||
joinedload(IndexAttempt.connector_credential_pair).joinedload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
joinedload(IndexAttempt.connector_credential_pair).joinedload(
|
||||
ConnectorCredentialPair.credential
|
||||
),
|
||||
)
|
||||
new_attempts = db_session.scalars(stmt)
|
||||
return list(new_attempts.all())
|
||||
|
||||
|
||||
def mark_attempt_in_progress__no_commit(
|
||||
def mark_attempt_in_progress(
|
||||
index_attempt: IndexAttempt,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
index_attempt.status = IndexingStatus.IN_PROGRESS
|
||||
index_attempt.time_started = index_attempt.time_started or func.now() # type: ignore
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def mark_attempt_succeeded(
|
||||
@@ -103,7 +130,7 @@ def mark_attempt_failed(
|
||||
db_session.add(index_attempt)
|
||||
db_session.commit()
|
||||
|
||||
source = index_attempt.connector.source
|
||||
source = index_attempt.connector_credential_pair.connector.source
|
||||
optional_telemetry(record_type=RecordType.FAILURE, data={"connector": source})
|
||||
|
||||
|
||||
@@ -128,11 +155,16 @@ def get_last_attempt(
|
||||
embedding_model_id: int | None,
|
||||
db_session: Session,
|
||||
) -> IndexAttempt | None:
|
||||
stmt = select(IndexAttempt).where(
|
||||
IndexAttempt.connector_id == connector_id,
|
||||
IndexAttempt.credential_id == credential_id,
|
||||
IndexAttempt.embedding_model_id == embedding_model_id,
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.join(ConnectorCredentialPair)
|
||||
.where(
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
IndexAttempt.embedding_model_id == embedding_model_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Note, the below is using time_created instead of time_updated
|
||||
stmt = stmt.order_by(desc(IndexAttempt.time_created))
|
||||
|
||||
@@ -145,8 +177,7 @@ def get_latest_index_attempts(
|
||||
db_session: Session,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
ids_stmt = select(
|
||||
IndexAttempt.connector_id,
|
||||
IndexAttempt.credential_id,
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.max(IndexAttempt.time_created).label("max_time_created"),
|
||||
).join(EmbeddingModel, IndexAttempt.embedding_model_id == EmbeddingModel.id)
|
||||
|
||||
@@ -158,43 +189,101 @@ def get_latest_index_attempts(
|
||||
where_stmts: list[ColumnElement] = []
|
||||
for connector_credential_pair_identifier in connector_credential_pair_identifiers:
|
||||
where_stmts.append(
|
||||
and_(
|
||||
IndexAttempt.connector_id
|
||||
== connector_credential_pair_identifier.connector_id,
|
||||
IndexAttempt.credential_id
|
||||
== connector_credential_pair_identifier.credential_id,
|
||||
IndexAttempt.connector_credential_pair_id
|
||||
== (
|
||||
select(ConnectorCredentialPair.id)
|
||||
.where(
|
||||
ConnectorCredentialPair.connector_id
|
||||
== connector_credential_pair_identifier.connector_id,
|
||||
ConnectorCredentialPair.credential_id
|
||||
== connector_credential_pair_identifier.credential_id,
|
||||
)
|
||||
.scalar_subquery()
|
||||
)
|
||||
)
|
||||
if where_stmts:
|
||||
ids_stmt = ids_stmt.where(or_(*where_stmts))
|
||||
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_id, IndexAttempt.credential_id)
|
||||
ids_subqery = ids_stmt.subquery()
|
||||
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
ids_subquery = ids_stmt.subquery()
|
||||
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.join(
|
||||
ids_subqery,
|
||||
and_(
|
||||
ids_subqery.c.connector_id == IndexAttempt.connector_id,
|
||||
ids_subqery.c.credential_id == IndexAttempt.credential_id,
|
||||
),
|
||||
ids_subquery,
|
||||
IndexAttempt.connector_credential_pair_id
|
||||
== ids_subquery.c.connector_credential_pair_id,
|
||||
)
|
||||
.where(IndexAttempt.time_created == ids_subqery.c.max_time_created)
|
||||
.where(IndexAttempt.time_created == ids_subquery.c.max_time_created)
|
||||
)
|
||||
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def get_index_attempts_for_connector(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
only_current: bool = True,
|
||||
disinclude_finished: bool = False,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.join(ConnectorCredentialPair)
|
||||
.where(ConnectorCredentialPair.connector_id == connector_id)
|
||||
)
|
||||
if disinclude_finished:
|
||||
stmt = stmt.where(
|
||||
IndexAttempt.status.in_(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
)
|
||||
)
|
||||
if only_current:
|
||||
stmt = stmt.join(EmbeddingModel).where(
|
||||
EmbeddingModel.status == IndexModelStatus.PRESENT
|
||||
)
|
||||
|
||||
stmt = stmt.order_by(IndexAttempt.time_created.desc())
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def get_latest_finished_index_attempt_for_cc_pair(
|
||||
connector_credential_pair_id: int,
|
||||
secondary_index: bool,
|
||||
db_session: Session,
|
||||
) -> IndexAttempt | None:
|
||||
stmt = select(IndexAttempt).where(
|
||||
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
|
||||
IndexAttempt.status.not_in(
|
||||
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
|
||||
),
|
||||
)
|
||||
if secondary_index:
|
||||
stmt = stmt.join(EmbeddingModel).where(
|
||||
EmbeddingModel.status == IndexModelStatus.FUTURE
|
||||
)
|
||||
else:
|
||||
stmt = stmt.join(EmbeddingModel).where(
|
||||
EmbeddingModel.status == IndexModelStatus.PRESENT
|
||||
)
|
||||
stmt = stmt.order_by(desc(IndexAttempt.time_created))
|
||||
stmt = stmt.limit(1)
|
||||
return db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
|
||||
def get_index_attempts_for_cc_pair(
|
||||
db_session: Session,
|
||||
cc_pair_identifier: ConnectorCredentialPairIdentifier,
|
||||
only_current: bool = True,
|
||||
disinclude_finished: bool = False,
|
||||
) -> Sequence[IndexAttempt]:
|
||||
stmt = select(IndexAttempt).where(
|
||||
and_(
|
||||
IndexAttempt.connector_id == cc_pair_identifier.connector_id,
|
||||
IndexAttempt.credential_id == cc_pair_identifier.credential_id,
|
||||
stmt = (
|
||||
select(IndexAttempt)
|
||||
.join(ConnectorCredentialPair)
|
||||
.where(
|
||||
and_(
|
||||
ConnectorCredentialPair.connector_id == cc_pair_identifier.connector_id,
|
||||
ConnectorCredentialPair.credential_id
|
||||
== cc_pair_identifier.credential_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
if disinclude_finished:
|
||||
@@ -218,9 +307,11 @@ def delete_index_attempts(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
stmt = delete(IndexAttempt).where(
|
||||
IndexAttempt.connector_id == connector_id,
|
||||
IndexAttempt.credential_id == credential_id,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
@@ -249,14 +340,15 @@ def expire_index_attempts(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def cancel_indexing_attempts_for_connector(
|
||||
connector_id: int,
|
||||
def cancel_indexing_attempts_for_ccpair(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
include_secondary_index: bool = False,
|
||||
) -> None:
|
||||
stmt = delete(IndexAttempt).where(
|
||||
IndexAttempt.connector_id == connector_id,
|
||||
IndexAttempt.status == IndexingStatus.NOT_STARTED,
|
||||
stmt = (
|
||||
delete(IndexAttempt)
|
||||
.where(IndexAttempt.connector_credential_pair_id == cc_pair_id)
|
||||
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
|
||||
)
|
||||
|
||||
if not include_secondary_index:
|
||||
@@ -273,6 +365,8 @@ def cancel_indexing_attempts_for_connector(
|
||||
def cancel_indexing_attempts_past_model(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Stops all indexing attempts that are in progress or not started for
|
||||
any embedding model that not present/future"""
|
||||
db_session.execute(
|
||||
update(IndexAttempt)
|
||||
.where(
|
||||
@@ -296,7 +390,8 @@ def count_unique_cc_pairs_with_successful_index_attempts(
|
||||
Then do distinct by connector_id and credential_id which is equivalent to the cc-pair. Finally,
|
||||
do a count to get the total number of unique cc-pairs with successful attempts"""
|
||||
unique_pairs_count = (
|
||||
db_session.query(IndexAttempt.connector_id, IndexAttempt.credential_id)
|
||||
db_session.query(IndexAttempt.connector_credential_pair_id)
|
||||
.join(ConnectorCredentialPair)
|
||||
.filter(
|
||||
IndexAttempt.embedding_model_id == embedding_model_id,
|
||||
IndexAttempt.status == IndexingStatus.SUCCESS,
|
||||
|
||||
202
backend/danswer/db/input_prompt.py
Normal file
202
backend/danswer/db/input_prompt.py
Normal file
@@ -0,0 +1,202 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import InputPrompt
|
||||
from danswer.db.models import User
|
||||
from danswer.server.features.input_prompt.models import InputPromptSnapshot
|
||||
from danswer.server.manage.models import UserInfo
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def insert_input_prompt_if_not_exists(
|
||||
user: User | None,
|
||||
input_prompt_id: int | None,
|
||||
prompt: str,
|
||||
content: str,
|
||||
active: bool,
|
||||
is_public: bool,
|
||||
db_session: Session,
|
||||
commit: bool = True,
|
||||
) -> InputPrompt:
|
||||
if input_prompt_id is not None:
|
||||
input_prompt = (
|
||||
db_session.query(InputPrompt).filter_by(id=input_prompt_id).first()
|
||||
)
|
||||
else:
|
||||
query = db_session.query(InputPrompt).filter(InputPrompt.prompt == prompt)
|
||||
if user:
|
||||
query = query.filter(InputPrompt.user_id == user.id)
|
||||
else:
|
||||
query = query.filter(InputPrompt.user_id.is_(None))
|
||||
input_prompt = query.first()
|
||||
|
||||
if input_prompt is None:
|
||||
input_prompt = InputPrompt(
|
||||
id=input_prompt_id,
|
||||
prompt=prompt,
|
||||
content=content,
|
||||
active=active,
|
||||
is_public=is_public or user is None,
|
||||
user_id=user.id if user else None,
|
||||
)
|
||||
db_session.add(input_prompt)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
|
||||
return input_prompt
|
||||
|
||||
|
||||
def insert_input_prompt(
|
||||
prompt: str,
|
||||
content: str,
|
||||
is_public: bool,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> InputPrompt:
|
||||
input_prompt = InputPrompt(
|
||||
prompt=prompt,
|
||||
content=content,
|
||||
active=True,
|
||||
is_public=is_public or user is None,
|
||||
user_id=user.id if user is not None else None,
|
||||
)
|
||||
db_session.add(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
return input_prompt
|
||||
|
||||
|
||||
def update_input_prompt(
|
||||
user: User | None,
|
||||
input_prompt_id: int,
|
||||
prompt: str,
|
||||
content: str,
|
||||
active: bool,
|
||||
db_session: Session,
|
||||
) -> InputPrompt:
|
||||
input_prompt = db_session.scalar(
|
||||
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
|
||||
)
|
||||
if input_prompt is None:
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if not validate_user_prompt_authorization(user, input_prompt):
|
||||
raise HTTPException(status_code=401, detail="You don't own this prompt")
|
||||
|
||||
input_prompt.prompt = prompt
|
||||
input_prompt.content = content
|
||||
input_prompt.active = active
|
||||
|
||||
db_session.commit()
|
||||
return input_prompt
|
||||
|
||||
|
||||
def validate_user_prompt_authorization(
|
||||
user: User | None, input_prompt: InputPrompt
|
||||
) -> bool:
|
||||
prompt = InputPromptSnapshot.from_model(input_prompt=input_prompt)
|
||||
|
||||
if prompt.user_id is not None:
|
||||
if user is None:
|
||||
return False
|
||||
|
||||
user_details = UserInfo.from_model(user)
|
||||
if str(user_details.id) != str(prompt.user_id):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def remove_public_input_prompt(input_prompt_id: int, db_session: Session) -> None:
|
||||
input_prompt = db_session.scalar(
|
||||
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
|
||||
)
|
||||
|
||||
if input_prompt is None:
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if not input_prompt.is_public:
|
||||
raise HTTPException(status_code=400, detail="This prompt is not public")
|
||||
|
||||
db_session.delete(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_input_prompt(
|
||||
user: User | None, input_prompt_id: int, db_session: Session
|
||||
) -> None:
|
||||
input_prompt = db_session.scalar(
|
||||
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
|
||||
)
|
||||
if input_prompt is None:
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if input_prompt.is_public:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Cannot delete public prompts with this method"
|
||||
)
|
||||
|
||||
if not validate_user_prompt_authorization(user, input_prompt):
|
||||
raise HTTPException(status_code=401, detail="You do not own this prompt")
|
||||
|
||||
db_session.delete(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def fetch_input_prompt_by_id(
|
||||
id: int, user_id: UUID | None, db_session: Session
|
||||
) -> InputPrompt:
|
||||
query = select(InputPrompt).where(InputPrompt.id == id)
|
||||
|
||||
if user_id:
|
||||
query = query.where(
|
||||
(InputPrompt.user_id == user_id) | (InputPrompt.user_id is None)
|
||||
)
|
||||
else:
|
||||
# If no user_id is provided, only fetch prompts without a user_id (aka public)
|
||||
query = query.where(InputPrompt.user_id == None) # noqa
|
||||
|
||||
result = db_session.scalar(query)
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(422, "No input prompt found")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def fetch_public_input_prompts(
|
||||
db_session: Session,
|
||||
) -> list[InputPrompt]:
|
||||
query = select(InputPrompt).where(InputPrompt.is_public)
|
||||
return list(db_session.scalars(query).all())
|
||||
|
||||
|
||||
def fetch_input_prompts_by_user(
|
||||
db_session: Session,
|
||||
user_id: UUID | None,
|
||||
active: bool | None = None,
|
||||
include_public: bool = False,
|
||||
) -> list[InputPrompt]:
|
||||
query = select(InputPrompt)
|
||||
|
||||
if user_id is not None:
|
||||
if include_public:
|
||||
query = query.where(
|
||||
(InputPrompt.user_id == user_id) | InputPrompt.is_public
|
||||
)
|
||||
else:
|
||||
query = query.where(InputPrompt.user_id == user_id)
|
||||
|
||||
elif include_public:
|
||||
query = query.where(InputPrompt.is_public)
|
||||
|
||||
if active is not None:
|
||||
query = query.where(InputPrompt.active == active)
|
||||
|
||||
return list(db_session.scalars(query).all())
|
||||
@@ -1,15 +1,41 @@
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
from danswer.db.models import LLMProvider__UserGroup
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from danswer.server.manage.llm.models import FullLLMProvider
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
|
||||
|
||||
def update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id: int,
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
# Delete existing relationships
|
||||
db_session.query(LLMProvider__UserGroup).filter(
|
||||
LLMProvider__UserGroup.llm_provider_id == llm_provider_id
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
# Add new relationships from given group_ids
|
||||
if group_ids:
|
||||
new_relationships = [
|
||||
LLMProvider__UserGroup(
|
||||
llm_provider_id=llm_provider_id,
|
||||
user_group_id=group_id,
|
||||
)
|
||||
for group_id in group_ids
|
||||
]
|
||||
db_session.add_all(new_relationships)
|
||||
|
||||
|
||||
def upsert_cloud_embedding_provider(
|
||||
db_session: Session, provider: CloudEmbeddingProviderCreationRequest
|
||||
) -> CloudEmbeddingProvider:
|
||||
@@ -36,36 +62,36 @@ def upsert_llm_provider(
|
||||
existing_llm_provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
||||
)
|
||||
if existing_llm_provider:
|
||||
existing_llm_provider.provider = llm_provider.provider
|
||||
existing_llm_provider.api_key = llm_provider.api_key
|
||||
existing_llm_provider.api_base = llm_provider.api_base
|
||||
existing_llm_provider.api_version = llm_provider.api_version
|
||||
existing_llm_provider.custom_config = llm_provider.custom_config
|
||||
existing_llm_provider.default_model_name = llm_provider.default_model_name
|
||||
existing_llm_provider.fast_default_model_name = (
|
||||
llm_provider.fast_default_model_name
|
||||
)
|
||||
existing_llm_provider.model_names = llm_provider.model_names
|
||||
db_session.commit()
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
# if it does not exist, create a new entry
|
||||
llm_provider_model = LLMProviderModel(
|
||||
name=llm_provider.name,
|
||||
provider=llm_provider.provider,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
default_model_name=llm_provider.default_model_name,
|
||||
fast_default_model_name=llm_provider.fast_default_model_name,
|
||||
model_names=llm_provider.model_names,
|
||||
is_default_provider=None,
|
||||
|
||||
if not existing_llm_provider:
|
||||
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
existing_llm_provider.provider = llm_provider.provider
|
||||
existing_llm_provider.api_key = llm_provider.api_key
|
||||
existing_llm_provider.api_base = llm_provider.api_base
|
||||
existing_llm_provider.api_version = llm_provider.api_version
|
||||
existing_llm_provider.custom_config = llm_provider.custom_config
|
||||
existing_llm_provider.default_model_name = llm_provider.default_model_name
|
||||
existing_llm_provider.fast_default_model_name = llm_provider.fast_default_model_name
|
||||
existing_llm_provider.model_names = llm_provider.model_names
|
||||
existing_llm_provider.is_public = llm_provider.is_public
|
||||
existing_llm_provider.display_model_names = llm_provider.display_model_names
|
||||
|
||||
if not existing_llm_provider.id:
|
||||
# If its not already in the db, we need to generate an ID by flushing
|
||||
db_session.flush()
|
||||
|
||||
# Make sure the relationship table stays up to date
|
||||
update_group_llm_provider_relationships__no_commit(
|
||||
llm_provider_id=existing_llm_provider.id,
|
||||
group_ids=llm_provider.groups,
|
||||
db_session=db_session,
|
||||
)
|
||||
db_session.add(llm_provider_model)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return FullLLMProvider.from_model(llm_provider_model)
|
||||
return FullLLMProvider.from_model(existing_llm_provider)
|
||||
|
||||
|
||||
def fetch_existing_embedding_providers(
|
||||
@@ -74,8 +100,29 @@ def fetch_existing_embedding_providers(
|
||||
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_existing_llm_providers(db_session: Session) -> list[LLMProviderModel]:
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
) -> list[LLMProviderModel]:
|
||||
if not user:
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
stmt = select(LLMProviderModel).distinct()
|
||||
user_groups_subquery = (
|
||||
select(User__UserGroup.user_group_id)
|
||||
.where(User__UserGroup.user_id == user.id)
|
||||
.subquery()
|
||||
)
|
||||
access_conditions = or_(
|
||||
LLMProviderModel.is_public,
|
||||
LLMProviderModel.id.in_( # User is part of a group that has access
|
||||
select(LLMProvider__UserGroup.llm_provider_id).where(
|
||||
LLMProvider__UserGroup.user_group_id.in_(user_groups_subquery) # type: ignore
|
||||
)
|
||||
),
|
||||
)
|
||||
stmt = stmt.where(access_conditions)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def fetch_embedding_provider(
|
||||
@@ -119,6 +166,13 @@ def remove_embedding_provider(
|
||||
|
||||
|
||||
def remove_llm_provider(db_session: Session, provider_id: int) -> None:
|
||||
# Remove LLMProvider's dependent relationships
|
||||
db_session.execute(
|
||||
delete(LLMProvider__UserGroup).where(
|
||||
LLMProvider__UserGroup.llm_provider_id == provider_id
|
||||
)
|
||||
)
|
||||
# Remove LLMProvider
|
||||
db_session.execute(
|
||||
delete(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
)
|
||||
|
||||
@@ -11,6 +11,7 @@ from uuid import UUID
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID
|
||||
from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID
|
||||
from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID
|
||||
from fastapi_users_db_sqlalchemy.generics import TIMESTAMPAware
|
||||
from sqlalchemy import Boolean
|
||||
from sqlalchemy import DateTime
|
||||
from sqlalchemy import Enum
|
||||
@@ -37,10 +38,12 @@ from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.configs.constants import TokenRateLimitScope
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import ChatSessionSharedStatus
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.enums import IndexModelStatus
|
||||
from danswer.db.enums import TaskStatus
|
||||
@@ -50,9 +53,9 @@ from danswer.file_store.models import FileDescriptor
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.utils.encryption import decrypt_bytes_to_string
|
||||
from danswer.utils.encryption import encrypt_string_to_bytes
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
@@ -120,6 +123,14 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
postgresql.ARRAY(Integer), nullable=True
|
||||
)
|
||||
|
||||
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
|
||||
TIMESTAMPAware(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
default_model: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
# organized in typical structured fashion
|
||||
# formatted as `displayName__provider__modelName`
|
||||
|
||||
# relationships
|
||||
credentials: Mapped[list["Credential"]] = relationship(
|
||||
"Credential", back_populates="user", lazy="joined"
|
||||
@@ -132,10 +143,41 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
)
|
||||
|
||||
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
|
||||
input_prompts: Mapped[list["InputPrompt"]] = relationship(
|
||||
"InputPrompt", back_populates="user"
|
||||
)
|
||||
|
||||
# Personas owned by this user
|
||||
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
|
||||
# Custom tools created by this user
|
||||
custom_tools: Mapped[list["Tool"]] = relationship("Tool", back_populates="user")
|
||||
# Notifications for the UI
|
||||
notifications: Mapped[list["Notification"]] = relationship(
|
||||
"Notification", back_populates="user"
|
||||
)
|
||||
|
||||
|
||||
class InputPrompt(Base):
|
||||
__tablename__ = "inputprompt"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
prompt: Mapped[str] = mapped_column(String)
|
||||
content: Mapped[str] = mapped_column(String)
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
user: Mapped[User | None] = relationship("User", back_populates="input_prompts")
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
|
||||
|
||||
class InputPrompt__User(Base):
|
||||
__tablename__ = "inputprompt__user"
|
||||
|
||||
input_prompt_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("inputprompt.id"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("inputprompt.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
|
||||
@@ -157,6 +199,24 @@ class ApiKey(Base):
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
# Add this relationship to access the User object via user_id
|
||||
user: Mapped["User"] = relationship("User", foreign_keys=[user_id])
|
||||
|
||||
|
||||
class Notification(Base):
|
||||
__tablename__ = "notification"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
notif_type: Mapped[NotificationType] = mapped_column(
|
||||
Enum(NotificationType, native_enum=False)
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
dismissed: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
last_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
|
||||
user: Mapped[User] = relationship("User", back_populates="notifications")
|
||||
|
||||
|
||||
"""
|
||||
Association Tables
|
||||
@@ -184,7 +244,9 @@ class Persona__User(Base):
|
||||
__tablename__ = "persona__user"
|
||||
|
||||
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id"), primary_key=True, nullable=True
|
||||
)
|
||||
|
||||
|
||||
class DocumentSet__User(Base):
|
||||
@@ -193,7 +255,9 @@ class DocumentSet__User(Base):
|
||||
document_set_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("document_set.id"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id"), primary_key=True, nullable=True
|
||||
)
|
||||
|
||||
|
||||
class DocumentSet__ConnectorCredentialPair(Base):
|
||||
@@ -301,6 +365,9 @@ class ConnectorCredentialPair(Base):
|
||||
nullable=False,
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
status: Mapped[ConnectorCredentialPairStatus] = mapped_column(
|
||||
Enum(ConnectorCredentialPairStatus, native_enum=False), nullable=False
|
||||
)
|
||||
connector_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector.id"), primary_key=True
|
||||
)
|
||||
@@ -337,6 +404,9 @@ class ConnectorCredentialPair(Base):
|
||||
back_populates="connector_credential_pairs",
|
||||
overlaps="document_set",
|
||||
)
|
||||
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
|
||||
"IndexAttempt", back_populates="connector_credential_pair"
|
||||
)
|
||||
|
||||
|
||||
class Document(Base):
|
||||
@@ -416,6 +486,9 @@ class Connector(Base):
|
||||
connector_specific_config: Mapped[dict[str, Any]] = mapped_column(
|
||||
postgresql.JSONB()
|
||||
)
|
||||
indexing_start: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime, nullable=True
|
||||
)
|
||||
refresh_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
prune_freq: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -424,7 +497,6 @@ class Connector(Base):
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
disabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
credentials: Mapped[list["ConnectorCredentialPair"]] = relationship(
|
||||
"ConnectorCredentialPair",
|
||||
@@ -434,14 +506,17 @@ class Connector(Base):
|
||||
documents_by_connector: Mapped[
|
||||
list["DocumentByConnectorCredentialPair"]
|
||||
] = relationship("DocumentByConnectorCredentialPair", back_populates="connector")
|
||||
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
|
||||
"IndexAttempt", back_populates="connector"
|
||||
)
|
||||
|
||||
|
||||
class Credential(Base):
|
||||
__tablename__ = "credential"
|
||||
|
||||
name: Mapped[str] = mapped_column(String, nullable=True)
|
||||
|
||||
source: Mapped[DocumentSource] = mapped_column(
|
||||
Enum(DocumentSource, native_enum=False)
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
credential_json: Mapped[dict[str, Any]] = mapped_column(EncryptedJson())
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
@@ -462,9 +537,7 @@ class Credential(Base):
|
||||
documents_by_credential: Mapped[
|
||||
list["DocumentByConnectorCredentialPair"]
|
||||
] = relationship("DocumentByConnectorCredentialPair", back_populates="credential")
|
||||
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
|
||||
"IndexAttempt", back_populates="credential"
|
||||
)
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="credentials")
|
||||
|
||||
|
||||
@@ -516,12 +589,16 @@ class EmbeddingModel(Base):
|
||||
cloud_provider='{self.cloud_provider.name if self.cloud_provider else 'None'}')>"
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
return self.cloud_provider.api_key if self.cloud_provider else None
|
||||
def provider_type(self) -> EmbeddingProvider | None:
|
||||
return (
|
||||
EmbeddingProvider(self.cloud_provider.name.lower())
|
||||
if self.cloud_provider is not None
|
||||
else None
|
||||
)
|
||||
|
||||
@property
|
||||
def provider_type(self) -> str | None:
|
||||
return self.cloud_provider.name if self.cloud_provider else None
|
||||
def api_key(self) -> str | None:
|
||||
return self.cloud_provider.api_key if self.cloud_provider is not None else None
|
||||
|
||||
|
||||
class IndexAttempt(Base):
|
||||
@@ -534,13 +611,10 @@ class IndexAttempt(Base):
|
||||
__tablename__ = "index_attempt"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
connector_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("connector.id"),
|
||||
nullable=True,
|
||||
)
|
||||
credential_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("credential.id"),
|
||||
nullable=True,
|
||||
|
||||
connector_credential_pair_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
# Some index attempts that run from beginning will still have this as False
|
||||
@@ -578,12 +652,10 @@ class IndexAttempt(Base):
|
||||
onupdate=func.now(),
|
||||
)
|
||||
|
||||
connector: Mapped[Connector] = relationship(
|
||||
"Connector", back_populates="index_attempts"
|
||||
)
|
||||
credential: Mapped[Credential] = relationship(
|
||||
"Credential", back_populates="index_attempts"
|
||||
connector_credential_pair: Mapped[ConnectorCredentialPair] = relationship(
|
||||
"ConnectorCredentialPair", back_populates="index_attempts"
|
||||
)
|
||||
|
||||
embedding_model: Mapped[EmbeddingModel] = relationship(
|
||||
"EmbeddingModel", back_populates="index_attempts"
|
||||
)
|
||||
@@ -591,8 +663,7 @@ class IndexAttempt(Base):
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_index_attempt_latest_for_connector_credential_pair",
|
||||
"connector_id",
|
||||
"credential_id",
|
||||
"connector_credential_pair_id",
|
||||
"time_created",
|
||||
),
|
||||
)
|
||||
@@ -600,7 +671,6 @@ class IndexAttempt(Base):
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<IndexAttempt(id={self.id!r}, "
|
||||
f"connector_id={self.connector_id!r}, "
|
||||
f"status={self.status!r}, "
|
||||
f"error_msg={self.error_msg!r})>"
|
||||
f"time_created={self.time_created!r}, "
|
||||
@@ -821,6 +891,8 @@ class ChatMessage(Base):
|
||||
secondary="chat_message__search_doc",
|
||||
back_populates="chat_messages",
|
||||
)
|
||||
# NOTE: Should always be attached to the `assistant` message.
|
||||
# represents the tool calls used to generate this message
|
||||
tool_calls: Mapped[list["ToolCall"]] = relationship(
|
||||
"ToolCall",
|
||||
back_populates="message",
|
||||
@@ -923,6 +995,11 @@ class LLMProvider(Base):
|
||||
default_model_name: Mapped[str] = mapped_column(String)
|
||||
fast_default_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# Models to actually disp;aly to users
|
||||
# If nulled out, we assume in the application logic we should present all
|
||||
display_model_names: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
# The LLMs that are available for this provider. Only required if not a default provider.
|
||||
# If a default provider, then the LLM options are pulled from the `options.py` file.
|
||||
# If needed, can be pulled out as a separate table in the future.
|
||||
@@ -932,6 +1009,13 @@ class LLMProvider(Base):
|
||||
|
||||
# should only be set for a single provider
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
groups: Mapped[list["UserGroup"]] = relationship(
|
||||
"UserGroup",
|
||||
secondary="llm_provider__user_group",
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(Base):
|
||||
@@ -1071,10 +1155,6 @@ class Persona(Base):
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
description: Mapped[str] = mapped_column(String)
|
||||
# Currently stored but unused, all flows use hybrid
|
||||
search_type: Mapped[SearchType] = mapped_column(
|
||||
Enum(SearchType, native_enum=False), default=SearchType.HYBRID
|
||||
)
|
||||
# Number of chunks to pass to the LLM for generation.
|
||||
num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
# Pass every chunk through LLM for evaluation, fairly expensive
|
||||
@@ -1107,9 +1187,14 @@ class Persona(Base):
|
||||
# controls the ordering of personas in the UI
|
||||
# higher priority personas are displayed first, ties are resolved by the ID,
|
||||
# where lower value IDs (e.g. created earlier) are displayed first
|
||||
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=None)
|
||||
display_priority: Mapped[int | None] = mapped_column(
|
||||
Integer, nullable=True, default=None
|
||||
)
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
uploaded_image_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
icon_color: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
icon_shape: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# These are only defaults, users can select from all if desired
|
||||
prompts: Mapped[list[Prompt]] = relationship(
|
||||
@@ -1137,6 +1222,7 @@ class Persona(Base):
|
||||
viewonly=True,
|
||||
)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
groups: Mapped[list["UserGroup"]] = relationship(
|
||||
"UserGroup",
|
||||
secondary="persona__user_group",
|
||||
@@ -1323,7 +1409,9 @@ class User__UserGroup(Base):
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id"), primary_key=True, nullable=True
|
||||
)
|
||||
|
||||
|
||||
class UserGroup__ConnectorCredentialPair(Base):
|
||||
@@ -1360,6 +1448,17 @@ class Persona__UserGroup(Base):
|
||||
)
|
||||
|
||||
|
||||
class LLMProvider__UserGroup(Base):
|
||||
__tablename__ = "llm_provider__user_group"
|
||||
|
||||
llm_provider_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("llm_provider.id"), primary_key=True
|
||||
)
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class DocumentSet__UserGroup(Base):
|
||||
__tablename__ = "document_set__user_group"
|
||||
|
||||
|
||||
76
backend/danswer/db/notification.py
Normal file
76
backend/danswer/db/notification.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.db.models import Notification
|
||||
from danswer.db.models import User
|
||||
|
||||
|
||||
def create_notification(
|
||||
user: User | None,
|
||||
notif_type: NotificationType,
|
||||
db_session: Session,
|
||||
) -> Notification:
|
||||
notification = Notification(
|
||||
user_id=user.id if user else None,
|
||||
notif_type=notif_type,
|
||||
dismissed=False,
|
||||
last_shown=func.now(),
|
||||
first_shown=func.now(),
|
||||
)
|
||||
db_session.add(notification)
|
||||
db_session.commit()
|
||||
return notification
|
||||
|
||||
|
||||
def get_notification_by_id(
|
||||
notification_id: int, user: User | None, db_session: Session
|
||||
) -> Notification:
|
||||
user_id = user.id if user else None
|
||||
notif = db_session.get(Notification, notification_id)
|
||||
if not notif:
|
||||
raise ValueError(f"No notification found with id {notification_id}")
|
||||
if notif.user_id != user_id:
|
||||
raise PermissionError(
|
||||
f"User {user_id} is not authorized to access notification {notification_id}"
|
||||
)
|
||||
return notif
|
||||
|
||||
|
||||
def get_notifications(
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
notif_type: NotificationType | None = None,
|
||||
include_dismissed: bool = True,
|
||||
) -> list[Notification]:
|
||||
query = select(Notification).where(
|
||||
Notification.user_id == user.id if user else Notification.user_id.is_(None)
|
||||
)
|
||||
if not include_dismissed:
|
||||
query = query.where(Notification.dismissed.is_(False))
|
||||
if notif_type:
|
||||
query = query.where(Notification.notif_type == notif_type)
|
||||
return list(db_session.execute(query).scalars().all())
|
||||
|
||||
|
||||
def dismiss_all_notifications(
|
||||
notif_type: NotificationType,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
db_session.query(Notification).filter(Notification.notif_type == notif_type).update(
|
||||
{"dismissed": True}
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def dismiss_notification(notification: Notification, db_session: Session) -> None:
|
||||
notification.dismissed = True
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_notification_last_shown(
|
||||
notification: Notification, db_session: Session
|
||||
) -> None:
|
||||
notification.last_shown = func.now()
|
||||
db_session.commit()
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import not_
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
@@ -24,6 +25,7 @@ from danswer.db.models import StarterMessage
|
||||
from danswer.db.models import Tool
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.server.features.persona.models import CreatePersonaRequest
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
@@ -80,6 +82,9 @@ def create_update_persona(
|
||||
starter_messages=create_persona_request.starter_messages,
|
||||
is_public=create_persona_request.is_public,
|
||||
db_session=db_session,
|
||||
icon_color=create_persona_request.icon_color,
|
||||
icon_shape=create_persona_request.icon_shape,
|
||||
uploaded_image_id=create_persona_request.uploaded_image_id,
|
||||
)
|
||||
|
||||
versioned_make_persona_private = fetch_versioned_implementation(
|
||||
@@ -328,6 +333,11 @@ def upsert_persona(
|
||||
persona_id: int | None = None,
|
||||
default_persona: bool = False,
|
||||
commit: bool = True,
|
||||
icon_color: str | None = None,
|
||||
icon_shape: int | None = None,
|
||||
uploaded_image_id: str | None = None,
|
||||
display_priority: int | None = None,
|
||||
is_visible: bool = True,
|
||||
) -> Persona:
|
||||
if persona_id is not None:
|
||||
persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
||||
@@ -383,6 +393,11 @@ def upsert_persona(
|
||||
persona.starter_messages = starter_messages
|
||||
persona.deleted = False # Un-delete if previously deleted
|
||||
persona.is_public = is_public
|
||||
persona.icon_color = icon_color
|
||||
persona.icon_shape = icon_shape
|
||||
persona.uploaded_image_id = uploaded_image_id
|
||||
persona.display_priority = display_priority
|
||||
persona.is_visible = is_visible
|
||||
|
||||
# Do not delete any associations manually added unless
|
||||
# a new updated list is provided
|
||||
@@ -415,6 +430,11 @@ def upsert_persona(
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
starter_messages=starter_messages,
|
||||
tools=tools or [],
|
||||
icon_shape=icon_shape,
|
||||
icon_color=icon_color,
|
||||
uploaded_image_id=uploaded_image_id,
|
||||
display_priority=display_priority,
|
||||
is_visible=is_visible,
|
||||
)
|
||||
db_session.add(persona)
|
||||
|
||||
@@ -548,6 +568,8 @@ def get_default_prompt__read_only() -> Prompt:
|
||||
return _get_default_prompt(db_session)
|
||||
|
||||
|
||||
# TODO: since this gets called with every chat message, could it be more efficient to pregenerate
|
||||
# a direct mapping indicating whether a user has access to a specific persona?
|
||||
def get_persona_by_id(
|
||||
persona_id: int,
|
||||
# if user is `None` assume the user is an admin or auth is disabled
|
||||
@@ -556,16 +578,38 @@ def get_persona_by_id(
|
||||
include_deleted: bool = False,
|
||||
is_for_edit: bool = True, # NOTE: assume true for safety
|
||||
) -> Persona:
|
||||
stmt = select(Persona).where(Persona.id == persona_id)
|
||||
stmt = (
|
||||
select(Persona)
|
||||
.options(selectinload(Persona.users), selectinload(Persona.groups))
|
||||
.where(Persona.id == persona_id)
|
||||
)
|
||||
|
||||
or_conditions = []
|
||||
|
||||
# if user is an admin, they should have access to all Personas
|
||||
# and will skip the following clause
|
||||
if user is not None and user.role != UserRole.ADMIN:
|
||||
or_conditions.extend([Persona.user_id == user.id, Persona.user_id.is_(None)])
|
||||
# the user is not an admin
|
||||
isPersonaUnowned = Persona.user_id.is_(
|
||||
None
|
||||
) # allow access if persona user id is None
|
||||
isUserCreator = (
|
||||
Persona.user_id == user.id
|
||||
) # allow access if user created the persona
|
||||
or_conditions.extend([isPersonaUnowned, isUserCreator])
|
||||
|
||||
# if we aren't editing, also give access to all public personas
|
||||
# if we aren't editing, also give access if:
|
||||
# 1. the user is authorized for this persona
|
||||
# 2. the user is in an authorized group for this persona
|
||||
# 3. if the persona is public
|
||||
if not is_for_edit:
|
||||
isSharedWithUser = Persona.users.any(
|
||||
id=user.id
|
||||
) # allow access if user is in allowed users
|
||||
isSharedWithGroup = Persona.groups.any(
|
||||
UserGroup.users.any(id=user.id)
|
||||
) # allow access if user is in any allowed group
|
||||
or_conditions.extend([isSharedWithUser, isSharedWithGroup])
|
||||
or_conditions.append(Persona.is_public.is_(True))
|
||||
|
||||
if or_conditions:
|
||||
|
||||
@@ -7,6 +7,7 @@ from danswer.access.models import DocumentAccess
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import InferenceChunkUncleaned
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -209,80 +210,6 @@ class IdRetrievalCapable(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KeywordCapable(abc.ABC):
|
||||
"""
|
||||
Class must implement the keyword search functionality
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def keyword_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
filters: IndexFilters,
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""
|
||||
Run keyword search and return a list of chunks. Inference chunks are chunks with all of the
|
||||
information required for query time purposes. For example, some details of the document
|
||||
required at indexing time are no longer needed past this point. At the same time, the
|
||||
matching keywords need to be highlighted.
|
||||
|
||||
NOTE: the query passed in here is the unprocessed plain text query. Preprocessing is
|
||||
expected to be handled by this function as it may depend on the index implementation.
|
||||
Things like query expansion, synonym injection, stop word removal, lemmatization, etc. are
|
||||
done here.
|
||||
|
||||
Parameters:
|
||||
- query: unmodified user query
|
||||
- filters: standard filter object
|
||||
- time_decay_multiplier: how much to decay the document scores as they age. Some queries
|
||||
based on the persona settings, will have this be a 2x or 3x of the default
|
||||
- num_to_retrieve: number of highest matching chunks to return
|
||||
- offset: number of highest matching chunks to skip (kind of like pagination)
|
||||
|
||||
Returns:
|
||||
best matching chunks based on keyword matching (should be BM25 algorithm ideally)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class VectorCapable(abc.ABC):
|
||||
"""
|
||||
Class must implement the vector/semantic search functionality
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query: str, # Needed for matching purposes
|
||||
query_embedding: list[float],
|
||||
filters: IndexFilters,
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""
|
||||
Run vector/semantic search and return a list of inference chunks.
|
||||
|
||||
Parameters:
|
||||
- query: unmodified user query. This is needed for getting the matching highlighted
|
||||
keywords
|
||||
- query_embedding: vector representation of the query, must be of the correct
|
||||
dimensionality for the primary index
|
||||
- filters: standard filter object
|
||||
- time_decay_multiplier: how much to decay the document scores as they age. Some queries
|
||||
based on the persona settings, will have this be a 2x or 3x of the default
|
||||
- num_to_retrieve: number of highest matching chunks to return
|
||||
- offset: number of highest matching chunks to skip (kind of like pagination)
|
||||
|
||||
Returns:
|
||||
best matching chunks based on vector similarity
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class HybridCapable(abc.ABC):
|
||||
"""
|
||||
Class must implement hybrid (keyword + vector) search functionality
|
||||
@@ -292,12 +219,13 @@ class HybridCapable(abc.ABC):
|
||||
def hybrid_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
query_embedding: list[float],
|
||||
query_embedding: Embedding,
|
||||
final_keywords: list[str] | None,
|
||||
filters: IndexFilters,
|
||||
hybrid_alpha: float,
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
hybrid_alpha: float | None = None,
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
"""
|
||||
Run hybrid search and return a list of inference chunks.
|
||||
@@ -312,15 +240,16 @@ class HybridCapable(abc.ABC):
|
||||
keywords
|
||||
- query_embedding: vector representation of the query, must be of the correct
|
||||
dimensionality for the primary index
|
||||
- final_keywords: Final keywords to be used from the query, defaults to query if not set
|
||||
- filters: standard filter object
|
||||
- time_decay_multiplier: how much to decay the document scores as they age. Some queries
|
||||
based on the persona settings, will have this be a 2x or 3x of the default
|
||||
- num_to_retrieve: number of highest matching chunks to return
|
||||
- offset: number of highest matching chunks to skip (kind of like pagination)
|
||||
- hybrid_alpha: weighting between the keyword and vector search results. It is important
|
||||
that the two scores are normalized to the same range so that a meaningful
|
||||
comparison can be made. 1 for 100% weighting on vector score, 0 for 100% weighting
|
||||
on keyword score.
|
||||
- time_decay_multiplier: how much to decay the document scores as they age. Some queries
|
||||
based on the persona settings, will have this be a 2x or 3x of the default
|
||||
- num_to_retrieve: number of highest matching chunks to return
|
||||
- offset: number of highest matching chunks to skip (kind of like pagination)
|
||||
|
||||
Returns:
|
||||
best matching chunks based on weighted sum of keyword and vector/semantic search scores
|
||||
@@ -386,7 +315,7 @@ class BaseIndex(
|
||||
"""
|
||||
|
||||
|
||||
class DocumentIndex(KeywordCapable, VectorCapable, HybridCapable, BaseIndex, abc.ABC):
|
||||
class DocumentIndex(HybridCapable, BaseIndex, abc.ABC):
|
||||
"""
|
||||
A valid document index that can plug into all Danswer flows must implement all of these
|
||||
functionalities, though "technically" it does not need to be keyword or vector capable as
|
||||
|
||||
@@ -219,28 +219,4 @@ schema DANSWER_CHUNK_NAME {
|
||||
expression: bm25(content) + (5 * bm25(title))
|
||||
}
|
||||
}
|
||||
|
||||
# THE ONES BELOW ARE OUT OF DATE, DO NOT USE
|
||||
# THEY MIGHT NOT EVEN WORK AT ALL
|
||||
rank-profile keyword_search inherits default, default_rank {
|
||||
first-phase {
|
||||
expression: bm25(content) * document_boost * recency_bias
|
||||
}
|
||||
|
||||
match-features: recency_bias document_boost bm25(content)
|
||||
}
|
||||
|
||||
rank-profile semantic_searchVARIABLE_DIM inherits default, default_rank {
|
||||
inputs {
|
||||
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
|
||||
}
|
||||
|
||||
first-phase {
|
||||
# Cannot do boost with the chosen embedding model because of high default similarity
|
||||
# This depends on the embedding model chosen
|
||||
expression: closeness(field, embeddings)
|
||||
}
|
||||
|
||||
match-features: recency_bias document_boost closest(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,8 +24,6 @@ from danswer.configs.app_configs import VESPA_HOST
|
||||
from danswer.configs.app_configs import VESPA_PORT
|
||||
from danswer.configs.app_configs import VESPA_TENANT_PORT
|
||||
from danswer.configs.chat_configs import DOC_TIME_DECAY
|
||||
from danswer.configs.chat_configs import EDIT_KEYWORD_QUERY
|
||||
from danswer.configs.chat_configs import HYBRID_ALPHA
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from danswer.configs.constants import ACCESS_CONTROL_LIST
|
||||
@@ -52,7 +50,6 @@ from danswer.configs.constants import SOURCE_LINKS
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
from danswer.configs.constants import TITLE
|
||||
from danswer.configs.constants import TITLE_EMBEDDING
|
||||
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
@@ -65,10 +62,9 @@ from danswer.document_index.vespa.utils import replace_invalid_doc_id_characters
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import InferenceChunkUncleaned
|
||||
from danswer.search.retrieval.search_runner import query_processing
|
||||
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -329,11 +325,13 @@ def _index_vespa_chunk(
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
document = chunk.source_document
|
||||
|
||||
# No minichunk documents in vespa, minichunk vectors are stored in the chunk itself
|
||||
vespa_chunk_id = str(get_uuid_from_chunk(chunk))
|
||||
|
||||
embeddings = chunk.embeddings
|
||||
|
||||
embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding}
|
||||
|
||||
if embeddings.mini_chunk_embeddings:
|
||||
for ind, m_c_embed in enumerate(embeddings.mini_chunk_embeddings):
|
||||
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
|
||||
@@ -346,11 +344,15 @@ def _index_vespa_chunk(
|
||||
BLURB: remove_invalid_unicode_chars(chunk.blurb),
|
||||
TITLE: remove_invalid_unicode_chars(title) if title else None,
|
||||
SKIP_TITLE_EMBEDDING: not title,
|
||||
CONTENT: remove_invalid_unicode_chars(chunk.content),
|
||||
# For the BM25 index, the keyword suffix is used, the vector is already generated with the more
|
||||
# natural language representation of the metadata section
|
||||
CONTENT: remove_invalid_unicode_chars(
|
||||
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_keyword}"
|
||||
),
|
||||
# This duplication of `content` is needed for keyword highlighting
|
||||
# Note that it's not exactly the same as the actual content
|
||||
# which contains the title prefix and metadata suffix
|
||||
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content_summary),
|
||||
CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content),
|
||||
SOURCE_TYPE: str(document.source.value),
|
||||
SOURCE_LINKS: json.dumps(chunk.source_links),
|
||||
SEMANTIC_IDENTIFIER: remove_invalid_unicode_chars(document.semantic_identifier),
|
||||
@@ -358,7 +360,7 @@ def _index_vespa_chunk(
|
||||
METADATA: json.dumps(document.metadata),
|
||||
# Save as a list for efficient extraction as an Attribute
|
||||
METADATA_LIST: chunk.source_document.get_metadata_str_attributes(),
|
||||
METADATA_SUFFIX: chunk.metadata_suffix,
|
||||
METADATA_SUFFIX: chunk.metadata_suffix_keyword,
|
||||
EMBEDDINGS: embeddings_name_vector_map,
|
||||
TITLE_EMBEDDING: chunk.title_embedding,
|
||||
BOOST: chunk.boost,
|
||||
@@ -601,7 +603,7 @@ def _vespa_hit_to_inference_chunk(
|
||||
chunk_id=fields[CHUNK_ID],
|
||||
blurb=fields.get(BLURB, ""), # Unused
|
||||
content=fields[CONTENT], # Includes extra title prefix and metadata suffix
|
||||
source_links=source_links_dict,
|
||||
source_links=source_links_dict or {0: ""},
|
||||
section_continuation=fields[SECTION_CONTINUATION],
|
||||
document_id=fields[DOCUMENT_ID],
|
||||
source_type=fields[SOURCE_TYPE],
|
||||
@@ -987,95 +989,17 @@ class VespaIndex(DocumentIndex):
|
||||
inference_chunks.sort(key=lambda chunk: chunk.chunk_id)
|
||||
return inference_chunks
|
||||
|
||||
def keyword_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
filters: IndexFilters,
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||
offset: int = 0,
|
||||
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
# IMPORTANT: THIS FUNCTION IS NOT UP TO DATE, DOES NOT WORK CORRECTLY
|
||||
vespa_where_clauses = _build_vespa_filters(filters)
|
||||
yql = (
|
||||
VespaIndex.yql_base.format(index_name=self.index_name)
|
||||
+ vespa_where_clauses
|
||||
# `({defaultIndex: "content_summary"}userInput(@query))` section is
|
||||
# needed for highlighting while the N-gram highlighting is broken /
|
||||
# not working as desired
|
||||
+ '({grammar: "weakAnd"}userInput(@query) '
|
||||
+ f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))'
|
||||
)
|
||||
|
||||
final_query = query_processing(query) if edit_keyword_query else query
|
||||
|
||||
params: dict[str, str | int] = {
|
||||
"yql": yql,
|
||||
"query": final_query,
|
||||
"input.query(decay_factor)": str(DOC_TIME_DECAY * time_decay_multiplier),
|
||||
"hits": num_to_retrieve,
|
||||
"offset": offset,
|
||||
"ranking.profile": "keyword_search",
|
||||
"timeout": _VESPA_TIMEOUT,
|
||||
}
|
||||
|
||||
return _query_vespa(params)
|
||||
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
query_embedding: list[float],
|
||||
filters: IndexFilters,
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||
offset: int = 0,
|
||||
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
|
||||
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
# IMPORTANT: THIS FUNCTION IS NOT UP TO DATE, DOES NOT WORK CORRECTLY
|
||||
vespa_where_clauses = _build_vespa_filters(filters)
|
||||
yql = (
|
||||
VespaIndex.yql_base.format(index_name=self.index_name)
|
||||
+ vespa_where_clauses
|
||||
+ f"(({{targetHits: {10 * num_to_retrieve}}}nearestNeighbor(embeddings, query_embedding)) "
|
||||
# `({defaultIndex: "content_summary"}userInput(@query))` section is
|
||||
# needed for highlighting while the N-gram highlighting is broken /
|
||||
# not working as desired
|
||||
+ f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))'
|
||||
)
|
||||
|
||||
query_keywords = (
|
||||
" ".join(remove_stop_words_and_punctuation(query))
|
||||
if edit_keyword_query
|
||||
else query
|
||||
)
|
||||
|
||||
params: dict[str, str | int] = {
|
||||
"yql": yql,
|
||||
"query": query_keywords, # Needed for highlighting
|
||||
"input.query(query_embedding)": str(query_embedding),
|
||||
"input.query(decay_factor)": str(DOC_TIME_DECAY * time_decay_multiplier),
|
||||
"hits": num_to_retrieve,
|
||||
"offset": offset,
|
||||
"ranking.profile": f"hybrid_search{len(query_embedding)}",
|
||||
"timeout": _VESPA_TIMEOUT,
|
||||
}
|
||||
|
||||
return _query_vespa(params)
|
||||
|
||||
def hybrid_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
query_embedding: list[float],
|
||||
query_embedding: Embedding,
|
||||
final_keywords: list[str] | None,
|
||||
filters: IndexFilters,
|
||||
hybrid_alpha: float,
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int,
|
||||
offset: int = 0,
|
||||
hybrid_alpha: float | None = HYBRID_ALPHA,
|
||||
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
|
||||
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
|
||||
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
vespa_where_clauses = _build_vespa_filters(filters)
|
||||
# Needs to be at least as much as the value set in Vespa schema config
|
||||
@@ -1089,20 +1013,14 @@ class VespaIndex(DocumentIndex):
|
||||
+ f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))'
|
||||
)
|
||||
|
||||
query_keywords = (
|
||||
" ".join(remove_stop_words_and_punctuation(query))
|
||||
if edit_keyword_query
|
||||
else query
|
||||
)
|
||||
final_query = " ".join(final_keywords) if final_keywords else query
|
||||
|
||||
params: dict[str, str | int | float] = {
|
||||
"yql": yql,
|
||||
"query": query_keywords,
|
||||
"query": final_query,
|
||||
"input.query(query_embedding)": str(query_embedding),
|
||||
"input.query(decay_factor)": str(DOC_TIME_DECAY * time_decay_multiplier),
|
||||
"input.query(alpha)": hybrid_alpha
|
||||
if hybrid_alpha is not None
|
||||
else HYBRID_ALPHA,
|
||||
"input.query(alpha)": hybrid_alpha,
|
||||
"input.query(title_content_ratio)": title_content_ratio
|
||||
if title_content_ratio is not None
|
||||
else TITLE_CONTENT_RATIO,
|
||||
|
||||
@@ -1,115 +0,0 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
|
||||
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_API_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.llm import update_default_provider
|
||||
from danswer.db.llm import upsert_llm_provider
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.factory import PostgresBackedDynamicConfigStore
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def read_file_system_store(directory_path: str) -> dict:
|
||||
store = {}
|
||||
base_path = Path(directory_path)
|
||||
for file_path in base_path.iterdir():
|
||||
if file_path.is_file() and "." not in file_path.name:
|
||||
with open(file_path, "r") as file:
|
||||
key = file_path.stem
|
||||
value = json.load(file)
|
||||
|
||||
if value:
|
||||
store[key] = value
|
||||
return store
|
||||
|
||||
|
||||
def insert_into_postgres(store_data: dict) -> None:
|
||||
port_once_key = "file_store_ported"
|
||||
config_store = PostgresBackedDynamicConfigStore()
|
||||
try:
|
||||
config_store.load(port_once_key)
|
||||
return
|
||||
except ConfigNotFoundError:
|
||||
pass
|
||||
|
||||
for key, value in store_data.items():
|
||||
config_store.store(key, value)
|
||||
|
||||
config_store.store(port_once_key, True)
|
||||
|
||||
|
||||
def port_filesystem_to_postgres(directory_path: str) -> None:
|
||||
store_data = read_file_system_store(directory_path)
|
||||
insert_into_postgres(store_data)
|
||||
|
||||
|
||||
def port_api_key_to_postgres() -> None:
|
||||
# can't port over custom, no longer supported
|
||||
if GEN_AI_MODEL_PROVIDER == "custom":
|
||||
return
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
# if we already have ported things over / setup providers in the db, don't do anything
|
||||
if len(fetch_existing_llm_providers(db_session)) > 0:
|
||||
return
|
||||
|
||||
api_key = GEN_AI_API_KEY
|
||||
try:
|
||||
api_key = cast(
|
||||
str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY)
|
||||
)
|
||||
except ConfigNotFoundError:
|
||||
pass
|
||||
|
||||
# if no API key set, don't port anything over
|
||||
if not api_key:
|
||||
return
|
||||
|
||||
default_model_name = GEN_AI_MODEL_VERSION
|
||||
if GEN_AI_MODEL_PROVIDER == "openai" and not default_model_name:
|
||||
default_model_name = "gpt-4"
|
||||
|
||||
# if no default model name found, don't port anything over
|
||||
if not default_model_name:
|
||||
return
|
||||
|
||||
default_fast_model_name = FAST_GEN_AI_MODEL_VERSION
|
||||
if GEN_AI_MODEL_PROVIDER == "openai" and not default_fast_model_name:
|
||||
default_fast_model_name = "gpt-3.5-turbo"
|
||||
|
||||
llm_provider_upsert = LLMProviderUpsertRequest(
|
||||
name=GEN_AI_MODEL_PROVIDER,
|
||||
provider=GEN_AI_MODEL_PROVIDER,
|
||||
api_key=api_key,
|
||||
api_base=GEN_AI_API_ENDPOINT,
|
||||
api_version=GEN_AI_API_VERSION,
|
||||
# can't port over any custom configs, since we don't know
|
||||
# all the possible keys and values that could be in there
|
||||
custom_config=None,
|
||||
default_model_name=default_model_name,
|
||||
fast_default_model_name=default_fast_model_name,
|
||||
model_names=None,
|
||||
)
|
||||
llm_provider = upsert_llm_provider(db_session, llm_provider_upsert)
|
||||
update_default_provider(db_session, llm_provider.id)
|
||||
logger.info(f"Ported over LLM provider:\n\n{llm_provider}")
|
||||
|
||||
# delete the old API key
|
||||
try:
|
||||
get_dynamic_config_store().delete(GEN_AI_API_KEY_STORAGE_KEY)
|
||||
except ConfigNotFoundError:
|
||||
pass
|
||||
@@ -8,7 +8,7 @@ from typing import cast
|
||||
from filelock import FileLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import SessionFactory
|
||||
from danswer.db.engine import get_session_factory
|
||||
from danswer.db.models import KVStore
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.dynamic_configs.interface import DynamicConfigStore
|
||||
@@ -56,7 +56,8 @@ class FileSystemBackedDynamicConfigStore(DynamicConfigStore):
|
||||
class PostgresBackedDynamicConfigStore(DynamicConfigStore):
|
||||
@contextmanager
|
||||
def get_session(self) -> Iterator[Session]:
|
||||
session: Session = SessionFactory()
|
||||
factory = get_session_factory()
|
||||
session: Session = factory()
|
||||
try:
|
||||
yield session
|
||||
finally:
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import abc
|
||||
from collections.abc import Callable
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.configs.app_configs import BLURB_SIZE
|
||||
from danswer.configs.app_configs import MINI_CHUNK_SIZE
|
||||
from danswer.configs.app_configs import SKIP_METADATA_IN_CHUNK
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MAX_CHUNK_TITLE_LEN
|
||||
from danswer.configs.constants import RETURN_SEPARATOR
|
||||
from danswer.configs.constants import SECTION_SEPARATOR
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
@@ -15,12 +15,13 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
)
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.search.search_nlp_models import get_default_tokenizer
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import AutoTokenizer # type:ignore
|
||||
from llama_index.text_splitter import SentenceSplitter # type:ignore
|
||||
|
||||
|
||||
# Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps
|
||||
@@ -28,6 +29,8 @@ if TYPE_CHECKING:
|
||||
CHUNK_OVERLAP = 0
|
||||
# Fairly arbitrary numbers but the general concept is we don't want the title/metadata to
|
||||
# overwhelm the actual contents of the chunk
|
||||
# For example in a rare case, this could be 128 tokens for the 512 chunk and title prefix
|
||||
# could be another 128 tokens leaving 256 for the actual contents
|
||||
MAX_METADATA_PERCENTAGE = 0.25
|
||||
CHUNK_MIN_CONTENT = 256
|
||||
|
||||
@@ -36,15 +39,11 @@ logger = setup_logger()
|
||||
ChunkFunc = Callable[[Document], list[DocAwareChunk]]
|
||||
|
||||
|
||||
def extract_blurb(text: str, blurb_size: int) -> str:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
token_count_func = get_default_tokenizer().tokenize
|
||||
blurb_splitter = SentenceSplitter(
|
||||
tokenizer=token_count_func, chunk_size=blurb_size, chunk_overlap=0
|
||||
)
|
||||
|
||||
return blurb_splitter.split_text(text)[0]
|
||||
def extract_blurb(text: str, blurb_splitter: "SentenceSplitter") -> str:
|
||||
texts = blurb_splitter.split_text(text)
|
||||
if not texts:
|
||||
return ""
|
||||
return texts[0]
|
||||
|
||||
|
||||
def chunk_large_section(
|
||||
@@ -52,76 +51,130 @@ def chunk_large_section(
|
||||
section_link_text: str,
|
||||
document: Document,
|
||||
start_chunk_id: int,
|
||||
tokenizer: "AutoTokenizer",
|
||||
chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
chunk_overlap: int = CHUNK_OVERLAP,
|
||||
blurb_size: int = BLURB_SIZE,
|
||||
title_prefix: str = "",
|
||||
metadata_suffix: str = "",
|
||||
blurb: str,
|
||||
chunk_splitter: "SentenceSplitter",
|
||||
mini_chunk_splitter: Optional["SentenceSplitter"],
|
||||
title_prefix: str,
|
||||
metadata_suffix_semantic: str,
|
||||
metadata_suffix_keyword: str,
|
||||
) -> list[DocAwareChunk]:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
blurb = extract_blurb(section_text, blurb_size)
|
||||
|
||||
sentence_aware_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize, chunk_size=chunk_size, chunk_overlap=chunk_overlap
|
||||
)
|
||||
|
||||
split_texts = sentence_aware_splitter.split_text(section_text)
|
||||
split_texts = chunk_splitter.split_text(section_text)
|
||||
|
||||
chunks = [
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=start_chunk_id + chunk_ind,
|
||||
blurb=blurb,
|
||||
content=f"{title_prefix}{chunk_str}{metadata_suffix}",
|
||||
content_summary=chunk_str,
|
||||
content=chunk_text,
|
||||
source_links={0: section_link_text},
|
||||
section_continuation=(chunk_ind != 0),
|
||||
metadata_suffix=metadata_suffix,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
|
||||
if mini_chunk_splitter and chunk_text.strip()
|
||||
else None,
|
||||
)
|
||||
for chunk_ind, chunk_str in enumerate(split_texts)
|
||||
for chunk_ind, chunk_text in enumerate(split_texts)
|
||||
]
|
||||
return chunks
|
||||
|
||||
|
||||
def _get_metadata_suffix_for_document_index(
|
||||
metadata: dict[str, str | list[str]]
|
||||
) -> str:
|
||||
metadata: dict[str, str | list[str]], include_separator: bool = False
|
||||
) -> tuple[str, str]:
|
||||
"""
|
||||
Returns the metadata as a natural language string representation with all of the keys and values for the vector embedding
|
||||
and a string of all of the values for the keyword search
|
||||
|
||||
For example, if we have the following metadata:
|
||||
{
|
||||
"author": "John Doe",
|
||||
"space": "Engineering"
|
||||
}
|
||||
The vector embedding string should include the relation between the key and value wheres as for keyword we only want John Doe
|
||||
and Engineering. The keys are repeat and much more noisy.
|
||||
"""
|
||||
if not metadata:
|
||||
return ""
|
||||
return "", ""
|
||||
|
||||
metadata_str = "Metadata:\n"
|
||||
metadata_values = []
|
||||
for key, value in metadata.items():
|
||||
if key in get_metadata_keys_to_ignore():
|
||||
continue
|
||||
|
||||
value_str = ", ".join(value) if isinstance(value, list) else value
|
||||
|
||||
if isinstance(value, list):
|
||||
metadata_values.extend(value)
|
||||
else:
|
||||
metadata_values.append(value)
|
||||
|
||||
metadata_str += f"\t{key} - {value_str}\n"
|
||||
return metadata_str.strip()
|
||||
|
||||
metadata_semantic = metadata_str.strip()
|
||||
metadata_keyword = " ".join(metadata_values)
|
||||
|
||||
if include_separator:
|
||||
return RETURN_SEPARATOR + metadata_semantic, RETURN_SEPARATOR + metadata_keyword
|
||||
return metadata_semantic, metadata_keyword
|
||||
|
||||
|
||||
def chunk_document(
|
||||
document: Document,
|
||||
model_name: str,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
enable_multipass: bool,
|
||||
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
subsection_overlap: int = CHUNK_OVERLAP,
|
||||
blurb_size: int = BLURB_SIZE,
|
||||
blurb_size: int = BLURB_SIZE, # Used for both title and content
|
||||
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
|
||||
mini_chunk_size: int = MINI_CHUNK_SIZE,
|
||||
) -> list[DocAwareChunk]:
|
||||
tokenizer = get_default_tokenizer()
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
title = document.get_title_for_document_index()
|
||||
title_prefix = f"{title[:MAX_CHUNK_TITLE_LEN]}{RETURN_SEPARATOR}" if title else ""
|
||||
tokenizer = get_tokenizer(
|
||||
model_name=model_name,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
blurb_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize, chunk_size=blurb_size, chunk_overlap=0
|
||||
)
|
||||
|
||||
chunk_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize,
|
||||
chunk_size=chunk_tok_size,
|
||||
chunk_overlap=subsection_overlap,
|
||||
)
|
||||
|
||||
mini_chunk_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize,
|
||||
chunk_size=mini_chunk_size,
|
||||
chunk_overlap=0,
|
||||
)
|
||||
|
||||
title = extract_blurb(document.get_title_for_document_index() or "", blurb_splitter)
|
||||
title_prefix = title + RETURN_SEPARATOR if title else ""
|
||||
title_tokens = len(tokenizer.tokenize(title_prefix))
|
||||
|
||||
metadata_suffix = ""
|
||||
metadata_suffix_semantic = ""
|
||||
metadata_suffix_keyword = ""
|
||||
metadata_tokens = 0
|
||||
if include_metadata:
|
||||
metadata = _get_metadata_suffix_for_document_index(document.metadata)
|
||||
metadata_suffix = RETURN_SEPARATOR + metadata if metadata else ""
|
||||
metadata_tokens = len(tokenizer.tokenize(metadata_suffix))
|
||||
(
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
) = _get_metadata_suffix_for_document_index(
|
||||
document.metadata, include_separator=True
|
||||
)
|
||||
metadata_tokens = len(tokenizer.tokenize(metadata_suffix_semantic))
|
||||
|
||||
if metadata_tokens >= chunk_tok_size * MAX_METADATA_PERCENTAGE:
|
||||
metadata_suffix = ""
|
||||
# Note: we can keep the keyword suffix even if the semantic suffix is too long to fit in the model
|
||||
# context, there is no limit for the keyword component
|
||||
metadata_suffix_semantic = ""
|
||||
metadata_tokens = 0
|
||||
|
||||
content_token_limit = chunk_tok_size - title_tokens - metadata_tokens
|
||||
@@ -130,7 +183,7 @@ def chunk_document(
|
||||
if content_token_limit <= CHUNK_MIN_CONTENT:
|
||||
content_token_limit = chunk_tok_size
|
||||
title_prefix = ""
|
||||
metadata_suffix = ""
|
||||
metadata_suffix_semantic = ""
|
||||
|
||||
chunks: list[DocAwareChunk] = []
|
||||
link_offsets: dict[int, str] = {}
|
||||
@@ -151,12 +204,16 @@ def chunk_document(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_size),
|
||||
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
|
||||
content_summary=chunk_text,
|
||||
blurb=extract_blurb(chunk_text, blurb_splitter),
|
||||
content=chunk_text,
|
||||
source_links=link_offsets,
|
||||
section_continuation=False,
|
||||
metadata_suffix=metadata_suffix,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
|
||||
if enable_multipass and chunk_text.strip()
|
||||
else None,
|
||||
)
|
||||
)
|
||||
link_offsets = {}
|
||||
@@ -167,12 +224,14 @@ def chunk_document(
|
||||
section_link_text=section_link_text,
|
||||
document=document,
|
||||
start_chunk_id=len(chunks),
|
||||
tokenizer=tokenizer,
|
||||
chunk_size=content_token_limit,
|
||||
chunk_overlap=subsection_overlap,
|
||||
blurb_size=blurb_size,
|
||||
chunk_splitter=chunk_splitter,
|
||||
mini_chunk_splitter=mini_chunk_splitter
|
||||
if enable_multipass and chunk_text.strip()
|
||||
else None,
|
||||
blurb=extract_blurb(section_text, blurb_splitter),
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix=metadata_suffix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
)
|
||||
chunks.extend(large_section_chunks)
|
||||
continue
|
||||
@@ -193,60 +252,75 @@ def chunk_document(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_size),
|
||||
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
|
||||
content_summary=chunk_text,
|
||||
blurb=extract_blurb(chunk_text, blurb_splitter),
|
||||
content=chunk_text,
|
||||
source_links=link_offsets,
|
||||
section_continuation=False,
|
||||
metadata_suffix=metadata_suffix,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
|
||||
if enable_multipass and chunk_text.strip()
|
||||
else None,
|
||||
)
|
||||
)
|
||||
link_offsets = {0: section_link_text}
|
||||
chunk_text = section_text
|
||||
|
||||
# Once we hit the end, if we're still in the process of building a chunk, add what we have
|
||||
# NOTE: if it's just whitespace, ignore it.
|
||||
if chunk_text.strip():
|
||||
# Once we hit the end, if we're still in the process of building a chunk, add what we have. If there is only whitespace left
|
||||
# then don't include it. If there are no chunks at all from the doc, we can just create a single chunk with the title.
|
||||
if chunk_text.strip() or not chunks:
|
||||
chunks.append(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_size),
|
||||
content=f"{title_prefix}{chunk_text}{metadata_suffix}",
|
||||
content_summary=chunk_text,
|
||||
source_links=link_offsets,
|
||||
blurb=extract_blurb(chunk_text, blurb_splitter),
|
||||
content=chunk_text,
|
||||
source_links=link_offsets or {0: section_link_text},
|
||||
section_continuation=False,
|
||||
metadata_suffix=metadata_suffix,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
|
||||
if enable_multipass and chunk_text.strip()
|
||||
else None,
|
||||
)
|
||||
)
|
||||
|
||||
# If the chunk does not have any useable content, it will not be indexed
|
||||
return chunks
|
||||
|
||||
|
||||
def split_chunk_text_into_mini_chunks(
|
||||
chunk_text: str, mini_chunk_size: int = MINI_CHUNK_SIZE
|
||||
) -> list[str]:
|
||||
"""The minichunks won't all have the title prefix or metadata suffix
|
||||
It could be a significant percentage of every minichunk so better to not include it
|
||||
"""
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
token_count_func = get_default_tokenizer().tokenize
|
||||
sentence_aware_splitter = SentenceSplitter(
|
||||
tokenizer=token_count_func, chunk_size=mini_chunk_size, chunk_overlap=0
|
||||
)
|
||||
|
||||
return sentence_aware_splitter.split_text(chunk_text)
|
||||
|
||||
|
||||
class Chunker:
|
||||
@abc.abstractmethod
|
||||
def chunk(self, document: Document) -> list[DocAwareChunk]:
|
||||
def chunk(
|
||||
self,
|
||||
document: Document,
|
||||
) -> list[DocAwareChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DefaultChunker(Chunker):
|
||||
def chunk(self, document: Document) -> list[DocAwareChunk]:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
enable_multipass: bool,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.provider_type = provider_type
|
||||
self.enable_multipass = enable_multipass
|
||||
|
||||
def chunk(
|
||||
self,
|
||||
document: Document,
|
||||
) -> list[DocAwareChunk]:
|
||||
# Specifically for reproducing an issue with gmail
|
||||
if document.source == DocumentSource.GMAIL:
|
||||
logger.debug(f"Chunking {document.semantic_identifier}")
|
||||
return chunk_document(document)
|
||||
return chunk_document(
|
||||
document=document,
|
||||
model_name=self.model_name,
|
||||
provider_type=self.provider_type,
|
||||
enable_multipass=self.enable_multipass,
|
||||
)
|
||||
|
||||
@@ -3,23 +3,21 @@ from abc import abstractmethod
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
|
||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.models import EmbeddingModel as DbEmbeddingModel
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
|
||||
from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.utils.batching import batch_list
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -32,14 +30,21 @@ class IndexingEmbedder(ABC):
|
||||
normalize: bool,
|
||||
query_prefix: str | None,
|
||||
passage_prefix: str | None,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
api_key: str | None,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.normalize = normalize
|
||||
self.query_prefix = query_prefix
|
||||
self.passage_prefix = passage_prefix
|
||||
self.provider_type = provider_type
|
||||
self.api_key = api_key
|
||||
|
||||
@abstractmethod
|
||||
def embed_chunks(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]:
|
||||
def embed_chunks(
|
||||
self,
|
||||
chunks: list[DocAwareChunk],
|
||||
) -> list[IndexChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -50,10 +55,12 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
normalize: bool,
|
||||
query_prefix: str | None,
|
||||
passage_prefix: str | None,
|
||||
provider_type: EmbeddingProvider | None = None,
|
||||
api_key: str | None = None,
|
||||
provider_type: str | None = None,
|
||||
):
|
||||
super().__init__(model_name, normalize, query_prefix, passage_prefix)
|
||||
super().__init__(
|
||||
model_name, normalize, query_prefix, passage_prefix, provider_type, api_key
|
||||
)
|
||||
self.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE # Currently not customizable
|
||||
|
||||
self.embedding_model = EmbeddingModel(
|
||||
@@ -66,72 +73,63 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=INDEXING_MODEL_SERVER_PORT,
|
||||
retrim_content=True,
|
||||
)
|
||||
|
||||
def embed_chunks(
|
||||
self,
|
||||
chunks: list[DocAwareChunk],
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
|
||||
) -> list[IndexChunk]:
|
||||
# Cache the Title embeddings to only have to do it once
|
||||
title_embed_dict: dict[str, list[float]] = {}
|
||||
embedded_chunks: list[IndexChunk] = []
|
||||
# All chunks at this point must have some non-empty content
|
||||
flat_chunk_texts: list[str] = []
|
||||
for chunk in chunks:
|
||||
chunk_text = (
|
||||
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_semantic}"
|
||||
) or chunk.source_document.get_title_for_document_index()
|
||||
|
||||
# Create Mini Chunks for more precise matching of details
|
||||
# Off by default with unedited settings
|
||||
chunk_texts = []
|
||||
chunk_mini_chunks_count = {}
|
||||
for chunk_ind, chunk in enumerate(chunks):
|
||||
chunk_texts.append(chunk.content)
|
||||
mini_chunk_texts = (
|
||||
split_chunk_text_into_mini_chunks(chunk.content_summary)
|
||||
if enable_mini_chunk
|
||||
else []
|
||||
)
|
||||
chunk_texts.extend(mini_chunk_texts)
|
||||
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
|
||||
if not chunk_text:
|
||||
# This should never happen, the document would have been dropped
|
||||
# before getting to this point
|
||||
raise ValueError(f"Chunk has no content: {chunk.to_short_descriptor()}")
|
||||
|
||||
# Batching for embedding
|
||||
text_batches = batch_list(chunk_texts, batch_size)
|
||||
flat_chunk_texts.append(chunk_text)
|
||||
|
||||
embeddings: list[list[float]] = []
|
||||
len_text_batches = len(text_batches)
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}")
|
||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
||||
# value for the set loss
|
||||
embeddings.extend(
|
||||
self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE)
|
||||
)
|
||||
if chunk.mini_chunk_texts:
|
||||
flat_chunk_texts.extend(chunk.mini_chunk_texts)
|
||||
|
||||
# Replace line above with the line below for easy debugging of indexing flow
|
||||
# skipping the actual model
|
||||
# embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))])
|
||||
embeddings = self.embedding_model.encode(
|
||||
flat_chunk_texts, text_type=EmbedTextType.PASSAGE
|
||||
)
|
||||
|
||||
chunk_titles = {
|
||||
chunk.source_document.get_title_for_document_index() for chunk in chunks
|
||||
}
|
||||
|
||||
# Drop any None or empty strings
|
||||
# If there is no title or the title is empty, the title embedding field will be null
|
||||
# which is ok, it just won't contribute at all to the scoring.
|
||||
chunk_titles_list = [title for title in chunk_titles if title]
|
||||
|
||||
# Embed Titles in batches
|
||||
title_batches = batch_list(chunk_titles_list, batch_size)
|
||||
len_title_batches = len(title_batches)
|
||||
for ind_batch, title_batch in enumerate(title_batches, start=1):
|
||||
logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}")
|
||||
# Cache the Title embeddings to only have to do it once
|
||||
title_embed_dict: dict[str, Embedding] = {}
|
||||
if chunk_titles_list:
|
||||
title_embeddings = self.embedding_model.encode(
|
||||
title_batch, text_type=EmbedTextType.PASSAGE
|
||||
chunk_titles_list, text_type=EmbedTextType.PASSAGE
|
||||
)
|
||||
title_embed_dict.update(
|
||||
{title: vector for title, vector in zip(title_batch, title_embeddings)}
|
||||
{
|
||||
title: vector
|
||||
for title, vector in zip(chunk_titles_list, title_embeddings)
|
||||
}
|
||||
)
|
||||
|
||||
# Mapping embeddings to chunks
|
||||
embedded_chunks: list[IndexChunk] = []
|
||||
embedding_ind_start = 0
|
||||
for chunk_ind, chunk in enumerate(chunks):
|
||||
num_embeddings = chunk_mini_chunks_count[chunk_ind]
|
||||
for chunk in chunks:
|
||||
num_embeddings = 1 + (
|
||||
len(chunk.mini_chunk_texts) if chunk.mini_chunk_texts else 0
|
||||
)
|
||||
chunk_embeddings = embeddings[
|
||||
embedding_ind_start : embedding_ind_start + num_embeddings
|
||||
]
|
||||
@@ -165,6 +163,19 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
|
||||
return embedded_chunks
|
||||
|
||||
@classmethod
|
||||
def from_db_embedding_model(
|
||||
cls, embedding_model: DbEmbeddingModel
|
||||
) -> "DefaultIndexingEmbedder":
|
||||
return cls(
|
||||
model_name=embedding_model.model_name,
|
||||
normalize=embedding_model.normalize,
|
||||
query_prefix=embedding_model.query_prefix,
|
||||
passage_prefix=embedding_model.passage_prefix,
|
||||
provider_type=embedding_model.provider_type,
|
||||
api_key=embedding_model.api_key,
|
||||
)
|
||||
|
||||
|
||||
def get_embedding_model_from_db_embedding_model(
|
||||
db_session: Session, index_model_status: IndexModelStatus = IndexModelStatus.PRESENT
|
||||
@@ -184,4 +195,6 @@ def get_embedding_model_from_db_embedding_model(
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
api_key=db_embedding_model.api_key,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from functools import partial
|
||||
from itertools import chain
|
||||
from typing import Protocol
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.configs.app_configs import ENABLE_MULTIPASS_INDEXING
|
||||
from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
@@ -26,6 +26,7 @@ from danswer.indexing.chunker import DefaultChunker
|
||||
from danswer.indexing.embedder import IndexingEmbedder
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.search.search_settings import get_search_settings
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
@@ -34,7 +35,9 @@ logger = setup_logger()
|
||||
|
||||
class IndexingPipelineProtocol(Protocol):
|
||||
def __call__(
|
||||
self, documents: list[Document], index_attempt_metadata: IndexAttemptMetadata
|
||||
self,
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
) -> tuple[int, int]:
|
||||
...
|
||||
|
||||
@@ -116,7 +119,7 @@ def index_doc_batch(
|
||||
chunker: Chunker,
|
||||
embedder: IndexingEmbedder,
|
||||
document_index: DocumentIndex,
|
||||
documents: list[Document],
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
ignore_time_skip: bool = False,
|
||||
@@ -124,6 +127,32 @@ def index_doc_batch(
|
||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||
Note that the documents should already be batched at this point so that it does not inflate the
|
||||
memory requirements"""
|
||||
documents = []
|
||||
for document in document_batch:
|
||||
empty_contents = not any(section.text.strip() for section in document.sections)
|
||||
if (
|
||||
(not document.title or not document.title.strip())
|
||||
and not document.semantic_identifier.strip()
|
||||
and empty_contents
|
||||
):
|
||||
# Skip documents that have neither title nor content
|
||||
# If the document doesn't have either, then there is no useful information in it
|
||||
# This is again verified later in the pipeline after chunking but at that point there should
|
||||
# already be no documents that are empty.
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as it has neither title nor content."
|
||||
)
|
||||
elif (
|
||||
document.title is not None and not document.title.strip() and empty_contents
|
||||
):
|
||||
# The title is explicitly empty ("" and not None) and the document is empty
|
||||
# so when building the chunk text representation, it will be empty and unuseable
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as the chunks will be empty."
|
||||
)
|
||||
else:
|
||||
documents.append(document)
|
||||
|
||||
document_ids = [document.id for document in documents]
|
||||
db_docs = get_documents_by_ids(
|
||||
document_ids=document_ids,
|
||||
@@ -138,6 +167,11 @@ def index_doc_batch(
|
||||
if not ignore_time_skip
|
||||
else documents
|
||||
)
|
||||
|
||||
# No docs to update either because the batch is empty or every doc was already indexed
|
||||
if not updatable_docs:
|
||||
return 0, 0
|
||||
|
||||
updatable_ids = [doc.id for doc in updatable_docs]
|
||||
|
||||
# Create records in the source of truth about these documents,
|
||||
@@ -149,14 +183,20 @@ def index_doc_batch(
|
||||
)
|
||||
|
||||
logger.debug("Starting chunking")
|
||||
|
||||
# The first chunk additionally contains the Title of the Document
|
||||
chunks: list[DocAwareChunk] = list(
|
||||
chain(*[chunker.chunk(document=document) for document in updatable_docs])
|
||||
)
|
||||
chunks: list[DocAwareChunk] = [
|
||||
chunk
|
||||
for document in updatable_docs
|
||||
for chunk in chunker.chunk(document=document)
|
||||
]
|
||||
|
||||
logger.debug("Starting embedding")
|
||||
chunks_with_embeddings = embedder.embed_chunks(chunks=chunks)
|
||||
chunks_with_embeddings = (
|
||||
embedder.embed_chunks(
|
||||
chunks=chunks,
|
||||
)
|
||||
if chunks
|
||||
else []
|
||||
)
|
||||
|
||||
# Acquires a lock on the documents so that no other process can modify them
|
||||
# NOTE: don't need to acquire till here, since this is when the actual race condition
|
||||
@@ -191,7 +231,7 @@ def index_doc_batch(
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in chunks]}"
|
||||
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in access_aware_chunks]}"
|
||||
)
|
||||
# A document will not be spread across different batches, so all the
|
||||
# documents with chunks in this set, are fully represented by the chunks
|
||||
@@ -215,7 +255,7 @@ def index_doc_batch(
|
||||
)
|
||||
|
||||
return len([r for r in insertion_records if r.already_existed is False]), len(
|
||||
chunks
|
||||
access_aware_chunks
|
||||
)
|
||||
|
||||
|
||||
@@ -227,8 +267,18 @@ def build_indexing_pipeline(
|
||||
chunker: Chunker | None = None,
|
||||
ignore_time_skip: bool = False,
|
||||
) -> IndexingPipelineProtocol:
|
||||
"""Builds a pipline which takes in a list (batch) of docs and indexes them."""
|
||||
chunker = chunker or DefaultChunker()
|
||||
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
|
||||
search_settings = get_search_settings()
|
||||
multipass = (
|
||||
search_settings.multipass_indexing
|
||||
if search_settings
|
||||
else ENABLE_MULTIPASS_INDEXING
|
||||
)
|
||||
chunker = chunker or DefaultChunker(
|
||||
model_name=embedder.model_name,
|
||||
provider_type=embedder.provider_type,
|
||||
enable_multipass=multipass,
|
||||
)
|
||||
|
||||
return partial(
|
||||
index_doc_batch,
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import EmbeddingModel
|
||||
@@ -13,9 +14,6 @@ if TYPE_CHECKING:
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
Embedding = list[float]
|
||||
|
||||
|
||||
class ChunkEmbedding(BaseModel):
|
||||
full_embedding: Embedding
|
||||
mini_chunk_embeddings: list[Embedding]
|
||||
@@ -36,15 +34,17 @@ class DocAwareChunk(BaseChunk):
|
||||
# During inference we only have access to the document id and do not reconstruct the Document
|
||||
source_document: Document
|
||||
|
||||
# The Vespa documents require a separate highlight field. Since it is stored as a duplicate anyway,
|
||||
# it's easier to just store a not prefixed/suffixed string for the highlighting
|
||||
# Also during the chunking, this non-prefixed/suffixed string is used for mini-chunks
|
||||
content_summary: str
|
||||
# This could be an empty string if the title is too long and taking up too much of the chunk
|
||||
# This does not mean necessarily that the document does not have a title
|
||||
title_prefix: str
|
||||
|
||||
# During indexing we also (optionally) build a metadata string from the metadata dict
|
||||
# This is also indexed so that we can strip it out after indexing, this way it supports
|
||||
# multiple iterations of metadata representation for backwards compatibility
|
||||
metadata_suffix: str
|
||||
metadata_suffix_semantic: str
|
||||
metadata_suffix_keyword: str
|
||||
|
||||
mini_chunk_texts: list[str] | None
|
||||
|
||||
def to_short_descriptor(self) -> str:
|
||||
"""Used when logging the identity of a chunk"""
|
||||
|
||||
@@ -34,8 +34,8 @@ from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import message_generator_to_string_generator
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.tools.custom.custom_tool_prompt_builder import (
|
||||
build_user_message_for_custom_tool_for_non_tool_calling_llm,
|
||||
)
|
||||
@@ -99,6 +99,7 @@ class Answer:
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
llm: LLM,
|
||||
prompt_config: PromptConfig,
|
||||
force_use_tool: ForceUseTool,
|
||||
# must be the same length as `docs`. If None, all docs are considered "relevant"
|
||||
message_history: list[PreviousMessage] | None = None,
|
||||
single_message_history: str | None = None,
|
||||
@@ -107,10 +108,8 @@ class Answer:
|
||||
latest_query_files: list[InMemoryChatFile] | None = None,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
tools: list[Tool] | None = None,
|
||||
# if specified, tells the LLM to always this tool
|
||||
# NOTE: for native tool-calling, this is only supported by OpenAI atm,
|
||||
# but we only support them anyways
|
||||
force_use_tool: ForceUseTool | None = None,
|
||||
# if set to True, then never use the LLMs provided tool-calling functonality
|
||||
skip_explicit_tool_calling: bool = False,
|
||||
# Returns the full document sections text from the search tool
|
||||
@@ -129,6 +128,7 @@ class Answer:
|
||||
|
||||
self.tools = tools or []
|
||||
self.force_use_tool = force_use_tool
|
||||
|
||||
self.skip_explicit_tool_calling = skip_explicit_tool_calling
|
||||
|
||||
self.message_history = message_history or []
|
||||
@@ -139,7 +139,10 @@ class Answer:
|
||||
self.prompt_config = prompt_config
|
||||
|
||||
self.llm = llm
|
||||
self.llm_tokenizer = get_default_llm_tokenizer()
|
||||
self.llm_tokenizer = get_tokenizer(
|
||||
provider_type=llm.config.model_provider,
|
||||
model_name=llm.config.model_name,
|
||||
)
|
||||
|
||||
self._final_prompt: list[BaseMessage] | None = None
|
||||
|
||||
@@ -187,7 +190,7 @@ class Answer:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
if self.force_use_tool and self.force_use_tool.args is not None:
|
||||
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
|
||||
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
|
||||
# / need to generate the args
|
||||
tool_call_chunk = AIMessageChunk(
|
||||
@@ -221,7 +224,7 @@ class Answer:
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=final_tool_definitions if final_tool_definitions else None,
|
||||
tool_choice="required" if self.force_use_tool else None,
|
||||
tool_choice="required" if self.force_use_tool.force_use else None,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
@@ -240,12 +243,26 @@ class Answer:
|
||||
# if we have a tool call, we need to call the tool
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
for tool_call_request in tool_call_requests:
|
||||
tool = [
|
||||
known_tools_by_name = [
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
][0]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
if self.tools:
|
||||
tool = self.tools[0]
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
tool = known_tools_by_name[0]
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool and self.force_use_tool.args
|
||||
if self.force_use_tool.tool_name == tool.name
|
||||
and self.force_use_tool.args
|
||||
else tool_call_request["args"]
|
||||
)
|
||||
|
||||
@@ -263,9 +280,13 @@ class Answer:
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
self._update_prompt_builder_for_search_tool(prompt_builder, [])
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = [
|
||||
img_generation_result["url"]
|
||||
for img_generation_result in tool_runner.tool_final_result().tool_result
|
||||
]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question,
|
||||
query=self.question, img_urls=img_urls
|
||||
)
|
||||
)
|
||||
yield tool_runner.tool_final_result()
|
||||
@@ -286,7 +307,7 @@ class Answer:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
chosen_tool_and_args: tuple[Tool, dict] | None = None
|
||||
|
||||
if self.force_use_tool:
|
||||
if self.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = next(
|
||||
iter(
|
||||
@@ -303,7 +324,7 @@ class Answer:
|
||||
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.args
|
||||
if self.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
@@ -462,6 +483,7 @@ class Answer:
|
||||
]
|
||||
elif message.id == FINAL_CONTEXT_DOCUMENTS:
|
||||
final_context_docs = cast(list[LlmDoc], message.response)
|
||||
|
||||
elif (
|
||||
message.id == SEARCH_DOC_CONTENT_ID
|
||||
and not self._return_contexts
|
||||
|
||||
@@ -16,6 +16,7 @@ from danswer.configs.constants import MessageType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import ChatMessage
|
||||
@@ -32,6 +33,7 @@ class PreviousMessage(BaseModel):
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
@@ -49,6 +51,14 @@ class PreviousMessage(BaseModel):
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
|
||||
@@ -12,8 +12,8 @@ from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import check_message_tokens
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import drop_messages_history_overflow
|
||||
@@ -66,7 +66,10 @@ class AnswerPromptBuilder:
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||
self.user_message_and_token_cnt: tuple[HumanMessage, int] | None = None
|
||||
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
llm_tokenizer = get_tokenizer(
|
||||
provider_type=llm_config.model_provider,
|
||||
model_name=llm_config.model_name,
|
||||
)
|
||||
self.llm_tokenizer_encode_func = cast(
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
@@ -111,8 +114,24 @@ class AnswerPromptBuilder:
|
||||
final_messages_with_tokens.append(self.user_message_and_token_cnt)
|
||||
|
||||
if tool_call_summary:
|
||||
final_messages_with_tokens.append((tool_call_summary.tool_call_request, 0))
|
||||
final_messages_with_tokens.append((tool_call_summary.tool_call_result, 0))
|
||||
final_messages_with_tokens.append(
|
||||
(
|
||||
tool_call_summary.tool_call_request,
|
||||
check_message_tokens(
|
||||
tool_call_summary.tool_call_request,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
)
|
||||
final_messages_with_tokens.append(
|
||||
(
|
||||
tool_call_summary.tool_call_result,
|
||||
check_message_tokens(
|
||||
tool_call_summary.tool_call_result,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
|
||||
@@ -2,7 +2,6 @@ from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.persona import get_default_prompt__read_only
|
||||
@@ -29,17 +28,19 @@ from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.search_settings import get_multilingual_expansion
|
||||
|
||||
|
||||
def get_prompt_tokens(prompt_config: PromptConfig) -> int:
|
||||
# Note: currently custom prompts do not allow datetime aware, only default prompts
|
||||
multilingual_expansion = get_multilingual_expansion()
|
||||
return (
|
||||
check_number_of_tokens(prompt_config.system_prompt)
|
||||
+ check_number_of_tokens(prompt_config.task_prompt)
|
||||
+ CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT
|
||||
+ CITATION_STATEMENT_TOKEN_CNT
|
||||
+ CITATION_REMINDER_TOKEN_CNT
|
||||
+ (LANGUAGE_HINT_TOKEN_CNT if bool(MULTILINGUAL_QUERY_EXPANSION) else 0)
|
||||
+ (LANGUAGE_HINT_TOKEN_CNT if multilingual_expansion else 0)
|
||||
+ (ADDITIONAL_INFO_TOKEN_CNT if prompt_config.datetime_aware else 0)
|
||||
)
|
||||
|
||||
@@ -135,7 +136,10 @@ def build_citations_user_message(
|
||||
all_doc_useful: bool,
|
||||
history_message: str = "",
|
||||
) -> HumanMessage:
|
||||
task_prompt_with_reminder = build_task_prompt_reminders(prompt_config)
|
||||
multilingual_expansion = get_multilingual_expansion()
|
||||
task_prompt_with_reminder = build_task_prompt_reminders(
|
||||
prompt=prompt_config, use_language_hint=bool(multilingual_expansion)
|
||||
)
|
||||
|
||||
if context_docs:
|
||||
context_docs_str = build_complete_context_str(context_docs)
|
||||
|
||||
@@ -2,7 +2,6 @@ from langchain.schema.messages import HumanMessage
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import LANGUAGE_HINT
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
@@ -12,6 +11,7 @@ from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import build_complete_context_str
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.search_settings import get_search_settings
|
||||
|
||||
|
||||
def _build_weak_llm_quotes_prompt(
|
||||
@@ -19,7 +19,6 @@ def _build_weak_llm_quotes_prompt(
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
use_language_hint: bool,
|
||||
) -> HumanMessage:
|
||||
"""Since Danswer supports a variety of LLMs, this less demanding prompt is provided
|
||||
as an option to use with weaker LLMs such as small version, low float precision, quantized,
|
||||
@@ -48,8 +47,12 @@ def _build_strong_llm_quotes_prompt(
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
use_language_hint: bool,
|
||||
) -> HumanMessage:
|
||||
search_settings = get_search_settings()
|
||||
use_language_hint = (
|
||||
bool(search_settings.multilingual_expansion) if search_settings else False
|
||||
)
|
||||
|
||||
context_block = ""
|
||||
if context_docs:
|
||||
context_docs_str = build_complete_context_str(context_docs)
|
||||
@@ -79,7 +82,6 @@ def build_quotes_user_message(
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
|
||||
) -> HumanMessage:
|
||||
prompt_builder = (
|
||||
_build_weak_llm_quotes_prompt
|
||||
@@ -92,7 +94,6 @@ def build_quotes_user_message(
|
||||
context_docs=context_docs,
|
||||
history_str=history_str,
|
||||
prompt=prompt,
|
||||
use_language_hint=use_language_hint,
|
||||
)
|
||||
|
||||
|
||||
@@ -101,7 +102,6 @@ def build_quotes_prompt(
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
|
||||
) -> HumanMessage:
|
||||
prompt_builder = (
|
||||
_build_weak_llm_quotes_prompt
|
||||
@@ -114,5 +114,4 @@ def build_quotes_prompt(
|
||||
context_docs=context_docs,
|
||||
history_str=history_str,
|
||||
prompt=prompt,
|
||||
use_language_hint=use_language_hint,
|
||||
)
|
||||
|
||||
@@ -14,8 +14,8 @@ from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import tokenizer_trim_content
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import InferenceSection
|
||||
@@ -28,6 +28,9 @@ logger = setup_logger()
|
||||
T = TypeVar("T", bound=LlmDoc | InferenceChunk | InferenceSection)
|
||||
|
||||
_METADATA_TOKEN_ESTIMATE = 75
|
||||
# Title and additional tokens as part of the tool message json
|
||||
# this is only used to log a warning so we can be more forgiving with the buffer
|
||||
_OVERCOUNT_ESTIMATE = 256
|
||||
|
||||
|
||||
class PruningError(Exception):
|
||||
@@ -135,8 +138,12 @@ def _apply_pruning(
|
||||
is_manually_selected_docs: bool,
|
||||
use_sections: bool,
|
||||
using_tool_message: bool,
|
||||
llm_config: LLMConfig,
|
||||
) -> list[InferenceSection]:
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
llm_tokenizer = get_tokenizer(
|
||||
provider_type=llm_config.model_provider,
|
||||
model_name=llm_config.model_name,
|
||||
)
|
||||
sections = deepcopy(sections) # don't modify in place
|
||||
|
||||
# re-order docs with all the "relevant" docs at the front
|
||||
@@ -165,27 +172,36 @@ def _apply_pruning(
|
||||
)
|
||||
)
|
||||
|
||||
section_tokens = len(llm_tokenizer.encode(section_str))
|
||||
section_token_count = len(llm_tokenizer.encode(section_str))
|
||||
# if not using sections (specifically, using Sections where each section maps exactly to the one center chunk),
|
||||
# truncate chunks that are way too long. This can happen if the embedding model tokenizer is different
|
||||
# than the LLM tokenizer
|
||||
if (
|
||||
not is_manually_selected_docs
|
||||
and not use_sections
|
||||
and section_tokens > DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE
|
||||
and section_token_count
|
||||
> DOC_EMBEDDING_CONTEXT_SIZE + _METADATA_TOKEN_ESTIMATE
|
||||
):
|
||||
logger.warning(
|
||||
"Found more tokens in Section than expected, "
|
||||
"likely mismatch between embedding and LLM tokenizers. Trimming content..."
|
||||
)
|
||||
if (
|
||||
section_token_count
|
||||
> DOC_EMBEDDING_CONTEXT_SIZE
|
||||
+ _METADATA_TOKEN_ESTIMATE
|
||||
+ _OVERCOUNT_ESTIMATE
|
||||
):
|
||||
# If the section is just a little bit over, it is likely due to the additional tool message tokens
|
||||
# no need to record this, the content will be trimmed just in case
|
||||
logger.info(
|
||||
"Found more tokens in Section than expected, "
|
||||
"likely mismatch between embedding and LLM tokenizers. Trimming content..."
|
||||
)
|
||||
section.combined_content = tokenizer_trim_content(
|
||||
content=section.combined_content,
|
||||
desired_length=DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
tokenizer=llm_tokenizer,
|
||||
)
|
||||
section_tokens = DOC_EMBEDDING_CONTEXT_SIZE
|
||||
section_token_count = DOC_EMBEDDING_CONTEXT_SIZE
|
||||
|
||||
total_tokens += section_tokens
|
||||
total_tokens += section_token_count
|
||||
if total_tokens > token_limit:
|
||||
final_section_ind = ind
|
||||
break
|
||||
@@ -273,6 +289,7 @@ def prune_sections(
|
||||
is_manually_selected_docs=document_pruning_config.is_manually_selected_docs,
|
||||
use_sections=document_pruning_config.use_sections, # Now default True
|
||||
using_tool_message=document_pruning_config.using_tool_message,
|
||||
llm_config=llm_config,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -125,6 +125,30 @@ def extract_citations_from_stream(
|
||||
length_to_add -= diff
|
||||
continue
|
||||
|
||||
# Handle edge case where LLM outputs citation itself
|
||||
# by allowing it to generate citations on its own.
|
||||
if curr_segment.startswith("[["):
|
||||
match = re.match(r"\[\[(\d+)\]\]", curr_segment)
|
||||
if match:
|
||||
try:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Manual LLM citation didn't properly cite documents {e}"
|
||||
)
|
||||
else:
|
||||
# Will continue attempt on next loops
|
||||
logger.warning(
|
||||
"Manual LLM citation wasn't able to close brackets"
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
# Replace the citation in the current segment
|
||||
@@ -162,6 +186,7 @@ def extract_citations_from_stream(
|
||||
+ curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(curr_segment) - prev_length
|
||||
|
||||
last_citation_end = end + length_to_add
|
||||
|
||||
if last_citation_end > 0:
|
||||
|
||||
@@ -17,7 +17,6 @@ from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||
from danswer.prompts.constants import ANSWER_PAT
|
||||
from danswer.prompts.constants import QUOTE_PAT
|
||||
from danswer.prompts.constants import UNCERTAINTY_PAT
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import clean_model_quote
|
||||
@@ -27,6 +26,7 @@ from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
answer_pattern = re.compile(r'{\s*"answer"\s*:\s*"', re.IGNORECASE)
|
||||
|
||||
|
||||
def _extract_answer_quotes_freeform(
|
||||
@@ -166,11 +166,8 @@ def process_answer(
|
||||
into an Answer and Quotes AND (2) after the complete streaming response
|
||||
has been received to process the model output into an Answer and Quotes."""
|
||||
answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt)
|
||||
if answer == UNCERTAINTY_PAT or not answer:
|
||||
if answer == UNCERTAINTY_PAT:
|
||||
logger.debug("Answer matched UNCERTAINTY_PAT")
|
||||
else:
|
||||
logger.debug("No answer extracted from raw output")
|
||||
if not answer:
|
||||
logger.debug("No answer extracted from raw output")
|
||||
return DanswerAnswer(answer=None), DanswerQuotes(quotes=[])
|
||||
|
||||
logger.info(f"Answer: {answer}")
|
||||
@@ -227,22 +224,27 @@ def process_model_tokens(
|
||||
found_answer_start = False if is_json_prompt else True
|
||||
found_answer_end = False
|
||||
hold_quote = ""
|
||||
|
||||
for token in tokens:
|
||||
model_previous = model_output
|
||||
model_output += token
|
||||
|
||||
if not found_answer_start and '{"answer":"' in re.sub(r"\s", "", model_output):
|
||||
# Note, if the token that completes the pattern has additional text, for example if the token is "?
|
||||
# Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the
|
||||
# event that the model outputs the UNCERTAINTY_PAT
|
||||
found_answer_start = True
|
||||
if not found_answer_start:
|
||||
m = answer_pattern.search(model_output)
|
||||
if m:
|
||||
found_answer_start = True
|
||||
|
||||
# Prevent heavy cases of hallucinations where model is not even providing a json until later
|
||||
if is_json_prompt and len(model_output) > 40:
|
||||
logger.warning("LLM did not produce json as prompted")
|
||||
found_answer_end = True
|
||||
# Prevent heavy cases of hallucinations where model is never providing a JSON
|
||||
# We want to quickly update the user - not stream forever
|
||||
if is_json_prompt and len(model_output) > 70:
|
||||
logger.warning("LLM did not produce json as prompted")
|
||||
found_answer_end = True
|
||||
continue
|
||||
|
||||
continue
|
||||
remaining = model_output[m.end() :]
|
||||
if len(remaining) > 0:
|
||||
yield DanswerAnswerPiece(answer_piece=remaining)
|
||||
continue
|
||||
|
||||
if found_answer_start and not found_answer_end:
|
||||
if is_json_prompt and _stream_json_answer_end(model_previous, token):
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user