mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 15:55:45 +00:00
Compare commits
52 Commits
final_grap
...
debug-test
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5f82de7c45 | ||
|
|
45f67368a2 | ||
|
|
014ba9e220 | ||
|
|
ba64543dd7 | ||
|
|
18c62a0c24 | ||
|
|
33f555922c | ||
|
|
05f6f6d5b5 | ||
|
|
19dae1d870 | ||
|
|
6d859bd37c | ||
|
|
122e3fa3fa | ||
|
|
87b542b335 | ||
|
|
00229d2abe | ||
|
|
5f2644985c | ||
|
|
c82a36ad68 | ||
|
|
16d1c19d9f | ||
|
|
9f179940f8 | ||
|
|
8a8e2b310e | ||
|
|
2274cab554 | ||
|
|
ef104e9a82 | ||
|
|
a575d7f1eb | ||
|
|
f404c4b448 | ||
|
|
3884f1d70a | ||
|
|
bc9d5fece7 | ||
|
|
bb279a8580 | ||
|
|
a9403016c9 | ||
|
|
f3cea79c1c | ||
|
|
54bb79303c | ||
|
|
d3dfabb20e | ||
|
|
7d1ec1095c | ||
|
|
f531d071af | ||
|
|
4218814385 | ||
|
|
e662e3b57d | ||
|
|
2073820e33 | ||
|
|
5f25b243c5 | ||
|
|
a9427f190a | ||
|
|
18fbe9d7e8 | ||
|
|
75c9b1cafe | ||
|
|
632a8f700b | ||
|
|
cd58c96014 | ||
|
|
c5032d25c9 | ||
|
|
72acde6fd4 | ||
|
|
5596a68d08 | ||
|
|
5b18409c89 | ||
|
|
84272af5ac | ||
|
|
6bef70c8b7 | ||
|
|
7f7559e3d2 | ||
|
|
7ba829a585 | ||
|
|
8b2ecb4eab | ||
|
|
2dd3870504 | ||
|
|
df464fc54b | ||
|
|
96b98fbc4a | ||
|
|
66cf67d04d |
@@ -27,6 +27,11 @@ jobs:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Install build-essential
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y build-essential
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
|
||||
@@ -37,9 +37,9 @@ jobs:
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
|
||||
6
.github/workflows/pr-python-checks.yml
vendored
6
.github/workflows/pr-python-checks.yml
vendored
@@ -24,9 +24,9 @@ jobs:
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Run MyPy
|
||||
run: |
|
||||
|
||||
@@ -39,8 +39,8 @@ jobs:
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
|
||||
4
.github/workflows/pr-python-tests.yml
vendored
4
.github/workflows/pr-python-tests.yml
vendored
@@ -29,8 +29,8 @@ jobs:
|
||||
- name: Install Dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
|
||||
9
.github/workflows/run-it.yml
vendored
9
.github/workflows/run-it.yml
vendored
@@ -13,8 +13,7 @@ env:
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
runs-on:
|
||||
group: 'arm64-image-builders'
|
||||
runs-on: Amd64
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
@@ -41,7 +40,7 @@ jobs:
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
platforms: linux/arm64
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-backend:it
|
||||
cache-from: type=registry,ref=danswer/danswer-backend:it
|
||||
cache-to: |
|
||||
@@ -53,7 +52,7 @@ jobs:
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/arm64
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-model-server:it
|
||||
cache-from: type=registry,ref=danswer/danswer-model-server:it
|
||||
cache-to: |
|
||||
@@ -65,7 +64,7 @@ jobs:
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/arm64
|
||||
platforms: linux/amd64
|
||||
tags: danswer/integration-test-runner:it
|
||||
cache-from: type=registry,ref=danswer/integration-test-runner:it
|
||||
cache-to: |
|
||||
|
||||
@@ -41,6 +41,8 @@ RUN apt-get update && \
|
||||
COPY ./requirements/default.txt /tmp/requirements.txt
|
||||
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
--retries 5 \
|
||||
--timeout 30 \
|
||||
-r /tmp/requirements.txt \
|
||||
-r /tmp/ee-requirements.txt && \
|
||||
pip uninstall -y py && \
|
||||
|
||||
@@ -15,7 +15,10 @@ ENV DANSWER_VERSION=${DANSWER_VERSION} \
|
||||
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
|
||||
|
||||
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
--retries 5 \
|
||||
--timeout 30 \
|
||||
-r /tmp/requirements.txt
|
||||
|
||||
RUN apt-get remove -y --allow-remove-essential perl-base && \
|
||||
apt-get autoremove -y
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 27 KiB |
@@ -0,0 +1,102 @@
|
||||
"""add_user_delete_cascades
|
||||
|
||||
Revision ID: 1b8206b29c5d
|
||||
Revises: 35e6853a51d5
|
||||
Create Date: 2024-09-18 11:48:59.418726
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1b8206b29c5d"
|
||||
down_revision = "35e6853a51d5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_constraint("credential_user_id_fkey", "credential", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"credential_user_id_fkey",
|
||||
"credential",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.drop_constraint("chat_session_user_id_fkey", "chat_session", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"chat_session_user_id_fkey",
|
||||
"chat_session",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.drop_constraint("chat_folder_user_id_fkey", "chat_folder", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"chat_folder_user_id_fkey",
|
||||
"chat_folder",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.drop_constraint("prompt_user_id_fkey", "prompt", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"prompt_user_id_fkey", "prompt", "user", ["user_id"], ["id"], ondelete="CASCADE"
|
||||
)
|
||||
|
||||
op.drop_constraint("notification_user_id_fkey", "notification", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"notification_user_id_fkey",
|
||||
"notification",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.drop_constraint("inputprompt_user_id_fkey", "inputprompt", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"inputprompt_user_id_fkey",
|
||||
"inputprompt",
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("credential_user_id_fkey", "credential", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"credential_user_id_fkey", "credential", "user", ["user_id"], ["id"]
|
||||
)
|
||||
|
||||
op.drop_constraint("chat_session_user_id_fkey", "chat_session", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"chat_session_user_id_fkey", "chat_session", "user", ["user_id"], ["id"]
|
||||
)
|
||||
|
||||
op.drop_constraint("chat_folder_user_id_fkey", "chat_folder", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"chat_folder_user_id_fkey", "chat_folder", "user", ["user_id"], ["id"]
|
||||
)
|
||||
|
||||
op.drop_constraint("prompt_user_id_fkey", "prompt", type_="foreignkey")
|
||||
op.create_foreign_key("prompt_user_id_fkey", "prompt", "user", ["user_id"], ["id"])
|
||||
|
||||
op.drop_constraint("notification_user_id_fkey", "notification", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"notification_user_id_fkey", "notification", "user", ["user_id"], ["id"]
|
||||
)
|
||||
|
||||
op.drop_constraint("inputprompt_user_id_fkey", "inputprompt", type_="foreignkey")
|
||||
op.create_foreign_key(
|
||||
"inputprompt_user_id_fkey", "inputprompt", "user", ["user_id"], ["id"]
|
||||
)
|
||||
@@ -0,0 +1,64 @@
|
||||
"""server default chosen assistants
|
||||
|
||||
Revision ID: 35e6853a51d5
|
||||
Revises: c99d76fcd298
|
||||
Create Date: 2024-09-13 13:20:32.885317
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "35e6853a51d5"
|
||||
down_revision = "c99d76fcd298"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
DEFAULT_ASSISTANTS = [-2, -1, 0]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Step 1: Update any NULL values to the default value
|
||||
# This upgrades existing users without ordered assistant
|
||||
# to have default assistants set to visible assistants which are
|
||||
# accessible by them.
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user" u
|
||||
SET chosen_assistants = (
|
||||
SELECT jsonb_agg(
|
||||
p.id ORDER BY
|
||||
COALESCE(p.display_priority, 2147483647) ASC,
|
||||
p.id ASC
|
||||
)
|
||||
FROM persona p
|
||||
LEFT JOIN persona__user pu ON p.id = pu.persona_id AND pu.user_id = u.id
|
||||
WHERE p.is_visible = true
|
||||
AND (p.is_public = true OR pu.user_id IS NOT NULL)
|
||||
)
|
||||
WHERE chosen_assistants IS NULL
|
||||
OR chosen_assistants = 'null'
|
||||
OR jsonb_typeof(chosen_assistants) = 'null'
|
||||
OR (jsonb_typeof(chosen_assistants) = 'string' AND chosen_assistants = '"null"')
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 2: Alter the column to make it non-nullable
|
||||
op.alter_column(
|
||||
"user",
|
||||
"chosen_assistants",
|
||||
type_=postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=False,
|
||||
server_default=sa.text(f"'{DEFAULT_ASSISTANTS}'::jsonb"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"user",
|
||||
"chosen_assistants",
|
||||
type_=postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
server_default=None,
|
||||
)
|
||||
@@ -1,65 +0,0 @@
|
||||
"""single tool call per message
|
||||
|
||||
|
||||
Revision ID: 4e8e7ae58189
|
||||
Revises: 5c7fdadae813
|
||||
Create Date: 2024-09-09 10:07:58.008838
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4e8e7ae58189"
|
||||
down_revision = "5c7fdadae813"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create the new column
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("tool_call_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_tool_call",
|
||||
"chat_message",
|
||||
"tool_call",
|
||||
["tool_call_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Migrate existing data
|
||||
op.execute(
|
||||
"UPDATE chat_message SET tool_call_id = (SELECT id FROM tool_call WHERE tool_call.message_id = chat_message.id LIMIT 1)"
|
||||
)
|
||||
|
||||
# Drop the old relationship
|
||||
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
|
||||
op.drop_column("tool_call", "message_id")
|
||||
|
||||
# Add a unique constraint to ensure one-to-one relationship
|
||||
op.create_unique_constraint(
|
||||
"uq_chat_message_tool_call_id", "chat_message", ["tool_call_id"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the old column
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("message_id", sa.INTEGER(), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"tool_call_message_id_fkey", "tool_call", "chat_message", ["message_id"], ["id"]
|
||||
)
|
||||
|
||||
# Migrate data back
|
||||
op.execute(
|
||||
"UPDATE tool_call SET message_id = (SELECT id FROM chat_message WHERE chat_message.tool_call_id = tool_call.id)"
|
||||
)
|
||||
|
||||
# Drop the new column
|
||||
op.drop_constraint("fk_chat_message_tool_call", "chat_message", type_="foreignkey")
|
||||
op.drop_column("chat_message", "tool_call_id")
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Add last synced and last modified to document table
|
||||
|
||||
Revision ID: 52a219fb5233
|
||||
Revises: f17bf3b0d9f1
|
||||
Revises: f7e58d357687
|
||||
Create Date: 2024-08-28 17:40:46.077470
|
||||
|
||||
"""
|
||||
|
||||
79
backend/alembic/versions/55546a7967ee_assistant_rework.py
Normal file
79
backend/alembic/versions/55546a7967ee_assistant_rework.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""assistant_rework
|
||||
|
||||
Revision ID: 55546a7967ee
|
||||
Revises: 61ff3651add4
|
||||
Create Date: 2024-09-18 17:00:23.755399
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "55546a7967ee"
|
||||
down_revision = "61ff3651add4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Reworking persona and user tables for new assistant features
|
||||
# keep track of user's chosen assistants separate from their `ordering`
|
||||
op.add_column("persona", sa.Column("builtin_persona", sa.Boolean(), nullable=True))
|
||||
op.execute("UPDATE persona SET builtin_persona = default_persona")
|
||||
op.alter_column("persona", "builtin_persona", nullable=False)
|
||||
op.drop_index("_default_persona_name_idx", table_name="persona")
|
||||
op.create_index(
|
||||
"_builtin_persona_name_idx",
|
||||
"persona",
|
||||
["name"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("builtin_persona = true"),
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"user", sa.Column("visible_assistants", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"user", sa.Column("hidden_assistants", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE \"user\" SET visible_assistants = '[]'::jsonb, hidden_assistants = '[]'::jsonb"
|
||||
)
|
||||
op.alter_column(
|
||||
"user",
|
||||
"visible_assistants",
|
||||
nullable=False,
|
||||
server_default=sa.text("'[]'::jsonb"),
|
||||
)
|
||||
op.alter_column(
|
||||
"user",
|
||||
"hidden_assistants",
|
||||
nullable=False,
|
||||
server_default=sa.text("'[]'::jsonb"),
|
||||
)
|
||||
op.drop_column("persona", "default_persona")
|
||||
op.add_column(
|
||||
"persona", sa.Column("is_default_persona", sa.Boolean(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Reverting changes made in upgrade
|
||||
op.drop_column("user", "hidden_assistants")
|
||||
op.drop_column("user", "visible_assistants")
|
||||
op.drop_index("_builtin_persona_name_idx", table_name="persona")
|
||||
|
||||
op.drop_column("persona", "is_default_persona")
|
||||
op.add_column("persona", sa.Column("default_persona", sa.Boolean(), nullable=True))
|
||||
op.execute("UPDATE persona SET default_persona = builtin_persona")
|
||||
op.alter_column("persona", "default_persona", nullable=False)
|
||||
op.drop_column("persona", "builtin_persona")
|
||||
op.create_index(
|
||||
"_default_persona_name_idx",
|
||||
"persona",
|
||||
["name"],
|
||||
unique=True,
|
||||
postgresql_where=sa.text("default_persona = true"),
|
||||
)
|
||||
162
backend/alembic/versions/61ff3651add4_add_permission_syncing.py
Normal file
162
backend/alembic/versions/61ff3651add4_add_permission_syncing.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Add Permission Syncing
|
||||
|
||||
Revision ID: 61ff3651add4
|
||||
Revises: 1b8206b29c5d
|
||||
Create Date: 2024-09-05 13:57:11.770413
|
||||
|
||||
"""
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "61ff3651add4"
|
||||
down_revision = "1b8206b29c5d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Admin user who set up connectors will lose access to the docs temporarily
|
||||
# only way currently to give back access is to rerun from beginning
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"access_type",
|
||||
sa.String(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET access_type = 'PUBLIC' WHERE is_public = true"
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET access_type = 'PRIVATE' WHERE is_public = false"
|
||||
)
|
||||
op.alter_column("connector_credential_pair", "access_type", nullable=False)
|
||||
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column(
|
||||
"auto_sync_options",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("last_time_perm_sync", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.drop_column("connector_credential_pair", "is_public")
|
||||
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("external_user_emails", postgresql.ARRAY(sa.String()), nullable=True),
|
||||
)
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column(
|
||||
"external_user_group_ids", postgresql.ARRAY(sa.String()), nullable=True
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column("is_public", sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"user__external_user_group_id",
|
||||
sa.Column(
|
||||
"user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False
|
||||
),
|
||||
sa.Column("external_user_group_id", sa.String(), nullable=False),
|
||||
sa.Column("cc_pair_id", sa.Integer(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("user_id"),
|
||||
)
|
||||
|
||||
op.drop_column("external_permission", "user_id")
|
||||
op.drop_column("email_to_external_user_cache", "user_id")
|
||||
op.drop_table("permission_sync_run")
|
||||
op.drop_table("external_permission")
|
||||
op.drop_table("email_to_external_user_cache")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"connector_credential_pair",
|
||||
sa.Column("is_public", sa.BOOLEAN(), nullable=True),
|
||||
)
|
||||
op.execute(
|
||||
"UPDATE connector_credential_pair SET is_public = (access_type = 'PUBLIC')"
|
||||
)
|
||||
op.alter_column("connector_credential_pair", "is_public", nullable=False)
|
||||
|
||||
op.drop_column("connector_credential_pair", "auto_sync_options")
|
||||
op.drop_column("connector_credential_pair", "access_type")
|
||||
op.drop_column("connector_credential_pair", "last_time_perm_sync")
|
||||
op.drop_column("document", "external_user_emails")
|
||||
op.drop_column("document", "external_user_group_ids")
|
||||
op.drop_column("document", "is_public")
|
||||
|
||||
op.drop_table("user__external_user_group_id")
|
||||
|
||||
# Drop the enum type at the end of the downgrade
|
||||
op.create_table(
|
||||
"permission_sync_run",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"source_type",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("update_type", sa.String(), nullable=False),
|
||||
sa.Column("cc_pair_id", sa.Integer(), nullable=True),
|
||||
sa.Column(
|
||||
"status",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("error_msg", sa.Text(), nullable=True),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["cc_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"external_permission",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("user_email", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"source_type",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("external_permission_group", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"email_to_external_user_cache",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("external_user_id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.UUID(), nullable=True),
|
||||
sa.Column("user_email", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
27
backend/alembic/versions/797089dfb4d2_persona_start_date.py
Normal file
27
backend/alembic/versions/797089dfb4d2_persona_start_date.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""persona_start_date
|
||||
|
||||
Revision ID: 797089dfb4d2
|
||||
Revises: 55546a7967ee
|
||||
Create Date: 2024-09-11 14:51:49.785835
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "797089dfb4d2"
|
||||
down_revision = "55546a7967ee"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("search_start_date", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "search_start_date")
|
||||
@@ -0,0 +1,43 @@
|
||||
"""non nullable default persona
|
||||
|
||||
Revision ID: bd2921608c3a
|
||||
Revises: 797089dfb4d2
|
||||
Create Date: 2024-09-20 10:28:37.992042
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bd2921608c3a"
|
||||
down_revision = "797089dfb4d2"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Set existing NULL values to False
|
||||
op.execute(
|
||||
"UPDATE persona SET is_default_persona = FALSE WHERE is_default_persona IS NULL"
|
||||
)
|
||||
|
||||
# Alter the column to be not nullable with a default value of False
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"is_default_persona",
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Revert the changes
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"is_default_persona",
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=True,
|
||||
server_default=None,
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
"""add nullable to persona id in Chat Session
|
||||
|
||||
Revision ID: c99d76fcd298
|
||||
Revises: 5c7fdadae813
|
||||
Create Date: 2024-07-09 19:27:01.579697
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "c99d76fcd298"
|
||||
down_revision = "5c7fdadae813"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"chat_session",
|
||||
"persona_id",
|
||||
existing_type=sa.INTEGER(),
|
||||
nullable=False,
|
||||
)
|
||||
@@ -1,7 +1,7 @@
|
||||
"""standard answer match_regex flag
|
||||
|
||||
Revision ID: efb35676026c
|
||||
Revises: 52a219fb5233
|
||||
Revises: 0ebb1d516877
|
||||
Create Date: 2024-09-11 13:55:46.101149
|
||||
|
||||
"""
|
||||
@@ -19,7 +19,9 @@ def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column(
|
||||
"standard_answer",
|
||||
sa.Column("match_regex", sa.Boolean(), nullable=False, default=False),
|
||||
sa.Column(
|
||||
"match_regex", sa.Boolean(), nullable=False, server_default=sa.false()
|
||||
),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
"""add custom headers to tools
|
||||
|
||||
Revision ID: f32615f71aeb
|
||||
Revises: bd2921608c3a
|
||||
Create Date: 2024-09-12 20:26:38.932377
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f32615f71aeb"
|
||||
down_revision = "bd2921608c3a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"tool", sa.Column("custom_headers", postgresql.JSONB(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("tool", "custom_headers")
|
||||
@@ -1,7 +1,7 @@
|
||||
"""add has_web_login column to user
|
||||
|
||||
Revision ID: f7e58d357687
|
||||
Revises: bceb1e139447
|
||||
Revises: ba98eba0f66a
|
||||
Create Date: 2024-09-07 20:20:54.522620
|
||||
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.access.utils import prefix_user_email
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.db.document import get_access_info_for_document
|
||||
from danswer.db.document import get_access_info_for_documents
|
||||
@@ -18,10 +18,13 @@ def _get_access_for_document(
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
if not info:
|
||||
return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
|
||||
|
||||
return DocumentAccess.build(user_ids=info[1], user_groups=[], is_public=info[2])
|
||||
return DocumentAccess.build(
|
||||
user_emails=info[1] if info and info[1] else [],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=info[2] if info else False,
|
||||
)
|
||||
|
||||
|
||||
def get_access_for_document(
|
||||
@@ -34,6 +37,16 @@ def get_access_for_document(
|
||||
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
|
||||
|
||||
|
||||
def get_null_document_access() -> DocumentAccess:
|
||||
return DocumentAccess(
|
||||
user_emails=set(),
|
||||
user_groups=set(),
|
||||
is_public=False,
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
)
|
||||
|
||||
|
||||
def _get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
@@ -42,13 +55,27 @@ def _get_access_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
return {
|
||||
document_id: DocumentAccess.build(
|
||||
user_ids=user_ids, user_groups=[], is_public=is_public
|
||||
doc_access = {
|
||||
document_id: DocumentAccess(
|
||||
user_emails=set([email for email in user_emails if email]),
|
||||
# MIT version will wipe all groups and external groups on update
|
||||
user_groups=set(),
|
||||
is_public=is_public,
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
)
|
||||
for document_id, user_ids, is_public in document_access_info
|
||||
for document_id, user_emails, is_public in document_access_info
|
||||
}
|
||||
|
||||
# Sometimes the document has not be indexed by the indexing job yet, in those cases
|
||||
# the document does not exist and so we use least permissive. Specifically the EE version
|
||||
# checks the MIT version permissions and creates a superset. This ensures that this flow
|
||||
# does not fail even if the Document has not yet been indexed.
|
||||
for doc_id in document_ids:
|
||||
if doc_id not in doc_access:
|
||||
doc_access[doc_id] = get_null_document_access()
|
||||
return doc_access
|
||||
|
||||
|
||||
def get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
@@ -70,7 +97,7 @@ def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
|
||||
matches one entry in the returned set.
|
||||
"""
|
||||
if user:
|
||||
return {prefix_user(str(user.id)), PUBLIC_DOC_PAT}
|
||||
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
|
||||
return {PUBLIC_DOC_PAT}
|
||||
|
||||
|
||||
|
||||
@@ -1,30 +1,72 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from danswer.access.utils import prefix_user
|
||||
from danswer.access.utils import prefix_external_group
|
||||
from danswer.access.utils import prefix_user_email
|
||||
from danswer.access.utils import prefix_user_group
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentAccess:
|
||||
user_ids: set[str] # stringified UUIDs
|
||||
user_groups: set[str] # names of user groups associated with this document
|
||||
class ExternalAccess:
|
||||
# Emails of external users with access to the doc externally
|
||||
external_user_emails: set[str]
|
||||
# Names or external IDs of groups with access to the doc
|
||||
external_user_group_ids: set[str]
|
||||
# Whether the document is public in the external system or Danswer
|
||||
is_public: bool
|
||||
|
||||
def to_acl(self) -> list[str]:
|
||||
return (
|
||||
[prefix_user(user_id) for user_id in self.user_ids]
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentAccess(ExternalAccess):
|
||||
# User emails for Danswer users, None indicates admin
|
||||
user_emails: set[str | None]
|
||||
# Names of user groups associated with this document
|
||||
user_groups: set[str]
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
return set(
|
||||
[
|
||||
prefix_user_email(user_email)
|
||||
for user_email in self.user_emails
|
||||
if user_email
|
||||
]
|
||||
+ [prefix_user_group(group_name) for group_name in self.user_groups]
|
||||
+ [
|
||||
prefix_user_email(user_email)
|
||||
for user_email in self.external_user_emails
|
||||
]
|
||||
+ [
|
||||
# The group names are already prefixed by the source type
|
||||
# This adds an additional prefix of "external_group:"
|
||||
prefix_external_group(group_name)
|
||||
for group_name in self.external_user_group_ids
|
||||
]
|
||||
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool
|
||||
cls,
|
||||
user_emails: list[str | None],
|
||||
user_groups: list[str],
|
||||
external_user_emails: list[str],
|
||||
external_user_group_ids: list[str],
|
||||
is_public: bool,
|
||||
) -> "DocumentAccess":
|
||||
return cls(
|
||||
user_ids={str(user_id) for user_id in user_ids if user_id},
|
||||
external_user_emails={
|
||||
prefix_user_email(external_email)
|
||||
for external_email in external_user_emails
|
||||
},
|
||||
external_user_group_ids={
|
||||
prefix_external_group(external_group_id)
|
||||
for external_group_id in external_user_group_ids
|
||||
},
|
||||
user_emails={
|
||||
prefix_user_email(user_email)
|
||||
for user_email in user_emails
|
||||
if user_email
|
||||
},
|
||||
user_groups=set(user_groups),
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,24 @@
|
||||
def prefix_user(user_id: str) -> str:
|
||||
"""Prefixes a user ID to eliminate collision with group names.
|
||||
This assumes that groups are prefixed with a different prefix."""
|
||||
return f"user_id:{user_id}"
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
|
||||
def prefix_user_email(user_email: str) -> str:
|
||||
"""Prefixes a user email to eliminate collision with group names.
|
||||
This applies to both a Danswer user and an External user, this is to make the query time
|
||||
more efficient"""
|
||||
return f"user_email:{user_email}"
|
||||
|
||||
|
||||
def prefix_user_group(user_group_name: str) -> str:
|
||||
"""Prefixes a user group name to eliminate collision with user IDs.
|
||||
"""Prefixes a user group name to eliminate collision with user emails.
|
||||
This assumes that user ids are prefixed with a different prefix."""
|
||||
return f"group:{user_group_name}"
|
||||
|
||||
|
||||
def prefix_external_group(ext_group_name: str) -> str:
|
||||
"""Prefixes an external group name to eliminate collision with user emails / Danswer groups."""
|
||||
return f"external_group:{ext_group_name}"
|
||||
|
||||
|
||||
def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str:
|
||||
"""External groups may collide across sources, every source needs its own prefix."""
|
||||
return f"{source.value.upper()}_{ext_group_name}"
|
||||
|
||||
@@ -300,17 +300,27 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async def authenticate(
|
||||
self, credentials: OAuth2PasswordRequestForm
|
||||
) -> Optional[User]:
|
||||
user = await super().authenticate(credentials)
|
||||
if user is None:
|
||||
try:
|
||||
user = await self.get_by_email(credentials.username)
|
||||
if not user.has_web_login:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
except exceptions.UserNotExists:
|
||||
pass
|
||||
try:
|
||||
user = await self.get_by_email(credentials.username)
|
||||
except exceptions.UserNotExists:
|
||||
self.password_helper.hash(credentials.password)
|
||||
return None
|
||||
|
||||
if not user.has_web_login:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
|
||||
verified, updated_password_hash = self.password_helper.verify_and_update(
|
||||
credentials.password, user.hashed_password
|
||||
)
|
||||
if not verified:
|
||||
return None
|
||||
|
||||
if updated_password_hash is not None:
|
||||
await self.user_db.update(user, {"hashed_password": updated_password_hash})
|
||||
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -21,18 +21,17 @@ from redis import Redis
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.celery.celery_redis import RedisDocumentSet
|
||||
from danswer.background.celery.celery_redis import RedisUserGroup
|
||||
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from danswer.background.celery.celery_utils import should_kick_off_deletion_of_cc_pair
|
||||
from danswer.background.celery.celery_utils import should_prune_cc_pair
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair
|
||||
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
@@ -43,29 +42,42 @@ from danswer.configs.constants import POSTGRES_CELERY_WORKER_APP_NAME
|
||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector_credential_pair import add_deletion_failure_message
|
||||
from danswer.db.connector_credential_pair import (
|
||||
get_connector_credential_pair,
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.document import count_documents_by_needs_sync
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document
|
||||
from danswer.db.document import get_document_connector_count
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.document import mark_document_as_synced
|
||||
from danswer.db.document_set import delete_document_set
|
||||
from danswer.db.document_set import fetch_document_set_for_document
|
||||
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from danswer.db.document_set import fetch_document_sets
|
||||
from danswer.db.document_set import fetch_document_sets_for_document
|
||||
from danswer.db.document_set import get_document_set_by_id
|
||||
from danswer.db.document_set import mark_document_set_as_synced
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import init_sqlalchemy_engine
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.index_attempt import get_last_attempt
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import UserGroup
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.redis.redis_pool import RedisPool
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import ColoredFormatter
|
||||
from danswer.utils.logger import PlainFormatter
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -93,56 +105,6 @@ celery_app.config_from_object(
|
||||
#
|
||||
# If imports from this module are needed, use local imports to avoid circular importing
|
||||
#####
|
||||
@build_celery_task_wrapper(name_cc_cleanup_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def cleanup_connector_credential_pair_task(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> int:
|
||||
"""Connector deletion task. This is run as an async task because it is a somewhat slow job.
|
||||
Needs to potentially update a large number of Postgres and Vespa docs, including deleting them
|
||||
or updating the ACL"""
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
with Session(engine) as db_session:
|
||||
# validate that the connector / credential pair is deletable
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise ValueError(
|
||||
f"Cannot run deletion attempt - connector_credential_pair with Connector ID: "
|
||||
f"{connector_id} and Credential ID: {credential_id} does not exist."
|
||||
)
|
||||
try:
|
||||
deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(
|
||||
connector_credential_pair=cc_pair, db_session=db_session
|
||||
)
|
||||
if deletion_attempt_disallowed_reason:
|
||||
raise ValueError(deletion_attempt_disallowed_reason)
|
||||
|
||||
# The bulk of the work is in here, updates Postgres and Vespa
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
return delete_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
document_index=document_index,
|
||||
cc_pair=cc_pair,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
stack_trace = traceback.format_exc()
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair.id, error_message)
|
||||
task_logger.exception(
|
||||
f"Failed to run connector_deletion. "
|
||||
f"connector_id={connector_id} credential_id={credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_cc_prune_task)
|
||||
@@ -166,11 +128,11 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
return
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
cc_pair.connector.source,
|
||||
InputType.PRUNE,
|
||||
cc_pair.connector.connector_specific_config,
|
||||
cc_pair.credential,
|
||||
db_session,
|
||||
db_session=db_session,
|
||||
source=cc_pair.connector.source,
|
||||
input_type=InputType.PRUNE,
|
||||
connector_specific_config=cc_pair.connector.connector_specific_config,
|
||||
credential=cc_pair.credential,
|
||||
)
|
||||
|
||||
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
|
||||
@@ -218,6 +180,11 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||
) -> int | None:
|
||||
"""This picks up stale documents (typically from indexing) and queues them for sync to Vespa.
|
||||
|
||||
Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required.
|
||||
"""
|
||||
# the fence is up, do nothing
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
return None
|
||||
@@ -233,6 +200,8 @@ def try_generate_stale_document_sync_tasks(
|
||||
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
|
||||
)
|
||||
|
||||
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
|
||||
|
||||
# rkuo: we could technically sync all stale docs in one big pass.
|
||||
# but I feel it's more understandable to group the docs by cc_pair
|
||||
total_tasks_generated = 0
|
||||
@@ -255,7 +224,7 @@ def try_generate_stale_document_sync_tasks(
|
||||
total_tasks_generated += tasks_generated
|
||||
|
||||
task_logger.info(
|
||||
f"All per connector generate_tasks finished. total_tasks_generated={total_tasks_generated}"
|
||||
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
|
||||
)
|
||||
|
||||
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
|
||||
@@ -265,6 +234,10 @@ def try_generate_stale_document_sync_tasks(
|
||||
def try_generate_document_set_sync_tasks(
|
||||
document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||
Returns None if no syncing is required.
|
||||
"""
|
||||
lock_beat.reacquire()
|
||||
|
||||
rds = RedisDocumentSet(document_set.id)
|
||||
@@ -310,6 +283,10 @@ def try_generate_document_set_sync_tasks(
|
||||
def try_generate_user_group_sync_tasks(
|
||||
usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||
Returns None if no syncing is required.
|
||||
"""
|
||||
lock_beat.reacquire()
|
||||
|
||||
rug = RedisUserGroup(usergroup.id)
|
||||
@@ -327,7 +304,9 @@ def try_generate_user_group_sync_tasks(
|
||||
r.delete(rug.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(f"generate_tasks starting. usergroup_id={usergroup.id}")
|
||||
task_logger.info(
|
||||
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
|
||||
)
|
||||
tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
@@ -339,7 +318,7 @@ def try_generate_user_group_sync_tasks(
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"generate_tasks finished. "
|
||||
f"RedisUserGroup.generate_tasks finished. "
|
||||
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
@@ -348,6 +327,78 @@ def try_generate_user_group_sync_tasks(
|
||||
return tasks_generated
|
||||
|
||||
|
||||
def try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||
Returns None if no syncing is required.
|
||||
"""
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
|
||||
# don't generate sync tasks if tasks are still pending
|
||||
if r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
# we need to refresh the state of the object inside the fence
|
||||
# to avoid a race condition with db.commit/fence deletion
|
||||
# at the end of this taskset
|
||||
try:
|
||||
db_session.refresh(cc_pair)
|
||||
except ObjectDeletedError:
|
||||
return None
|
||||
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
return None
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
last_indexing = get_last_attempt(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
search_settings_id=search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if last_indexing:
|
||||
if (
|
||||
last_indexing.status == IndexingStatus.IN_PROGRESS
|
||||
or last_indexing.status == IndexingStatus.NOT_STARTED
|
||||
):
|
||||
return None
|
||||
|
||||
# add tasks to celery and build up the task set to monitor in redis
|
||||
r.delete(rcd.taskset_key)
|
||||
|
||||
# Add all documents that need to be updated into the queue
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
|
||||
)
|
||||
tasks_generated = rcd.generate_tasks(celery_app, db_session, r, lock_beat)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
|
||||
# Currently we are allowing the sync to proceed with 0 tasks.
|
||||
# It's possible for sets/groups to be generated initially with no entries
|
||||
# and they still need to be marked as up to date.
|
||||
# if tasks_generated == 0:
|
||||
# return 0
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnectorDeletion.generate_tasks finished. "
|
||||
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
r.set(rcd.fence_key, tasks_generated)
|
||||
return tasks_generated
|
||||
|
||||
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
@@ -411,26 +462,37 @@ def check_for_vespa_sync_task() -> None:
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="check_for_cc_pair_deletion_task",
|
||||
name="check_for_connector_deletion_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_for_cc_pair_deletion_task() -> None:
|
||||
"""Runs periodically to check if any deletion tasks should be run"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# check if any cc pairs are up for deletion
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_kick_off_deletion_of_cc_pair(cc_pair, db_session):
|
||||
task_logger.info(
|
||||
f"Deleting the {cc_pair.name} connector credential pair"
|
||||
)
|
||||
def check_for_connector_deletion_task() -> None:
|
||||
r = redis_pool.get_client()
|
||||
|
||||
cleanup_connector_credential_pair_task.apply_async(
|
||||
kwargs=dict(
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
),
|
||||
lock_beat = r.lock(
|
||||
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair, db_session, r, lock_beat
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
@@ -602,7 +664,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
|
||||
return False
|
||||
|
||||
# document set sync
|
||||
doc_sets = fetch_document_set_for_document(document_id, db_session)
|
||||
doc_sets = fetch_document_sets_for_document(document_id, db_session)
|
||||
update_doc_sets: set[str] = set(doc_sets)
|
||||
|
||||
# User group sync
|
||||
@@ -617,8 +679,8 @@ def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa
|
||||
document_index.update(update_requests=[update_request])
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
document_index.update_single(update_request=update_request)
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
# the sync might repeat again later
|
||||
@@ -635,6 +697,92 @@ def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="document_by_cc_pair_cleanup_task",
|
||||
bind=True,
|
||||
soft_time_limit=45,
|
||||
time_limit=60,
|
||||
max_retries=3,
|
||||
)
|
||||
def document_by_cc_pair_cleanup_task(
|
||||
self: Task, document_id: str, connector_id: int, credential_id: int
|
||||
) -> bool:
|
||||
task_logger.info(f"document_id={document_id}")
|
||||
|
||||
try:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
)
|
||||
|
||||
count = get_document_connector_count(db_session, document_id)
|
||||
if count == 1:
|
||||
# count == 1 means this is the only remaining cc_pair reference to the doc
|
||||
# delete it from vespa and the db
|
||||
document_index.delete(doc_ids=[document_id])
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
)
|
||||
elif count > 1:
|
||||
# count > 1 means the document still has cc_pair references
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
return False
|
||||
|
||||
# the below functions do not include cc_pairs being deleted.
|
||||
# i.e. they will correctly omit access for the current cc_pair
|
||||
doc_access = get_access_for_document(
|
||||
document_id=document_id, db_session=db_session
|
||||
)
|
||||
|
||||
doc_sets = fetch_document_sets_for_document(document_id, db_session)
|
||||
update_doc_sets: set[str] = set(doc_sets)
|
||||
|
||||
update_request = UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
document_sets=update_doc_sets,
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
document_index.update_single(update_request=update_request)
|
||||
|
||||
# there are still other cc_pair references to the doc, so just resync to Vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
else:
|
||||
pass
|
||||
|
||||
# update_docs_last_modified__no_commit(
|
||||
# db_session=db_session,
|
||||
# document_ids=[document_id],
|
||||
# )
|
||||
|
||||
db_session.commit()
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
|
||||
except Exception as e:
|
||||
task_logger.exception("Unexpected exception")
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
def celery_task_postrun(
|
||||
sender: Any | None = None,
|
||||
@@ -687,6 +835,14 @@ def celery_task_postrun(
|
||||
r.srem(rug.taskset_key, task_id)
|
||||
return
|
||||
|
||||
if task_id.startswith(RedisConnectorDeletion.PREFIX):
|
||||
r = redis_pool.get_client()
|
||||
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
|
||||
if cc_pair_id is not None:
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
r.srem(rcd.taskset_key, task_id)
|
||||
return
|
||||
|
||||
|
||||
def monitor_connector_taskset(r: Redis) -> None:
|
||||
fence_value = r.get(RedisConnectorCredentialPair.get_fence_key())
|
||||
@@ -707,9 +863,7 @@ def monitor_connector_taskset(r: Redis) -> None:
|
||||
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
|
||||
|
||||
|
||||
def monitor_document_set_taskset(
|
||||
key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
def monitor_document_set_taskset(key_bytes: bytes, r: Redis) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
if document_set_id is None:
|
||||
@@ -730,41 +884,47 @@ def monitor_document_set_taskset(
|
||||
|
||||
count = cast(int, r.scard(rds.taskset_key))
|
||||
task_logger.info(
|
||||
f"document_set_id={document_set_id} remaining={count} initial={initial_count}"
|
||||
f"Document set sync: document_set_id={document_set_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
document_set = cast(
|
||||
DocumentSet,
|
||||
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
|
||||
) # casting since we "know" a document set with this ID exists
|
||||
if document_set:
|
||||
if not document_set.connector_credential_pairs:
|
||||
# if there are no connectors, then delete the document set.
|
||||
delete_document_set(document_set_row=document_set, db_session=db_session)
|
||||
task_logger.info(
|
||||
f"Successfully deleted document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
else:
|
||||
mark_document_set_as_synced(document_set_id, db_session)
|
||||
task_logger.info(
|
||||
f"Successfully synced document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
document_set = cast(
|
||||
DocumentSet,
|
||||
get_document_set_by_id(
|
||||
db_session=db_session, document_set_id=document_set_id
|
||||
),
|
||||
) # casting since we "know" a document set with this ID exists
|
||||
if document_set:
|
||||
if not document_set.connector_credential_pairs:
|
||||
# if there are no connectors, then delete the document set.
|
||||
delete_document_set(
|
||||
document_set_row=document_set, db_session=db_session
|
||||
)
|
||||
task_logger.info(
|
||||
f"Successfully deleted document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
else:
|
||||
mark_document_set_as_synced(document_set_id, db_session)
|
||||
task_logger.info(
|
||||
f"Successfully synced document set with ID: '{document_set_id}'!"
|
||||
)
|
||||
|
||||
r.delete(rds.taskset_key)
|
||||
r.delete(rds.fence_key)
|
||||
|
||||
|
||||
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -> None:
|
||||
key = key_bytes.decode("utf-8")
|
||||
usergroup_id = RedisUserGroup.get_id_from_fence_key(key)
|
||||
if not usergroup_id:
|
||||
task_logger.warning("Could not parse usergroup id from {key}")
|
||||
def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id is None:
|
||||
task_logger.warning("could not parse document set id from {key}")
|
||||
return
|
||||
|
||||
rug = RedisUserGroup(usergroup_id)
|
||||
fence_value = r.get(rug.fence_key)
|
||||
rcd = RedisConnectorDeletion(cc_pair_id)
|
||||
|
||||
fence_value = r.get(rcd.fence_key)
|
||||
if fence_value is None:
|
||||
return
|
||||
|
||||
@@ -774,44 +934,79 @@ def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -
|
||||
task_logger.error("The value is not an integer.")
|
||||
return
|
||||
|
||||
count = cast(int, r.scard(rug.taskset_key))
|
||||
count = cast(int, r.scard(rcd.taskset_key))
|
||||
task_logger.info(
|
||||
f"usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
|
||||
f"Connector deletion: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
|
||||
)
|
||||
if count > 0:
|
||||
return
|
||||
|
||||
try:
|
||||
fetch_user_group = fetch_versioned_implementation(
|
||||
"danswer.db.user_group", "fetch_user_group"
|
||||
)
|
||||
except ModuleNotFoundError:
|
||||
task_logger.exception(
|
||||
"fetch_versioned_implementation failed to look up fetch_user_group."
|
||||
)
|
||||
return
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
|
||||
if not cc_pair:
|
||||
return
|
||||
|
||||
user_group: UserGroup | None = fetch_user_group(
|
||||
db_session=db_session, user_group_id=usergroup_id
|
||||
try:
|
||||
# clean up the rest of the related Postgres entities
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# user groups
|
||||
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.user_group",
|
||||
"delete_user_group_cc_pair_relationship__no_commit",
|
||||
noop_fallback,
|
||||
)
|
||||
cleanup_user_groups(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
stack_trace = traceback.format_exc()
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair.id, error_message)
|
||||
task_logger.exception(
|
||||
f"Failed to run connector_deletion. "
|
||||
f"connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Successfully deleted connector_credential_pair with connector_id: '{cc_pair.connector_id}' "
|
||||
f"and credential_id: '{cc_pair.credential_id}'. "
|
||||
f"Deleted {initial_count} docs."
|
||||
)
|
||||
if user_group:
|
||||
if user_group.is_up_for_deletion:
|
||||
delete_user_group = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.user_group", "delete_user_group", noop_fallback
|
||||
)
|
||||
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(f" Deleted usergroup. id='{usergroup_id}'")
|
||||
else:
|
||||
mark_user_group_as_synced = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.user_group", "mark_user_group_as_synced", noop_fallback
|
||||
)
|
||||
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
|
||||
|
||||
r.delete(rug.taskset_key)
|
||||
r.delete(rug.fence_key)
|
||||
r.delete(rcd.taskset_key)
|
||||
r.delete(rcd.fence_key)
|
||||
|
||||
|
||||
@celery_app.task(name="monitor_vespa_sync", soft_time_limit=300)
|
||||
@@ -835,17 +1030,24 @@ def monitor_vespa_sync() -> None:
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
monitor_document_set_taskset(key_bytes, r, db_session)
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
monitor_document_set_taskset(key_bytes, r)
|
||||
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
monitor_usergroup_taskset(key_bytes, r, db_session)
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.background.celery_utils",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
|
||||
monitor_usergroup_taskset(key_bytes, r)
|
||||
|
||||
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
monitor_connector_deletion_taskset(key_bytes, r)
|
||||
|
||||
#
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
# task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}")
|
||||
@@ -889,6 +1091,12 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
class CeleryTaskPlainFormatter(PlainFormatter):
|
||||
def format(self, record: logging.LogRecord) -> str:
|
||||
@@ -970,14 +1178,18 @@ celery_app.conf.beat_schedule = {
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
"check-for-cc-pair-deletion": {
|
||||
"task": "check_for_cc_pair_deletion_task",
|
||||
# don't need to check too often, since we kick off a deletion initially
|
||||
# during the API call that actually marks the CC pair for deletion
|
||||
"schedule": timedelta(minutes=1),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
}
|
||||
celery_app.conf.beat_schedule.update(
|
||||
{
|
||||
"check-for-connector-deletion-task": {
|
||||
"task": "check_for_connector_deletion_task",
|
||||
# don't need to check too often, since we kick off a deletion initially
|
||||
# during the API call that actually marks the CC pair for deletion
|
||||
"schedule": timedelta(minutes=1),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
}
|
||||
)
|
||||
celery_app.conf.beat_schedule.update(
|
||||
{
|
||||
"check-for-prune": {
|
||||
|
||||
@@ -15,6 +15,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import construct_document_select_for_connector_credential_pair
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
)
|
||||
@@ -134,7 +135,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
@@ -189,7 +190,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
@@ -211,6 +212,9 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
|
||||
|
||||
class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
"""This class differs from the default in that the taskset used spans
|
||||
all connectors and is not per connector."""
|
||||
|
||||
PREFIX = "connectorsync"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
@@ -256,7 +260,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
@@ -281,6 +285,64 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
return len(async_results)
|
||||
|
||||
|
||||
class RedisConnectorDeletion(RedisObjectHelper):
|
||||
PREFIX = "connectordeletion"
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: redis.lock.Lock,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
stmt = construct_document_select_for_connector_credential_pair(
|
||||
cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
lock.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
|
||||
|
||||
# add to the tracking taskset in redis BEFORE creating the celery task.
|
||||
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=DanswerCeleryPriority.MEDIUM,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
return len(async_results)
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
It is priority aware and knows how to count across the multiple redis lists
|
||||
|
||||
@@ -3,7 +3,7 @@ from datetime import timezone
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.task_utils import name_cc_cleanup_task
|
||||
from danswer.background.celery.celery_redis import RedisConnectorDeletion
|
||||
from danswer.background.task_utils import name_cc_prune_task
|
||||
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
@@ -15,29 +15,44 @@ from danswer.connectors.interfaces import IdConnector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.enums import TaskStatus
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import TaskQueueState
|
||||
from danswer.db.tasks import check_task_is_live_and_not_timed_out
|
||||
from danswer.db.tasks import get_latest_task
|
||||
from danswer.db.tasks import get_latest_task_by_type
|
||||
from danswer.redis.redis_pool import RedisPool
|
||||
from danswer.server.documents.models import DeletionAttemptSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
redis_pool = RedisPool()
|
||||
|
||||
|
||||
def _get_deletion_status(
|
||||
connector_id: int, credential_id: int, db_session: Session
|
||||
) -> TaskQueueState | None:
|
||||
cleanup_task_name = name_cc_cleanup_task(
|
||||
connector_id=connector_id, credential_id=credential_id
|
||||
"""We no longer store TaskQueueState in the DB for a deletion attempt.
|
||||
This function populates TaskQueueState by just checking redis.
|
||||
"""
|
||||
cc_pair = get_connector_credential_pair(
|
||||
connector_id=connector_id, credential_id=credential_id, db_session=db_session
|
||||
)
|
||||
if not cc_pair:
|
||||
return None
|
||||
|
||||
rcd = RedisConnectorDeletion(cc_pair.id)
|
||||
|
||||
r = redis_pool.get_client()
|
||||
if not r.exists(rcd.fence_key):
|
||||
return None
|
||||
|
||||
return TaskQueueState(
|
||||
task_id="", task_name=rcd.fence_key, status=TaskStatus.STARTED
|
||||
)
|
||||
return get_latest_task(task_name=cleanup_task_name, db_session=db_session)
|
||||
|
||||
|
||||
def get_deletion_attempt_snapshot(
|
||||
@@ -54,31 +69,6 @@ def get_deletion_attempt_snapshot(
|
||||
)
|
||||
|
||||
|
||||
def should_kick_off_deletion_of_cc_pair(
|
||||
cc_pair: ConnectorCredentialPair, db_session: Session
|
||||
) -> bool:
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
|
||||
return False
|
||||
|
||||
if check_deletion_attempt_is_allowed(cc_pair, db_session):
|
||||
return False
|
||||
|
||||
deletion_task = _get_deletion_status(
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if deletion_task and check_task_is_live_and_not_timed_out(
|
||||
deletion_task,
|
||||
db_session,
|
||||
# 1 hour timeout
|
||||
timeout=60 * 60,
|
||||
):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def should_prune_cc_pair(
|
||||
connector: Connector, credential: Credential, db_session: Session
|
||||
) -> bool:
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
|
||||
from danswer.configs.app_configs import CELERY_RESULT_EXPIRES
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
|
||||
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY_RESULT_BACKEND
|
||||
from danswer.configs.app_configs import REDIS_HOST
|
||||
from danswer.configs.app_configs import REDIS_PASSWORD
|
||||
from danswer.configs.app_configs import REDIS_PORT
|
||||
@@ -27,7 +29,7 @@ if REDIS_SSL:
|
||||
# example celery_broker_url: "redis://:password@localhost:6379/15"
|
||||
broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
|
||||
|
||||
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
|
||||
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
|
||||
|
||||
# NOTE: prefetch 4 is significantly faster than prefetch 1 for small tasks
|
||||
# however, prefetching is bad when tasks are lengthy as those tasks
|
||||
@@ -42,3 +44,33 @@ broker_transport_options = {
|
||||
|
||||
task_default_priority = DanswerCeleryPriority.MEDIUM
|
||||
task_acks_late = True
|
||||
|
||||
# It's possible we don't even need celery's result backend, in which case all of the optimization below
|
||||
# might be irrelevant
|
||||
result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
|
||||
|
||||
# Option 0: Defaults (json serializer, no compression)
|
||||
# about 1.5 KB per queued task. 1KB in queue, 400B for result, 100 as a child entry in generator result
|
||||
|
||||
# Option 1: Reduces generator task result sizes by roughly 20%
|
||||
# task_compression = "bzip2"
|
||||
# task_serializer = "pickle"
|
||||
# result_compression = "bzip2"
|
||||
# result_serializer = "pickle"
|
||||
# accept_content=["pickle"]
|
||||
|
||||
# Option 2: this significantly reduces the size of the result for generator tasks since the list of children
|
||||
# can be large. small tasks change very little
|
||||
# def pickle_bz2_encoder(data):
|
||||
# return bz2.compress(pickle.dumps(data))
|
||||
|
||||
# def pickle_bz2_decoder(data):
|
||||
# return pickle.loads(bz2.decompress(data))
|
||||
|
||||
# from kombu import serialization # To register custom serialization with Celery/Kombu
|
||||
|
||||
# serialization.register('pickle-bzip2', pickle_bz2_encoder, pickle_bz2_decoder, 'application/x-pickle-bz2', 'binary')
|
||||
|
||||
# task_serializer = "pickle-bzip2"
|
||||
# result_serializer = "pickle-bzip2"
|
||||
# accept_content=["pickle", "pickle-bzip2"]
|
||||
|
||||
@@ -13,28 +13,16 @@ connector / credential pair from the access list
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document_connector_cnts
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.document import get_document_connector_counts
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from danswer.utils.variable_functionality import noop_fallback
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -57,13 +45,15 @@ def delete_connector_credential_pair_batch(
|
||||
with prepare_to_modify_documents(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
):
|
||||
document_connector_cnts = get_document_connector_cnts(
|
||||
document_connector_counts = get_document_connector_counts(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
)
|
||||
|
||||
# figure out which docs need to be completely deleted
|
||||
document_ids_to_delete = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt == 1
|
||||
document_id
|
||||
for document_id, cnt in document_connector_counts
|
||||
if cnt == 1
|
||||
]
|
||||
logger.debug(f"Deleting documents: {document_ids_to_delete}")
|
||||
|
||||
@@ -76,7 +66,7 @@ def delete_connector_credential_pair_batch(
|
||||
|
||||
# figure out which docs need to be updated
|
||||
document_ids_to_update = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt > 1
|
||||
document_id for document_id, cnt in document_connector_counts if cnt > 1
|
||||
]
|
||||
|
||||
# maps document id to list of document set names
|
||||
@@ -109,7 +99,7 @@ def delete_connector_credential_pair_batch(
|
||||
document_index.update(update_requests=update_requests)
|
||||
|
||||
# clean up Postgres
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
delete_documents_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_update,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
@@ -118,78 +108,3 @@ def delete_connector_credential_pair_batch(
|
||||
),
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_connector_credential_pair(
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> int:
|
||||
connector_id = cc_pair.connector_id
|
||||
credential_id = cc_pair.credential_id
|
||||
|
||||
num_docs_deleted = 0
|
||||
while True:
|
||||
documents = get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
limit=_DELETION_BATCH_SIZE,
|
||||
)
|
||||
if not documents:
|
||||
break
|
||||
|
||||
delete_connector_credential_pair_batch(
|
||||
document_ids=[document.id for document in documents],
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_index=document_index,
|
||||
)
|
||||
num_docs_deleted += len(documents)
|
||||
|
||||
# clean up the rest of the related Postgres entities
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
# user groups
|
||||
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.user_group",
|
||||
"delete_user_group_cc_pair_relationship__no_commit",
|
||||
noop_fallback,
|
||||
)
|
||||
cleanup_user_groups(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
logger.info("Found no credentials left for connector, deleting connector")
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
|
||||
logger.notice(
|
||||
"Successfully deleted connector_credential_pair with connector_id:"
|
||||
f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs."
|
||||
)
|
||||
return num_docs_deleted
|
||||
|
||||
@@ -56,11 +56,11 @@ def _get_connector_runner(
|
||||
|
||||
try:
|
||||
runnable_connector = instantiate_connector(
|
||||
attempt.connector_credential_pair.connector.source,
|
||||
task,
|
||||
attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
attempt.connector_credential_pair.credential,
|
||||
db_session,
|
||||
db_session=db_session,
|
||||
source=attempt.connector_credential_pair.connector.source,
|
||||
input_type=task,
|
||||
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
credential=attempt.connector_credential_pair.credential,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
|
||||
@@ -14,14 +14,6 @@ from danswer.db.tasks import mark_task_start
|
||||
from danswer.db.tasks import register_task
|
||||
|
||||
|
||||
def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str:
|
||||
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
|
||||
|
||||
|
||||
def name_document_set_sync_task(document_set_id: int) -> str:
|
||||
return f"sync_doc_set_{document_set_id}"
|
||||
|
||||
|
||||
def name_cc_prune_task(
|
||||
connector_id: int | None = None, credential_id: int | None = None
|
||||
) -> str:
|
||||
|
||||
@@ -211,7 +211,6 @@ def cleanup_indexing_jobs(
|
||||
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
|
||||
) -> dict[int, Future | SimpleJob]:
|
||||
existing_jobs_copy = existing_jobs.copy()
|
||||
|
||||
# clean up completed jobs
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
for attempt_id, job in existing_jobs.items():
|
||||
@@ -312,7 +311,12 @@ def kickoff_indexing_jobs(
|
||||
|
||||
indexing_attempt_count = 0
|
||||
|
||||
primary_client_full = False
|
||||
secondary_client_full = False
|
||||
for attempt, search_settings in new_indexing_attempts:
|
||||
if primary_client_full and secondary_client_full:
|
||||
break
|
||||
|
||||
use_secondary_index = (
|
||||
search_settings.status == IndexModelStatus.FUTURE
|
||||
if search_settings is not None
|
||||
@@ -337,22 +341,28 @@ def kickoff_indexing_jobs(
|
||||
)
|
||||
continue
|
||||
|
||||
if use_secondary_index:
|
||||
run = secondary_client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
attempt.connector_credential_pair_id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
if not use_secondary_index:
|
||||
if not primary_client_full:
|
||||
run = client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
attempt.connector_credential_pair_id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
if not run:
|
||||
primary_client_full = True
|
||||
else:
|
||||
run = client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
attempt.connector_credential_pair_id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
if not secondary_client_full:
|
||||
run = secondary_client.submit(
|
||||
run_indexing_entrypoint,
|
||||
attempt.id,
|
||||
attempt.connector_credential_pair_id,
|
||||
global_version.get_is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
if not run:
|
||||
secondary_client_full = True
|
||||
|
||||
if run:
|
||||
if indexing_attempt_count == 0:
|
||||
|
||||
@@ -122,7 +122,7 @@ def load_personas_from_yaml(
|
||||
prompt_ids=prompt_ids,
|
||||
document_set_ids=doc_set_ids,
|
||||
tool_ids=tool_ids,
|
||||
default_persona=True,
|
||||
builtin_persona=True,
|
||||
is_public=True,
|
||||
display_priority=existing_persona.display_priority
|
||||
if existing_persona is not None
|
||||
|
||||
@@ -11,7 +11,6 @@ from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SearchResponse
|
||||
from danswer.tools.custom.base_tool_types import ToolResultType
|
||||
from danswer.tools.graphing.models import GraphGenerationDisplay
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
@@ -49,8 +48,6 @@ class QADocsResponse(RetrievalDocs):
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
FINISHED = "finished"
|
||||
NEW_RESPONSE = "new_response"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
@@ -176,7 +173,6 @@ AnswerQuestionPossibleReturn = (
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
| GraphGenerationDisplay
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
@@ -18,8 +18,6 @@ from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
@@ -74,16 +72,13 @@ from danswer.search.utils import relevant_sections_to_indices
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.analysis.analysis_tool import CSVAnalysisTool
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
|
||||
from danswer.tools.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from danswer.tools.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.graphing.graphing_tool import GraphingResponse
|
||||
from danswer.tools.graphing.graphing_tool import GraphingTool
|
||||
from danswer.tools.graphing.models import GraphGenerationDisplay
|
||||
from danswer.tools.graphing.models import GRAPHING_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
@@ -250,7 +245,6 @@ def _get_force_search_settings(
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
| GraphingResponse
|
||||
| LLMRelevanceFilterResponse
|
||||
| FinalUsedContextDocsResponse
|
||||
| ChatMessageDetail
|
||||
@@ -259,7 +253,6 @@ ChatPacket = (
|
||||
| CitationInfo
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| GraphGenerationDisplay
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
)
|
||||
@@ -280,6 +273,7 @@ def stream_chat_message_objects(
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
enforce_chat_session_id_for_search_docs: bool = True,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
@@ -451,6 +445,7 @@ def stream_chat_message_objects(
|
||||
chat_session=chat_session,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
enforce_chat_session_id_for_search_docs=enforce_chat_session_id_for_search_docs,
|
||||
)
|
||||
|
||||
# Generates full documents currently
|
||||
@@ -537,21 +532,8 @@ def stream_chat_message_objects(
|
||||
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
|
||||
for db_tool_model in persona.tools:
|
||||
# handle in-code tools specially
|
||||
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||
if (
|
||||
tool_cls.__name__ == CSVAnalysisTool.__name__
|
||||
and not latest_query_files
|
||||
):
|
||||
tool_dict[db_tool_model.id] = [CSVAnalysisTool()]
|
||||
|
||||
if (
|
||||
tool_cls.__name__ == GraphingTool.__name__
|
||||
and not latest_query_files
|
||||
):
|
||||
tool_dict[db_tool_model.id] = [GraphingTool(output_dir="output")]
|
||||
|
||||
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
@@ -622,27 +604,24 @@ def stream_chat_message_objects(
|
||||
]
|
||||
|
||||
continue
|
||||
|
||||
# handle all custom tools
|
||||
if db_tool_model.openapi_schema:
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema(
|
||||
build_custom_tools_from_openapi_schema_and_headers(
|
||||
db_tool_model.openapi_schema,
|
||||
dynamic_schema_info=DynamicSchemaInfo(
|
||||
chat_session_id=chat_session_id,
|
||||
message_id=user_message.id if user_message else None,
|
||||
),
|
||||
custom_headers=db_tool_model.custom_headers,
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
@@ -688,185 +667,102 @@ def stream_chat_message_objects(
|
||||
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
||||
dropped_indices = None
|
||||
tool_result = None
|
||||
yielded_message_id_info = True
|
||||
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, StreamStopInfo):
|
||||
if packet.stop_reason is not StreamStopReason.NEW_RESPONSE:
|
||||
break
|
||||
if isinstance(packet, ToolResponse):
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
dedupe_docs=(
|
||||
retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False
|
||||
),
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
db_citations = None
|
||||
if reference_db_search_docs is not None:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if reference_db_search_docs:
|
||||
db_citations = _translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
if dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
if tool_result is None:
|
||||
tool_call = None
|
||||
else:
|
||||
tool_call = ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments={
|
||||
k: v if not isinstance(v, bytes) else v.decode("utf-8")
|
||||
for k, v in tool_result.tool_args.items()
|
||||
},
|
||||
tool_result=tool_result.tool_result,
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
)
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=cast(
|
||||
QADocsResponse, qa_docs_response
|
||||
).rephrased_query
|
||||
if qa_docs_response is not None
|
||||
else None,
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=cast(MessageSpecificCitations, db_citations).citation_map
|
||||
if db_citations is not None
|
||||
else None,
|
||||
error=None,
|
||||
tool_call=tool_call,
|
||||
)
|
||||
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
yield msg_detail_response
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=gen_ai_response_message.id
|
||||
if user_message is not None
|
||||
else gen_ai_response_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
yielded_message_id_info = False
|
||||
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=gen_ai_response_message,
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
reference_db_search_docs = None
|
||||
file_ids = save_files_from_urls(
|
||||
[img.url for img in img_generation_response]
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
else:
|
||||
if not yielded_message_id_info:
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=gen_ai_response_message.id,
|
||||
reserved_assistant_message_id=reserved_message_id,
|
||||
)
|
||||
yielded_message_id_info = True
|
||||
|
||||
if isinstance(packet, ToolResponse):
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
dedupe_docs=retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
if reference_db_search_docs is not None:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
)
|
||||
elif packet.id == GRAPHING_RESPONSE_ID:
|
||||
graph_generation = cast(GraphingResponse, packet.response)
|
||||
yield graph_generation
|
||||
|
||||
# yield GraphGenerationDisplay(
|
||||
# file_id=graph_generation.extra_graph_display.file_id,
|
||||
# line_graph=graph_generation.extra_graph_display.line_graph,
|
||||
# )
|
||||
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
)
|
||||
|
||||
file_ids = save_files_from_urls(
|
||||
[img.url for img in img_generation_response]
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(
|
||||
CustomToolCallSummary, packet.response
|
||||
)
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
logger.debug("Reached end of stream")
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(f"Failed to process chat message: {error_msg}")
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
error_msg = str(e)
|
||||
yield StreamingError(error=error_msg)
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to process chat message.")
|
||||
|
||||
error_msg = str(e)
|
||||
stack_trace = traceback.format_exc()
|
||||
client_error_msg = litellm_exception_to_error_msg(e, llm)
|
||||
if llm.config.api_key and len(llm.config.api_key) > 2:
|
||||
@@ -887,8 +783,11 @@ def stream_chat_message_objects(
|
||||
)
|
||||
yield AllCitations(citations=answer.citations)
|
||||
|
||||
if answer.llm_answer == "":
|
||||
return
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
@@ -903,14 +802,18 @@ def stream_chat_message_objects(
|
||||
if message_specific_citations
|
||||
else None,
|
||||
error=None,
|
||||
tool_call=ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
if tool_result
|
||||
else None,
|
||||
tool_calls=(
|
||||
[
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
if tool_result
|
||||
else []
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug("Committing messages")
|
||||
|
||||
@@ -135,7 +135,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
os.environ.get("POSTGRES_PASSWORD") or "password"
|
||||
)
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
|
||||
# defaults to False
|
||||
@@ -159,11 +159,18 @@ REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
|
||||
|
||||
# Used by celery as broker and backend
|
||||
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15))
|
||||
REDIS_DB_NUMBER_CELERY_RESULT_BACKEND = int(
|
||||
os.environ.get("REDIS_DB_NUMBER_CELERY_RESULT_BACKEND", 14)
|
||||
)
|
||||
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # broker
|
||||
|
||||
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "CERT_NONE")
|
||||
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
|
||||
# should be one of "required", "optional", or "none"
|
||||
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none")
|
||||
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "")
|
||||
|
||||
CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds
|
||||
|
||||
#####
|
||||
# Connector Configs
|
||||
#####
|
||||
|
||||
@@ -99,6 +99,7 @@ class DocumentSource(str, Enum):
|
||||
CLICKUP = "clickup"
|
||||
MEDIAWIKI = "mediawiki"
|
||||
WIKIPEDIA = "wikipedia"
|
||||
ASANA = "asana"
|
||||
S3 = "s3"
|
||||
R2 = "r2"
|
||||
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
|
||||
@@ -133,6 +134,12 @@ class AuthType(str, Enum):
|
||||
SAML = "saml"
|
||||
|
||||
|
||||
class SessionType(str, Enum):
|
||||
CHAT = "Chat"
|
||||
SEARCH = "Search"
|
||||
SLACK = "Slack"
|
||||
|
||||
|
||||
class QAFeedbackType(str, Enum):
|
||||
LIKE = "like" # User likes the answer, used for metrics
|
||||
DISLIKE = "dislike" # User dislikes the answer, used for metrics
|
||||
@@ -165,7 +172,6 @@ class FileOrigin(str, Enum):
|
||||
CONNECTOR = "connector"
|
||||
GENERATED_REPORT = "generated_report"
|
||||
OTHER = "other"
|
||||
GRAPH_GEN = "graph_gen"
|
||||
|
||||
|
||||
class PostgresAdvisoryLocks(Enum):
|
||||
@@ -182,6 +188,8 @@ class DanswerCeleryQueues:
|
||||
class DanswerRedisLocks:
|
||||
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
|
||||
MONITOR_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:monitor_connector_deletion_beat"
|
||||
|
||||
|
||||
class DanswerCeleryPriority(int, Enum):
|
||||
|
||||
233
backend/danswer/connectors/asana/asana_api.py
Executable file
233
backend/danswer/connectors/asana/asana_api.py
Executable file
@@ -0,0 +1,233 @@
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
|
||||
import asana # type: ignore
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# https://github.com/Asana/python-asana/tree/master?tab=readme-ov-file#documentation-for-api-endpoints
|
||||
class AsanaTask:
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
title: str,
|
||||
text: str,
|
||||
link: str,
|
||||
last_modified: datetime,
|
||||
project_gid: str,
|
||||
project_name: str,
|
||||
) -> None:
|
||||
self.id = id
|
||||
self.title = title
|
||||
self.text = text
|
||||
self.link = link
|
||||
self.last_modified = last_modified
|
||||
self.project_gid = project_gid
|
||||
self.project_name = project_name
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"ID: {self.id}\nTitle: {self.title}\nLast modified: {self.last_modified}\nText: {self.text}"
|
||||
|
||||
|
||||
class AsanaAPI:
|
||||
def __init__(
|
||||
self, api_token: str, workspace_gid: str, team_gid: str | None
|
||||
) -> None:
|
||||
self._user = None # type: ignore
|
||||
self.workspace_gid = workspace_gid
|
||||
self.team_gid = team_gid
|
||||
|
||||
self.configuration = asana.Configuration()
|
||||
self.api_client = asana.ApiClient(self.configuration)
|
||||
self.tasks_api = asana.TasksApi(self.api_client)
|
||||
self.stories_api = asana.StoriesApi(self.api_client)
|
||||
self.users_api = asana.UsersApi(self.api_client)
|
||||
self.project_api = asana.ProjectsApi(self.api_client)
|
||||
self.workspaces_api = asana.WorkspacesApi(self.api_client)
|
||||
|
||||
self.api_error_count = 0
|
||||
self.configuration.access_token = api_token
|
||||
self.task_count = 0
|
||||
|
||||
def get_tasks(
|
||||
self, project_gids: list[str] | None, start_date: str
|
||||
) -> Iterator[AsanaTask]:
|
||||
"""Get all tasks from the projects with the given gids that were modified since the given date.
|
||||
If project_gids is None, get all tasks from all projects in the workspace."""
|
||||
logger.info("Starting to fetch Asana projects")
|
||||
projects = self.project_api.get_projects(
|
||||
opts={
|
||||
"workspace": self.workspace_gid,
|
||||
"opt_fields": "gid,name,archived,modified_at",
|
||||
}
|
||||
)
|
||||
start_seconds = int(time.mktime(datetime.now().timetuple()))
|
||||
projects_list = []
|
||||
project_count = 0
|
||||
for project_info in projects:
|
||||
project_gid = project_info["gid"]
|
||||
if project_gids is None or project_gid in project_gids:
|
||||
projects_list.append(project_gid)
|
||||
else:
|
||||
logger.debug(
|
||||
f"Skipping project: {project_gid} - not in accepted project_gids"
|
||||
)
|
||||
project_count += 1
|
||||
if project_count % 100 == 0:
|
||||
logger.info(f"Processed {project_count} projects")
|
||||
|
||||
logger.info(f"Found {len(projects_list)} projects to process")
|
||||
for project_gid in projects_list:
|
||||
for task in self._get_tasks_for_project(
|
||||
project_gid, start_date, start_seconds
|
||||
):
|
||||
yield task
|
||||
logger.info(f"Completed fetching {self.task_count} tasks from Asana")
|
||||
if self.api_error_count > 0:
|
||||
logger.warning(
|
||||
f"Encountered {self.api_error_count} API errors during task fetching"
|
||||
)
|
||||
|
||||
def _get_tasks_for_project(
|
||||
self, project_gid: str, start_date: str, start_seconds: int
|
||||
) -> Iterator[AsanaTask]:
|
||||
project = self.project_api.get_project(project_gid, opts={})
|
||||
if project["archived"]:
|
||||
logger.info(f"Skipping archived project: {project['name']} ({project_gid})")
|
||||
return []
|
||||
if not project["team"] or not project["team"]["gid"]:
|
||||
logger.info(
|
||||
f"Skipping project without a team: {project['name']} ({project_gid})"
|
||||
)
|
||||
return []
|
||||
if project["privacy_setting"] == "private":
|
||||
if self.team_gid and project["team"]["gid"] != self.team_gid:
|
||||
logger.info(
|
||||
f"Skipping private project not in configured team: {project['name']} ({project_gid})"
|
||||
)
|
||||
return []
|
||||
else:
|
||||
logger.info(
|
||||
f"Processing private project in configured team: {project['name']} ({project_gid})"
|
||||
)
|
||||
|
||||
simple_start_date = start_date.split(".")[0].split("+")[0]
|
||||
logger.info(
|
||||
f"Fetching tasks modified since {simple_start_date} for project: {project['name']} ({project_gid})"
|
||||
)
|
||||
|
||||
opts = {
|
||||
"opt_fields": "name,memberships,memberships.project,completed_at,completed_by,created_at,"
|
||||
"created_by,custom_fields,dependencies,due_at,due_on,external,html_notes,liked,likes,"
|
||||
"modified_at,notes,num_hearts,parent,projects,resource_subtype,resource_type,start_on,"
|
||||
"workspace,permalink_url",
|
||||
"modified_since": start_date,
|
||||
}
|
||||
tasks_from_api = self.tasks_api.get_tasks_for_project(project_gid, opts)
|
||||
for data in tasks_from_api:
|
||||
self.task_count += 1
|
||||
if self.task_count % 10 == 0:
|
||||
end_seconds = time.mktime(datetime.now().timetuple())
|
||||
runtime_seconds = end_seconds - start_seconds
|
||||
if runtime_seconds > 0:
|
||||
logger.info(
|
||||
f"Processed {self.task_count} tasks in {runtime_seconds:.0f} seconds "
|
||||
f"({self.task_count / runtime_seconds:.2f} tasks/second)"
|
||||
)
|
||||
|
||||
logger.debug(f"Processing Asana task: {data['name']}")
|
||||
|
||||
text = self._construct_task_text(data)
|
||||
|
||||
try:
|
||||
text += self._fetch_and_add_comments(data["gid"])
|
||||
|
||||
last_modified_date = self.format_date(data["modified_at"])
|
||||
text += f"Last modified: {last_modified_date}\n"
|
||||
|
||||
task = AsanaTask(
|
||||
id=data["gid"],
|
||||
title=data["name"],
|
||||
text=text,
|
||||
link=data["permalink_url"],
|
||||
last_modified=datetime.fromisoformat(data["modified_at"]),
|
||||
project_gid=project_gid,
|
||||
project_name=project["name"],
|
||||
)
|
||||
yield task
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Error processing task {data['gid']} in project {project_gid}",
|
||||
exc_info=True,
|
||||
)
|
||||
self.api_error_count += 1
|
||||
|
||||
def _construct_task_text(self, data: Dict) -> str:
|
||||
text = f"{data['name']}\n\n"
|
||||
|
||||
if data["notes"]:
|
||||
text += f"{data['notes']}\n\n"
|
||||
|
||||
if data["created_by"] and data["created_by"]["gid"]:
|
||||
creator = self.get_user(data["created_by"]["gid"])["name"]
|
||||
created_date = self.format_date(data["created_at"])
|
||||
text += f"Created by: {creator} on {created_date}\n"
|
||||
|
||||
if data["due_on"]:
|
||||
due_date = self.format_date(data["due_on"])
|
||||
text += f"Due date: {due_date}\n"
|
||||
|
||||
if data["completed_at"]:
|
||||
completed_date = self.format_date(data["completed_at"])
|
||||
text += f"Completed on: {completed_date}\n"
|
||||
|
||||
text += "\n"
|
||||
return text
|
||||
|
||||
def _fetch_and_add_comments(self, task_gid: str) -> str:
|
||||
text = ""
|
||||
stories_opts: Dict[str, str] = {}
|
||||
story_start = time.time()
|
||||
stories = self.stories_api.get_stories_for_task(task_gid, stories_opts)
|
||||
|
||||
story_count = 0
|
||||
comment_count = 0
|
||||
|
||||
for story in stories:
|
||||
story_count += 1
|
||||
if story["resource_subtype"] == "comment_added":
|
||||
comment = self.stories_api.get_story(
|
||||
story["gid"], opts={"opt_fields": "text,created_by,created_at"}
|
||||
)
|
||||
commenter = self.get_user(comment["created_by"]["gid"])["name"]
|
||||
text += f"Comment by {commenter}: {comment['text']}\n\n"
|
||||
comment_count += 1
|
||||
|
||||
story_duration = time.time() - story_start
|
||||
logger.debug(
|
||||
f"Processed {story_count} stories (including {comment_count} comments) in {story_duration:.2f} seconds"
|
||||
)
|
||||
|
||||
return text
|
||||
|
||||
def get_user(self, user_gid: str) -> Dict:
|
||||
if self._user is not None:
|
||||
return self._user
|
||||
self._user = self.users_api.get_user(user_gid, {"opt_fields": "name,email"})
|
||||
|
||||
if not self._user:
|
||||
logger.warning(f"Unable to fetch user information for user_gid: {user_gid}")
|
||||
return {"name": "Unknown"}
|
||||
return self._user
|
||||
|
||||
def format_date(self, date_str: str) -> str:
|
||||
date = datetime.fromisoformat(date_str)
|
||||
return time.strftime("%Y-%m-%d", date.timetuple())
|
||||
|
||||
def get_time(self) -> str:
|
||||
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||
120
backend/danswer/connectors/asana/connector.py
Executable file
120
backend/danswer/connectors/asana/connector.py
Executable file
@@ -0,0 +1,120 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.asana import asana_api
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class AsanaConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
asana_workspace_id: str,
|
||||
asana_project_ids: str | None = None,
|
||||
asana_team_id: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
) -> None:
|
||||
self.workspace_id = asana_workspace_id
|
||||
self.project_ids_to_index: list[str] | None = (
|
||||
asana_project_ids.split(",") if asana_project_ids is not None else None
|
||||
)
|
||||
self.asana_team_id = asana_team_id
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
logger.info(
|
||||
f"AsanaConnector initialized with workspace_id: {asana_workspace_id}"
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.api_token = credentials["asana_api_token_secret"]
|
||||
self.asana_client = asana_api.AsanaAPI(
|
||||
api_token=self.api_token,
|
||||
workspace_gid=self.workspace_id,
|
||||
team_gid=self.asana_team_id,
|
||||
)
|
||||
logger.info("Asana credentials loaded and API client initialized")
|
||||
return None
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch | None
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_time = datetime.datetime.fromtimestamp(start).isoformat()
|
||||
logger.info(f"Starting Asana poll from {start_time}")
|
||||
asana = asana_api.AsanaAPI(
|
||||
api_token=self.api_token,
|
||||
workspace_gid=self.workspace_id,
|
||||
team_gid=self.asana_team_id,
|
||||
)
|
||||
docs_batch: list[Document] = []
|
||||
tasks = asana.get_tasks(self.project_ids_to_index, start_time)
|
||||
|
||||
for task in tasks:
|
||||
doc = self._message_to_doc(task)
|
||||
docs_batch.append(doc)
|
||||
|
||||
if len(docs_batch) >= self.batch_size:
|
||||
logger.info(f"Yielding batch of {len(docs_batch)} documents")
|
||||
yield docs_batch
|
||||
docs_batch = []
|
||||
|
||||
if docs_batch:
|
||||
logger.info(f"Yielding final batch of {len(docs_batch)} documents")
|
||||
yield docs_batch
|
||||
|
||||
logger.info("Asana poll completed")
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
logger.notice("Starting full index of all Asana tasks")
|
||||
return self.poll_source(start=0, end=None)
|
||||
|
||||
def _message_to_doc(self, task: asana_api.AsanaTask) -> Document:
|
||||
logger.debug(f"Converting Asana task {task.id} to Document")
|
||||
return Document(
|
||||
id=task.id,
|
||||
sections=[Section(link=task.link, text=task.text)],
|
||||
doc_updated_at=task.last_modified,
|
||||
source=DocumentSource.ASANA,
|
||||
semantic_identifier=task.title,
|
||||
metadata={
|
||||
"group": task.project_gid,
|
||||
"project": task.project_name,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
import os
|
||||
|
||||
logger.notice("Starting Asana connector test")
|
||||
connector = AsanaConnector(
|
||||
os.environ["WORKSPACE_ID"],
|
||||
os.environ["PROJECT_IDS"],
|
||||
os.environ["TEAM_ID"],
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"asana_api_token_secret": os.environ["API_TOKEN"],
|
||||
}
|
||||
)
|
||||
logger.info("Loading all documents from Asana")
|
||||
all_docs = connector.load_from_state()
|
||||
current = time.time()
|
||||
one_day_ago = current - 24 * 60 * 60 # 1 day
|
||||
logger.info("Polling for documents updated in the last 24 hours")
|
||||
latest_docs = connector.poll_source(one_day_ago, current)
|
||||
for docs in latest_docs:
|
||||
for doc in docs:
|
||||
print(doc.id)
|
||||
logger.notice("Asana connector test completed")
|
||||
@@ -4,6 +4,7 @@ from typing import Type
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.asana.connector import AsanaConnector
|
||||
from danswer.connectors.axero.connector import AxeroConnector
|
||||
from danswer.connectors.blob.connector import BlobStorageConnector
|
||||
from danswer.connectors.bookstack.connector import BookstackConnector
|
||||
@@ -91,6 +92,7 @@ def identify_connector_class(
|
||||
DocumentSource.CLICKUP: ClickupConnector,
|
||||
DocumentSource.MEDIAWIKI: MediaWikiConnector,
|
||||
DocumentSource.WIKIPEDIA: WikipediaConnector,
|
||||
DocumentSource.ASANA: AsanaConnector,
|
||||
DocumentSource.S3: BlobStorageConnector,
|
||||
DocumentSource.R2: BlobStorageConnector,
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
|
||||
@@ -124,11 +126,11 @@ def identify_connector_class(
|
||||
|
||||
|
||||
def instantiate_connector(
|
||||
db_session: Session,
|
||||
source: DocumentSource,
|
||||
input_type: InputType,
|
||||
connector_specific_config: dict[str, Any],
|
||||
credential: Credential,
|
||||
db_session: Session,
|
||||
) -> BaseConnector:
|
||||
connector_class = identify_connector_class(source, input_type)
|
||||
connector = connector_class(**connector_specific_config)
|
||||
|
||||
@@ -6,7 +6,6 @@ from datetime import timezone
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
@@ -21,19 +20,13 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
|
||||
from danswer.connectors.google_drive.connector_auth import (
|
||||
get_google_drive_creds_for_authorized_user,
|
||||
)
|
||||
from danswer.connectors.google_drive.connector_auth import (
|
||||
get_google_drive_creds_for_service_account,
|
||||
)
|
||||
from danswer.connectors.google_drive.connector_auth import get_google_drive_creds
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
@@ -407,42 +400,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
(2) A credential which holds a service account key JSON file, which
|
||||
can then be used to impersonate any user in the workspace.
|
||||
"""
|
||||
creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
access_token_json_str = cast(
|
||||
str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
|
||||
)
|
||||
creds = get_google_drive_creds_for_authorized_user(
|
||||
token_json_str=access_token_json_str
|
||||
)
|
||||
|
||||
# tell caller to update token stored in DB if it has changed
|
||||
# (e.g. the token has been refreshed)
|
||||
new_creds_json_str = creds.to_json() if creds else ""
|
||||
if new_creds_json_str != access_token_json_str:
|
||||
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
|
||||
|
||||
if DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
service_account_key_json_str = credentials[
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
]
|
||||
creds = get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str=service_account_key_json_str
|
||||
)
|
||||
|
||||
# "Impersonate" a user if one is specified
|
||||
delegated_user_email = cast(
|
||||
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
|
||||
)
|
||||
if delegated_user_email:
|
||||
creds = creds.with_subject(delegated_user_email) if creds else None # type: ignore
|
||||
|
||||
if creds is None:
|
||||
raise PermissionError(
|
||||
"Unable to access Google Drive - unknown credential structure."
|
||||
)
|
||||
|
||||
creds, new_creds_dict = get_google_drive_creds(credentials)
|
||||
self.creds = creds
|
||||
return new_creds_dict
|
||||
|
||||
@@ -509,6 +467,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
|
||||
file["modifiedTime"]
|
||||
).astimezone(timezone.utc),
|
||||
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
|
||||
additional_info=file.get("id"),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -10,11 +10,13 @@ from google.oauth2.service_account import Credentials as ServiceAccountCredentia
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
|
||||
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
|
||||
from danswer.connectors.google_drive.constants import BASE_SCOPES
|
||||
from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
|
||||
)
|
||||
@@ -22,7 +24,8 @@ from danswer.connectors.google_drive.constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
|
||||
from danswer.connectors.google_drive.constants import SCOPES
|
||||
from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
|
||||
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
|
||||
from danswer.db.credentials import update_credential_json
|
||||
from danswer.db.models import User
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
@@ -34,15 +37,25 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def build_gdrive_scopes() -> list[str]:
|
||||
base_scopes: list[str] = BASE_SCOPES
|
||||
permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES
|
||||
groups_scopes: list[str] = FETCH_GROUPS_SCOPES
|
||||
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
return base_scopes + permissions_scopes + groups_scopes
|
||||
return base_scopes + permissions_scopes
|
||||
|
||||
|
||||
def _build_frontend_google_drive_redirect() -> str:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
|
||||
|
||||
|
||||
def get_google_drive_creds_for_authorized_user(
|
||||
token_json_str: str,
|
||||
token_json_str: str, scopes: list[str] = build_gdrive_scopes()
|
||||
) -> OAuthCredentials | None:
|
||||
creds_json = json.loads(token_json_str)
|
||||
creds = OAuthCredentials.from_authorized_user_info(creds_json, SCOPES)
|
||||
creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes)
|
||||
if creds.valid:
|
||||
return creds
|
||||
|
||||
@@ -59,18 +72,67 @@ def get_google_drive_creds_for_authorized_user(
|
||||
return None
|
||||
|
||||
|
||||
def get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str: str,
|
||||
def _get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes()
|
||||
) -> ServiceAccountCredentials | None:
|
||||
service_account_key = json.loads(service_account_key_json_str)
|
||||
creds = ServiceAccountCredentials.from_service_account_info(
|
||||
service_account_key, scopes=SCOPES
|
||||
service_account_key, scopes=scopes
|
||||
)
|
||||
if not creds.valid or not creds.expired:
|
||||
creds.refresh(Request())
|
||||
return creds if creds.valid else None
|
||||
|
||||
|
||||
def get_google_drive_creds(
|
||||
credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes()
|
||||
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
|
||||
oauth_creds = None
|
||||
service_creds = None
|
||||
new_creds_dict = None
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
|
||||
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
|
||||
oauth_creds = get_google_drive_creds_for_authorized_user(
|
||||
token_json_str=access_token_json_str, scopes=scopes
|
||||
)
|
||||
|
||||
# tell caller to update token stored in DB if it has changed
|
||||
# (e.g. the token has been refreshed)
|
||||
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
|
||||
if new_creds_json_str != access_token_json_str:
|
||||
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
|
||||
|
||||
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
|
||||
service_account_key_json_str = credentials[
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
|
||||
]
|
||||
service_creds = _get_google_drive_creds_for_service_account(
|
||||
service_account_key_json_str=service_account_key_json_str,
|
||||
scopes=scopes,
|
||||
)
|
||||
|
||||
# "Impersonate" a user if one is specified
|
||||
delegated_user_email = cast(
|
||||
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
|
||||
)
|
||||
if delegated_user_email:
|
||||
service_creds = (
|
||||
service_creds.with_subject(delegated_user_email)
|
||||
if service_creds
|
||||
else None
|
||||
)
|
||||
|
||||
creds: ServiceAccountCredentials | OAuthCredentials | None = (
|
||||
oauth_creds or service_creds
|
||||
)
|
||||
if creds is None:
|
||||
raise PermissionError(
|
||||
"Unable to access Google Drive - unknown credential structure."
|
||||
)
|
||||
|
||||
return creds, new_creds_dict
|
||||
|
||||
|
||||
def verify_csrf(credential_id: int, state: str) -> None:
|
||||
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
|
||||
if csrf != state:
|
||||
@@ -84,7 +146,7 @@ def get_auth_url(credential_id: int) -> str:
|
||||
credential_json = json.loads(creds_str)
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
credential_json,
|
||||
scopes=SCOPES,
|
||||
scopes=build_gdrive_scopes(),
|
||||
redirect_uri=_build_frontend_google_drive_redirect(),
|
||||
)
|
||||
auth_url, _ = flow.authorization_url(prompt="consent")
|
||||
@@ -107,7 +169,7 @@ def update_credential_access_tokens(
|
||||
app_credentials = get_google_app_cred()
|
||||
flow = InstalledAppFlow.from_client_config(
|
||||
app_credentials.model_dump(),
|
||||
scopes=SCOPES,
|
||||
scopes=build_gdrive_scopes(),
|
||||
redirect_uri=_build_frontend_google_drive_redirect(),
|
||||
)
|
||||
flow.fetch_token(code=auth_code)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
|
||||
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user"
|
||||
SCOPES = [
|
||||
"https://www.googleapis.com/auth/drive.readonly",
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly",
|
||||
]
|
||||
|
||||
BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
|
||||
FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
|
||||
FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"]
|
||||
|
||||
@@ -113,6 +113,9 @@ class DocumentBase(BaseModel):
|
||||
# The default title is semantic_identifier though unless otherwise specified
|
||||
title: str | None = None
|
||||
from_ingestion_api: bool = False
|
||||
# Anything else that may be useful that is specific to this particular connector type that other
|
||||
# parts of the code may need. If you're unsure, this can be left as None
|
||||
additional_info: Any = None
|
||||
|
||||
def get_title_for_document_index(
|
||||
self,
|
||||
|
||||
@@ -211,7 +211,7 @@ def handle_message(
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if message_info.email:
|
||||
add_non_web_user_if_not_exists(message_info.email, db_session)
|
||||
add_non_web_user_if_not_exists(db_session, message_info.email)
|
||||
|
||||
# first check if we need to respond with a standard answer
|
||||
used_standard_answer = handle_standard_answers(
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
from fastapi import HTTPException
|
||||
from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.models.blocks import DividerBlock
|
||||
@@ -153,15 +154,23 @@ def handle_regular_answer(
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
if len(new_message_request.messages) > 1:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(
|
||||
db_session,
|
||||
new_message_request.persona_id,
|
||||
user=None,
|
||||
get_editable=False,
|
||||
),
|
||||
)
|
||||
if new_message_request.persona_config:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Slack bot does not support persona config",
|
||||
)
|
||||
|
||||
elif new_message_request.persona_id:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(
|
||||
db_session,
|
||||
new_message_request.persona_id,
|
||||
user=None,
|
||||
get_editable=False,
|
||||
),
|
||||
)
|
||||
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
|
||||
@@ -178,14 +178,8 @@ def delete_search_doc_message_relationship(
|
||||
|
||||
|
||||
def delete_tool_call_for_message_id(message_id: int, db_session: Session) -> None:
|
||||
chat_message = (
|
||||
db_session.query(ChatMessage).filter(ChatMessage.id == message_id).first()
|
||||
)
|
||||
if chat_message and chat_message.tool_call_id:
|
||||
stmt = delete(ToolCall).where(ToolCall.id == chat_message.tool_call_id)
|
||||
db_session.execute(stmt)
|
||||
chat_message.tool_call_id = None
|
||||
|
||||
stmt = delete(ToolCall).where(ToolCall.message_id == message_id)
|
||||
db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -232,7 +226,7 @@ def create_chat_session(
|
||||
db_session: Session,
|
||||
description: str,
|
||||
user_id: UUID | None,
|
||||
persona_id: int,
|
||||
persona_id: int | None, # Can be none if temporary persona is used
|
||||
llm_override: LLMOverride | None = None,
|
||||
prompt_override: PromptOverride | None = None,
|
||||
one_shot: bool = False,
|
||||
@@ -394,7 +388,7 @@ def get_chat_messages_by_session(
|
||||
)
|
||||
|
||||
if prefetch_tool_calls:
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_call))
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
|
||||
result = db_session.scalars(stmt).unique().all()
|
||||
else:
|
||||
result = db_session.scalars(stmt).all()
|
||||
@@ -480,7 +474,7 @@ def create_new_chat_message(
|
||||
alternate_assistant_id: int | None = None,
|
||||
# Maps the citation number [n] to the DB SearchDoc
|
||||
citations: dict[int, int] | None = None,
|
||||
tool_call: ToolCall | None = None,
|
||||
tool_calls: list[ToolCall] | None = None,
|
||||
commit: bool = True,
|
||||
reserved_message_id: int | None = None,
|
||||
overridden_model: str | None = None,
|
||||
@@ -500,7 +494,7 @@ def create_new_chat_message(
|
||||
existing_message.message_type = message_type
|
||||
existing_message.citations = citations
|
||||
existing_message.files = files
|
||||
existing_message.tool_call = tool_call if tool_call else None
|
||||
existing_message.tool_calls = tool_calls if tool_calls else []
|
||||
existing_message.error = error
|
||||
existing_message.alternate_assistant_id = alternate_assistant_id
|
||||
existing_message.overridden_model = overridden_model
|
||||
@@ -519,7 +513,7 @@ def create_new_chat_message(
|
||||
message_type=message_type,
|
||||
citations=citations,
|
||||
files=files,
|
||||
tool_call=tool_call if tool_call else None,
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
error=error,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
overridden_model=overridden_model,
|
||||
@@ -604,6 +598,7 @@ def get_doc_query_identifiers_from_model(
|
||||
chat_session: ChatSession,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
enforce_chat_session_id_for_search_docs: bool,
|
||||
) -> list[tuple[str, int]]:
|
||||
"""Given a list of search_doc_ids"""
|
||||
search_docs = (
|
||||
@@ -623,7 +618,8 @@ def get_doc_query_identifiers_from_model(
|
||||
for doc in search_docs
|
||||
]
|
||||
):
|
||||
raise ValueError("Invalid reference doc, not from this chat session.")
|
||||
if enforce_chat_session_id_for_search_docs:
|
||||
raise ValueError("Invalid reference doc, not from this chat session.")
|
||||
except IndexError:
|
||||
# This happens when the doc has no chat_messages associated with it.
|
||||
# which happens as an edge case where the chat message failed to save
|
||||
@@ -753,13 +749,14 @@ def translate_db_message_to_chat_message_detail(
|
||||
time_sent=chat_message.time_sent,
|
||||
citations=chat_message.citations,
|
||||
files=chat_message.files or [],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
alternate_assistant_id=chat_message.alternate_assistant_id,
|
||||
overridden_model=chat_message.overridden_model,
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.credentials import fetch_credential_by_id
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexAttempt
|
||||
@@ -24,6 +25,10 @@ from danswer.db.models import UserGroup__ConnectorCredentialPair
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.server.models import StatusResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
|
||||
from ee.danswer.external_permissions.permission_sync_function_map import (
|
||||
check_if_valid_sync_source,
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -74,7 +79,7 @@ def _add_user_filters(
|
||||
.correlate(ConnectorCredentialPair)
|
||||
)
|
||||
else:
|
||||
where_clause |= ConnectorCredentialPair.is_public == True # noqa: E712
|
||||
where_clause |= ConnectorCredentialPair.access_type == AccessType.PUBLIC
|
||||
|
||||
return stmt.where(where_clause)
|
||||
|
||||
@@ -94,8 +99,7 @@ def get_connector_credential_pairs(
|
||||
) # noqa
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
results = db_session.scalars(stmt)
|
||||
return list(results.all())
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def add_deletion_failure_message(
|
||||
@@ -309,9 +313,9 @@ def associate_default_cc_pair(db_session: Session) -> None:
|
||||
association = ConnectorCredentialPair(
|
||||
connector_id=0,
|
||||
credential_id=0,
|
||||
access_type=AccessType.PUBLIC,
|
||||
name="DefaultCCPair",
|
||||
status=ConnectorCredentialPairStatus.ACTIVE,
|
||||
is_public=True,
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.commit()
|
||||
@@ -336,8 +340,9 @@ def add_credential_to_connector(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
cc_pair_name: str | None,
|
||||
is_public: bool,
|
||||
access_type: AccessType,
|
||||
groups: list[int] | None,
|
||||
auto_sync_options: dict | None = None,
|
||||
) -> StatusResponse:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
@@ -345,6 +350,13 @@ def add_credential_to_connector(
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
|
||||
if access_type == AccessType.SYNC:
|
||||
if not check_if_valid_sync_source(connector.source):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Connector of type {connector.source} does not support SYNC access type",
|
||||
)
|
||||
|
||||
if credential is None:
|
||||
error_msg = (
|
||||
f"Credential {credential_id} does not exist or does not belong to user"
|
||||
@@ -375,12 +387,13 @@ def add_credential_to_connector(
|
||||
credential_id=credential_id,
|
||||
name=cc_pair_name,
|
||||
status=ConnectorCredentialPairStatus.ACTIVE,
|
||||
is_public=is_public,
|
||||
access_type=access_type,
|
||||
auto_sync_options=auto_sync_options,
|
||||
)
|
||||
db_session.add(association)
|
||||
db_session.flush() # make sure the association has an id
|
||||
|
||||
if groups:
|
||||
if groups and access_type != AccessType.SYNC:
|
||||
_relate_groups_to_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=association.id,
|
||||
@@ -423,6 +436,10 @@ def remove_credential_from_connector(
|
||||
)
|
||||
|
||||
if association is not None:
|
||||
delete_user__ext_group_for_cc_pair__no_commit(
|
||||
db_session=db_session,
|
||||
cc_pair_id=association.id,
|
||||
)
|
||||
db_session.delete(association)
|
||||
db_session.commit()
|
||||
return StatusResponse(
|
||||
|
||||
@@ -4,7 +4,6 @@ from collections.abc import Generator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
@@ -17,14 +16,17 @@ from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.engine.util import TransactionalContext
|
||||
from sqlalchemy.exc import OperationalError
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import null
|
||||
|
||||
from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.feedback import delete_document_feedback_for_documents__no_commit
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import Document as DbDocument
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import User
|
||||
from danswer.db.tag import delete_document_tags_for_documents__no_commit
|
||||
from danswer.db.utils import model_to_dict
|
||||
from danswer.document_index.interfaces import DocumentMetadata
|
||||
@@ -126,7 +128,18 @@ def get_documents_by_ids(
|
||||
return list(documents)
|
||||
|
||||
|
||||
def get_document_connector_cnts(
|
||||
def get_document_connector_count(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
) -> int:
|
||||
results = get_document_connector_counts(db_session, [document_id])
|
||||
if not results or len(results) == 0:
|
||||
return 0
|
||||
|
||||
return results[0][1]
|
||||
|
||||
|
||||
def get_document_connector_counts(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
) -> Sequence[tuple[str, int]]:
|
||||
@@ -141,7 +154,7 @@ def get_document_connector_cnts(
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
|
||||
|
||||
def get_document_cnts_for_cc_pairs(
|
||||
def get_document_counts_for_cc_pairs(
|
||||
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
|
||||
) -> Sequence[tuple[int, int, int]]:
|
||||
stmt = (
|
||||
@@ -175,16 +188,14 @@ def get_document_cnts_for_cc_pairs(
|
||||
def get_access_info_for_document(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
) -> tuple[str, list[UUID | None], bool] | None:
|
||||
) -> tuple[str, list[str | None], bool] | None:
|
||||
"""Gets access info for a single document by calling the get_access_info_for_documents function
|
||||
and passing a list with a single document ID.
|
||||
|
||||
Args:
|
||||
db_session (Session): The database session to use.
|
||||
document_id (str): The document ID to fetch access info for.
|
||||
|
||||
Returns:
|
||||
Optional[Tuple[str, List[UUID | None], bool]]: A tuple containing the document ID, a list of user IDs,
|
||||
Optional[Tuple[str, List[str | None], bool]]: A tuple containing the document ID, a list of user emails,
|
||||
and a boolean indicating if the document is globally public, or None if no results are found.
|
||||
"""
|
||||
results = get_access_info_for_documents(db_session, [document_id])
|
||||
@@ -197,19 +208,27 @@ def get_access_info_for_document(
|
||||
def get_access_info_for_documents(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
) -> Sequence[tuple[str, list[UUID | None], bool]]:
|
||||
) -> Sequence[tuple[str, list[str | None], bool]]:
|
||||
"""Gets back all relevant access info for the given documents. This includes
|
||||
the user_ids for cc pairs that the document is associated with + whether any
|
||||
of the associated cc pairs are intending to make the document globally public.
|
||||
Returns the list where each element contains:
|
||||
- Document ID (which is also the ID of the DocumentByConnectorCredentialPair)
|
||||
- List of emails of Danswer users with direct access to the doc (includes a "None" element if
|
||||
the connector was set up by an admin when auth was off
|
||||
- bool for whether the document is public (the document later can also be marked public by
|
||||
automatic permission sync step)
|
||||
"""
|
||||
stmt = select(
|
||||
DocumentByConnectorCredentialPair.id,
|
||||
func.array_agg(func.coalesce(User.email, null())).label("user_emails"),
|
||||
func.bool_or(ConnectorCredentialPair.access_type == AccessType.PUBLIC).label(
|
||||
"public_doc"
|
||||
),
|
||||
).where(DocumentByConnectorCredentialPair.id.in_(document_ids))
|
||||
|
||||
stmt = (
|
||||
select(
|
||||
DocumentByConnectorCredentialPair.id,
|
||||
func.array_agg(Credential.user_id).label("user_ids"),
|
||||
func.bool_or(ConnectorCredentialPair.is_public).label("public_doc"),
|
||||
)
|
||||
.where(DocumentByConnectorCredentialPair.id.in_(document_ids))
|
||||
.join(
|
||||
stmt.join(
|
||||
Credential,
|
||||
DocumentByConnectorCredentialPair.credential_id == Credential.id,
|
||||
)
|
||||
@@ -222,6 +241,13 @@ def get_access_info_for_documents(
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.outerjoin(
|
||||
User,
|
||||
and_(
|
||||
Credential.user_id == User.id,
|
||||
ConnectorCredentialPair.access_type != AccessType.SYNC,
|
||||
),
|
||||
)
|
||||
# don't include CC pairs that are being deleted
|
||||
# NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them
|
||||
.where(ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING)
|
||||
@@ -267,9 +293,19 @@ def upsert_documents(
|
||||
for doc in seen_documents.values()
|
||||
]
|
||||
)
|
||||
# for now, there are no columns to update. If more metadata is added, then this
|
||||
# needs to change to an `on_conflict_do_update`
|
||||
on_conflict_stmt = insert_stmt.on_conflict_do_nothing()
|
||||
|
||||
on_conflict_stmt = insert_stmt.on_conflict_do_update(
|
||||
index_elements=["id"], # Conflict target
|
||||
set_={
|
||||
"from_ingestion_api": insert_stmt.excluded.from_ingestion_api,
|
||||
"boost": insert_stmt.excluded.boost,
|
||||
"hidden": insert_stmt.excluded.hidden,
|
||||
"semantic_id": insert_stmt.excluded.semantic_id,
|
||||
"link": insert_stmt.excluded.link,
|
||||
"primary_owners": insert_stmt.excluded.primary_owners,
|
||||
"secondary_owners": insert_stmt.excluded.secondary_owners,
|
||||
},
|
||||
)
|
||||
db_session.execute(on_conflict_stmt)
|
||||
db_session.commit()
|
||||
|
||||
@@ -350,11 +386,34 @@ def upsert_documents_complete(
|
||||
|
||||
|
||||
def delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session: Session,
|
||||
document_id: str,
|
||||
connector_credential_pair_identifier: ConnectorCredentialPairIdentifier
|
||||
| None = None,
|
||||
) -> None:
|
||||
"""Deletes a single document by cc pair relationship entry.
|
||||
Foreign key rows are left in place.
|
||||
The implicit assumption is that the document itself still has other cc_pair
|
||||
references and needs to continue existing.
|
||||
"""
|
||||
delete_documents_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
connector_credential_pair_identifier=connector_credential_pair_identifier,
|
||||
)
|
||||
|
||||
|
||||
def delete_documents_by_connector_credential_pair__no_commit(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
connector_credential_pair_identifier: ConnectorCredentialPairIdentifier
|
||||
| None = None,
|
||||
) -> None:
|
||||
"""This deletes just the document by cc pair entries for a particular cc pair.
|
||||
Foreign key rows are left in place.
|
||||
The implicit assumption is that the document itself still has other cc_pair
|
||||
references and needs to continue existing.
|
||||
"""
|
||||
stmt = delete(DocumentByConnectorCredentialPair).where(
|
||||
DocumentByConnectorCredentialPair.id.in_(document_ids)
|
||||
)
|
||||
@@ -377,8 +436,9 @@ def delete_documents__no_commit(db_session: Session, document_ids: list[str]) ->
|
||||
def delete_documents_complete__no_commit(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> None:
|
||||
"""This completely deletes the documents from the db, including all foreign key relationships"""
|
||||
logger.info(f"Deleting {len(document_ids)} documents from the DB")
|
||||
delete_document_by_connector_credential_pair__no_commit(db_session, document_ids)
|
||||
delete_documents_by_connector_credential_pair__no_commit(db_session, document_ids)
|
||||
delete_document_feedback_for_documents__no_commit(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.connector_credential_pair import get_cc_pair_groups_for_ids
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Document
|
||||
@@ -180,7 +181,7 @@ def _check_if_cc_pairs_are_owned_by_groups(
|
||||
ids=missing_cc_pair_ids,
|
||||
)
|
||||
for cc_pair in cc_pairs:
|
||||
if not cc_pair.is_public:
|
||||
if cc_pair.access_type != AccessType.PUBLIC:
|
||||
raise ValueError(
|
||||
f"Connector Credential Pair with ID: '{cc_pair.id}'"
|
||||
" is not owned by the specified groups"
|
||||
@@ -569,7 +570,7 @@ def construct_document_select_by_docset(
|
||||
return stmt
|
||||
|
||||
|
||||
def fetch_document_set_for_document(
|
||||
def fetch_document_sets_for_document(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
) -> list[str]:
|
||||
@@ -704,7 +705,7 @@ def check_document_sets_are_public(
|
||||
ConnectorCredentialPair.id.in_(
|
||||
connector_credential_pair_ids # type:ignore
|
||||
),
|
||||
ConnectorCredentialPair.is_public.is_(False),
|
||||
ConnectorCredentialPair.access_type != AccessType.PUBLIC,
|
||||
)
|
||||
.limit(1)
|
||||
.first()
|
||||
|
||||
@@ -137,8 +137,8 @@ def get_sqlalchemy_engine() -> Engine:
|
||||
)
|
||||
_SYNC_ENGINE = create_engine(
|
||||
connection_string,
|
||||
pool_size=40,
|
||||
max_overflow=10,
|
||||
pool_size=5,
|
||||
max_overflow=0,
|
||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||
)
|
||||
@@ -156,8 +156,8 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
connect_args={
|
||||
"server_settings": {"application_name": POSTGRES_APP_NAME + "_async"}
|
||||
},
|
||||
pool_size=40,
|
||||
max_overflow=10,
|
||||
pool_size=5,
|
||||
max_overflow=0,
|
||||
pool_pre_ping=POSTGRES_POOL_PRE_PING,
|
||||
pool_recycle=POSTGRES_POOL_RECYCLE,
|
||||
)
|
||||
|
||||
@@ -51,3 +51,9 @@ class ConnectorCredentialPairStatus(str, PyEnum):
|
||||
|
||||
def is_active(self) -> bool:
|
||||
return self == ConnectorCredentialPairStatus.ACTIVE
|
||||
|
||||
|
||||
class AccessType(str, PyEnum):
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
SYNC = "sync"
|
||||
|
||||
@@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.db.chat import get_chat_message
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.models import ChatMessageFeedback
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Document as DbDocument
|
||||
@@ -94,7 +95,7 @@ def _add_user_filters(
|
||||
.correlate(CCPair)
|
||||
)
|
||||
else:
|
||||
where_clause |= CCPair.is_public == True # noqa: E712
|
||||
where_clause |= CCPair.access_type == AccessType.PUBLIC
|
||||
|
||||
return stmt.where(where_clause)
|
||||
|
||||
|
||||
@@ -4,9 +4,11 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from danswer.db.models import DocumentSet
|
||||
from danswer.db.models import LLMProvider as LLMProviderModel
|
||||
from danswer.db.models import LLMProvider__UserGroup
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.models import Tool as ToolModel
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
@@ -60,13 +62,21 @@ def upsert_cloud_embedding_provider(
|
||||
|
||||
|
||||
def upsert_llm_provider(
|
||||
llm_provider: LLMProviderUpsertRequest, db_session: Session
|
||||
llm_provider: LLMProviderUpsertRequest,
|
||||
db_session: Session,
|
||||
is_creation: bool = True,
|
||||
) -> FullLLMProvider:
|
||||
existing_llm_provider = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
|
||||
)
|
||||
if existing_llm_provider and is_creation:
|
||||
raise ValueError(f"LLM Provider with name {llm_provider.name} already exists")
|
||||
|
||||
if not existing_llm_provider:
|
||||
if not is_creation:
|
||||
raise ValueError(
|
||||
f"LLM Provider with name {llm_provider.name} does not exist"
|
||||
)
|
||||
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
|
||||
db_session.add(existing_llm_provider)
|
||||
|
||||
@@ -103,6 +113,20 @@ def fetch_existing_embedding_providers(
|
||||
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
|
||||
|
||||
|
||||
def fetch_existing_doc_sets(
|
||||
db_session: Session, doc_ids: list[int]
|
||||
) -> list[DocumentSet]:
|
||||
return list(
|
||||
db_session.scalars(select(DocumentSet).where(DocumentSet.id.in_(doc_ids))).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolModel]:
|
||||
return list(
|
||||
db_session.scalars(select(ToolModel).where(ToolModel.id.in_(tool_ids))).all()
|
||||
)
|
||||
|
||||
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
|
||||
@@ -39,6 +39,7 @@ from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.configs.constants import TokenRateLimitScope
|
||||
@@ -108,7 +109,7 @@ class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
||||
|
||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
||||
"OAuthAccount", lazy="joined"
|
||||
"OAuthAccount", lazy="joined", cascade="all, delete-orphan"
|
||||
)
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
@@ -122,7 +123,13 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
# if specified, controls the assistants that are shown to the user + their order
|
||||
# if not specified, all assistants are shown
|
||||
chosen_assistants: Mapped[list[int]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
postgresql.JSONB(), nullable=False, default=[-2, -1, 0]
|
||||
)
|
||||
visible_assistants: Mapped[list[int]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False, default=[]
|
||||
)
|
||||
hidden_assistants: Mapped[list[int]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False, default=[]
|
||||
)
|
||||
|
||||
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -170,7 +177,9 @@ class InputPrompt(Base):
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
user: Mapped[User | None] = relationship("User", back_populates="input_prompts")
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class InputPrompt__User(Base):
|
||||
@@ -214,7 +223,9 @@ class Notification(Base):
|
||||
notif_type: Mapped[NotificationType] = mapped_column(
|
||||
Enum(NotificationType, native_enum=False)
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
dismissed: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
last_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
@@ -249,7 +260,7 @@ class Persona__User(Base):
|
||||
|
||||
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id"), primary_key=True, nullable=True
|
||||
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True, nullable=True
|
||||
)
|
||||
|
||||
|
||||
@@ -260,7 +271,7 @@ class DocumentSet__User(Base):
|
||||
ForeignKey("document_set.id"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id"), primary_key=True, nullable=True
|
||||
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True, nullable=True
|
||||
)
|
||||
|
||||
|
||||
@@ -384,10 +395,20 @@ class ConnectorCredentialPair(Base):
|
||||
# controls whether the documents indexed by this CC pair are visible to all
|
||||
# or if they are only visible to those with that are given explicit access
|
||||
# (e.g. via owning the credential or being a part of a group that is given access)
|
||||
is_public: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
default=True,
|
||||
nullable=False,
|
||||
access_type: Mapped[AccessType] = mapped_column(
|
||||
Enum(AccessType, native_enum=False), nullable=False
|
||||
)
|
||||
|
||||
# special info needed for the auto-sync feature. The exact structure depends on the
|
||||
|
||||
# source type (defined in the connector's `source` field)
|
||||
# E.g. for google_drive perm sync:
|
||||
# {"customer_id": "123567", "company_domain": "@danswer.ai"}
|
||||
auto_sync_options: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
last_time_perm_sync: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
# Time finished, not used for calculating backend jobs which uses time started (created)
|
||||
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
@@ -418,6 +439,7 @@ class ConnectorCredentialPair(Base):
|
||||
|
||||
class Document(Base):
|
||||
__tablename__ = "document"
|
||||
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
|
||||
|
||||
# this should correspond to the ID of the document
|
||||
# (as is passed around in Danswer)
|
||||
@@ -461,7 +483,18 @@ class Document(Base):
|
||||
secondary_owners: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
# TODO if more sensitive data is added here for display, make sure to add user/group permission
|
||||
# Permission sync columns
|
||||
# Email addresses are saved at the document level for externally synced permissions
|
||||
# This is becuase the normal flow of assigning permissions is through the cc_pair
|
||||
# doesn't apply here
|
||||
external_user_emails: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
# These group ids have been prefixed by the source type
|
||||
external_user_group_ids: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
|
||||
"DocumentRetrievalFeedback", back_populates="document"
|
||||
@@ -541,7 +574,9 @@ class Credential(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
credential_json: Mapped[dict[str, Any]] = mapped_column(EncryptedJson())
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
# if `true`, then all Admins will have access to the credential
|
||||
admin_public: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -854,8 +889,10 @@ class ToolCall(Base):
|
||||
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
|
||||
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
|
||||
|
||||
message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
|
||||
|
||||
message: Mapped["ChatMessage"] = relationship(
|
||||
"ChatMessage", back_populates="tool_call"
|
||||
"ChatMessage", back_populates="tool_calls"
|
||||
)
|
||||
|
||||
|
||||
@@ -863,8 +900,12 @@ class ChatSession(Base):
|
||||
__tablename__ = "chat_session"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"))
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
description: Mapped[str] = mapped_column(Text)
|
||||
# One-shot direct answering, currently the two types of chats are not mixed
|
||||
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
@@ -898,7 +939,6 @@ class ChatSession(Base):
|
||||
prompt_override: Mapped[PromptOverride | None] = mapped_column(
|
||||
PydanticType(PromptOverride), nullable=True
|
||||
)
|
||||
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
@@ -907,7 +947,6 @@ class ChatSession(Base):
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
|
||||
folder: Mapped["ChatFolder"] = relationship(
|
||||
"ChatFolder", back_populates="chat_sessions"
|
||||
@@ -982,14 +1021,9 @@ class ChatMessage(Base):
|
||||
)
|
||||
# NOTE: Should always be attached to the `assistant` message.
|
||||
# represents the tool calls used to generate this message
|
||||
|
||||
tool_call_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("tool_call.id"), nullable=True
|
||||
)
|
||||
# NOTE: Should always be attached to the `assistant` message.
|
||||
# represents the tool calls used to generate this message
|
||||
tool_call: Mapped["ToolCall"] = relationship(
|
||||
"ToolCall", back_populates="message", foreign_keys=[tool_call_id]
|
||||
tool_calls: Mapped[list["ToolCall"]] = relationship(
|
||||
"ToolCall",
|
||||
back_populates="message",
|
||||
)
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
@@ -1005,7 +1039,9 @@ class ChatFolder(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
# Only null if auth is off
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=0)
|
||||
|
||||
@@ -1136,7 +1172,9 @@ class DocumentSet(Base):
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
description: Mapped[str] = mapped_column(String)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
# Whether changes to the document set have been propagated
|
||||
is_up_to_date: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
# If `False`, then the document set is not visible to users who are not explicitly
|
||||
@@ -1180,7 +1218,9 @@ class Prompt(Base):
|
||||
__tablename__ = "prompt"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
description: Mapped[str] = mapped_column(String)
|
||||
system_prompt: Mapped[str] = mapped_column(Text)
|
||||
@@ -1215,9 +1255,13 @@ class Tool(Base):
|
||||
openapi_schema: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
|
||||
custom_headers: Mapped[list[dict[str, str]] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
# user who created / owns the tool. Will be None for built-in tools.
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
|
||||
# Relationship to Persona through the association table
|
||||
@@ -1241,7 +1285,9 @@ class Persona(Base):
|
||||
__tablename__ = "persona"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
description: Mapped[str] = mapped_column(String)
|
||||
# Number of chunks to pass to the LLM for generation.
|
||||
@@ -1270,9 +1316,18 @@ class Persona(Base):
|
||||
starter_messages: Mapped[list[StarterMessage] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
# Default personas are configured via backend during deployment
|
||||
search_start_date: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), default=None
|
||||
)
|
||||
# Built-in personas are configured via backend during deployment
|
||||
# Treated specially (cannot be user edited etc.)
|
||||
default_persona: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
builtin_persona: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# Default personas are personas created by admins and are automatically added
|
||||
# to all users' assistants list.
|
||||
is_default_persona: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, nullable=False
|
||||
)
|
||||
# controls whether the persona is available to be selected by users
|
||||
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
# controls the ordering of personas in the UI
|
||||
@@ -1323,10 +1378,10 @@ class Persona(Base):
|
||||
# Default personas loaded via yaml cannot have the same name
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"_default_persona_name_idx",
|
||||
"_builtin_persona_name_idx",
|
||||
"name",
|
||||
unique=True,
|
||||
postgresql_where=(default_persona == True), # noqa: E712
|
||||
postgresql_where=(builtin_persona == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1350,55 +1405,6 @@ class ChannelConfig(TypedDict):
|
||||
follow_up_tags: NotRequired[list[str]]
|
||||
|
||||
|
||||
class StandardAnswerCategory(Base):
|
||||
__tablename__ = "standard_answer_category"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="categories",
|
||||
)
|
||||
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
|
||||
"SlackBotConfig",
|
||||
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answer_categories",
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(Base):
|
||||
__tablename__ = "standard_answer"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
keyword: Mapped[str] = mapped_column(String)
|
||||
answer: Mapped[str] = mapped_column(String)
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
match_regex: Mapped[bool] = mapped_column(Boolean)
|
||||
match_any_keywords: Mapped[bool] = mapped_column(Boolean)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"unique_keyword_active",
|
||||
keyword,
|
||||
active,
|
||||
unique=True,
|
||||
postgresql_where=(active == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
categories: Mapped[list[StandardAnswerCategory]] = relationship(
|
||||
"StandardAnswerCategory",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
chat_messages: Mapped[list[ChatMessage]] = relationship(
|
||||
"ChatMessage",
|
||||
secondary=ChatMessage__StandardAnswer.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
|
||||
|
||||
class SlackBotResponseType(str, PyEnum):
|
||||
QUOTES = "quotes"
|
||||
CITATIONS = "citations"
|
||||
@@ -1424,7 +1430,7 @@ class SlackBotConfig(Base):
|
||||
)
|
||||
|
||||
persona: Mapped[Persona | None] = relationship("Persona")
|
||||
standard_answer_categories: Mapped[list[StandardAnswerCategory]] = relationship(
|
||||
standard_answer_categories: Mapped[list["StandardAnswerCategory"]] = relationship(
|
||||
"StandardAnswerCategory",
|
||||
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
|
||||
back_populates="slack_bot_configs",
|
||||
@@ -1486,7 +1492,9 @@ class SamlAccount(Base):
|
||||
__tablename__ = "saml"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), unique=True)
|
||||
user_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), unique=True
|
||||
)
|
||||
encrypted_cookie: Mapped[str] = mapped_column(Text, unique=True)
|
||||
expires_at: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -1505,7 +1513,7 @@ class User__UserGroup(Base):
|
||||
ForeignKey("user_group.id"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id"), primary_key=True, nullable=True
|
||||
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True, nullable=True
|
||||
)
|
||||
|
||||
|
||||
@@ -1654,94 +1662,70 @@ class TokenRateLimit__UserGroup(Base):
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswerCategory(Base):
|
||||
__tablename__ = "standard_answer_category"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="categories",
|
||||
)
|
||||
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
|
||||
"SlackBotConfig",
|
||||
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answer_categories",
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(Base):
|
||||
__tablename__ = "standard_answer"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
keyword: Mapped[str] = mapped_column(String)
|
||||
answer: Mapped[str] = mapped_column(String)
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
match_regex: Mapped[bool] = mapped_column(Boolean)
|
||||
match_any_keywords: Mapped[bool] = mapped_column(Boolean)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"unique_keyword_active",
|
||||
keyword,
|
||||
active,
|
||||
unique=True,
|
||||
postgresql_where=(active == True), # noqa: E712
|
||||
),
|
||||
)
|
||||
|
||||
categories: Mapped[list[StandardAnswerCategory]] = relationship(
|
||||
"StandardAnswerCategory",
|
||||
secondary=StandardAnswer__StandardAnswerCategory.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
chat_messages: Mapped[list[ChatMessage]] = relationship(
|
||||
"ChatMessage",
|
||||
secondary=ChatMessage__StandardAnswer.__table__,
|
||||
back_populates="standard_answers",
|
||||
)
|
||||
|
||||
|
||||
"""Tables related to Permission Sync"""
|
||||
|
||||
|
||||
class PermissionSyncStatus(str, PyEnum):
|
||||
IN_PROGRESS = "in_progress"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class PermissionSyncJobType(str, PyEnum):
|
||||
USER_LEVEL = "user_level"
|
||||
GROUP_LEVEL = "group_level"
|
||||
|
||||
|
||||
class PermissionSyncRun(Base):
|
||||
"""Represents one run of a permission sync job. For some given cc_pair, it is either sync-ing
|
||||
the users or it is sync-ing the groups"""
|
||||
|
||||
__tablename__ = "permission_sync_run"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
# Not strictly needed but makes it easy to use without fetching from cc_pair
|
||||
source_type: Mapped[DocumentSource] = mapped_column(
|
||||
Enum(DocumentSource, native_enum=False)
|
||||
)
|
||||
# Currently all sync jobs are handled as a group permission sync or a user permission sync
|
||||
update_type: Mapped[PermissionSyncJobType] = mapped_column(
|
||||
Enum(PermissionSyncJobType)
|
||||
)
|
||||
cc_pair_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id"), nullable=True
|
||||
)
|
||||
status: Mapped[PermissionSyncStatus] = mapped_column(Enum(PermissionSyncStatus))
|
||||
error_msg: Mapped[str | None] = mapped_column(Text, default=None)
|
||||
updated_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
cc_pair: Mapped[ConnectorCredentialPair] = relationship("ConnectorCredentialPair")
|
||||
|
||||
|
||||
class ExternalPermission(Base):
|
||||
class User__ExternalUserGroupId(Base):
|
||||
"""Maps user info both internal and external to the name of the external group
|
||||
This maps the user to all of their external groups so that the external group name can be
|
||||
attached to the ACL list matching during query time. User level permissions can be handled by
|
||||
directly adding the Danswer user to the doc ACL list"""
|
||||
|
||||
__tablename__ = "external_permission"
|
||||
__tablename__ = "user__external_user_group_id"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
# Email is needed because we want to keep track of users not in Danswer to simplify process
|
||||
# when the user joins
|
||||
user_email: Mapped[str] = mapped_column(String)
|
||||
source_type: Mapped[DocumentSource] = mapped_column(
|
||||
Enum(DocumentSource, native_enum=False)
|
||||
)
|
||||
external_permission_group: Mapped[str] = mapped_column(String)
|
||||
user = relationship("User")
|
||||
|
||||
|
||||
class EmailToExternalUserCache(Base):
|
||||
"""A way to map users IDs in the external tool to a user in Danswer or at least an email for
|
||||
when the user joins. Used as a cache for when fetching external groups which have their own
|
||||
user ids, this can easily be mapped back to users already known in Danswer without needing
|
||||
to call external APIs to get the user emails.
|
||||
|
||||
This way when groups are updated in the external tool and we need to update the mapping of
|
||||
internal users to the groups, we can sync the internal users to the external groups they are
|
||||
part of using this.
|
||||
|
||||
Ie. User Chris is part of groups alpha, beta, and we can update this if Chris is no longer
|
||||
part of alpha in some external tool.
|
||||
"""
|
||||
|
||||
__tablename__ = "email_to_external_user_cache"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
external_user_id: Mapped[str] = mapped_column(String)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
# Email is needed because we want to keep track of users not in Danswer to simplify process
|
||||
# when the user joins
|
||||
user_email: Mapped[str] = mapped_column(String)
|
||||
source_type: Mapped[DocumentSource] = mapped_column(
|
||||
Enum(DocumentSource, native_enum=False)
|
||||
)
|
||||
|
||||
user = relationship("User")
|
||||
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
|
||||
# These group ids have been prefixed by the source type
|
||||
external_user_group_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
cc_pair_id: Mapped[int] = mapped_column(ForeignKey("connector_credential_pair.id"))
|
||||
|
||||
|
||||
class UsageReport(Base):
|
||||
@@ -1757,7 +1741,7 @@ class UsageReport(Base):
|
||||
|
||||
# if None, report was auto-generated
|
||||
requestor_user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id"), nullable=True
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from uuid import UUID
|
||||
|
||||
@@ -178,6 +179,7 @@ def create_update_persona(
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to create persona")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return PersonaSnapshot.from_model(persona)
|
||||
|
||||
|
||||
@@ -210,6 +212,22 @@ def update_persona_shared_users(
|
||||
)
|
||||
|
||||
|
||||
def update_persona_public_status(
|
||||
persona_id: int,
|
||||
is_public: bool,
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
) -> None:
|
||||
persona = fetch_persona_by_id(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
|
||||
raise ValueError("You don't have permission to modify this persona")
|
||||
|
||||
persona.is_public = is_public
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def get_prompts(
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
@@ -242,7 +260,7 @@ def get_personas(
|
||||
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
|
||||
|
||||
if not include_default:
|
||||
stmt = stmt.where(Persona.default_persona.is_(False))
|
||||
stmt = stmt.where(Persona.builtin_persona.is_(False))
|
||||
if not include_slack_bot_personas:
|
||||
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
|
||||
if not include_deleted:
|
||||
@@ -290,7 +308,7 @@ def mark_delete_persona_by_name(
|
||||
) -> None:
|
||||
stmt = (
|
||||
update(Persona)
|
||||
.where(Persona.name == persona_name, Persona.default_persona == is_default)
|
||||
.where(Persona.name == persona_name, Persona.builtin_persona == is_default)
|
||||
.values(deleted=True)
|
||||
)
|
||||
|
||||
@@ -390,7 +408,6 @@ def upsert_persona(
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
persona_id: int | None = None,
|
||||
default_persona: bool = False,
|
||||
commit: bool = True,
|
||||
icon_color: str | None = None,
|
||||
icon_shape: int | None = None,
|
||||
@@ -398,6 +415,9 @@ def upsert_persona(
|
||||
display_priority: int | None = None,
|
||||
is_visible: bool = True,
|
||||
remove_image: bool | None = None,
|
||||
search_start_date: datetime | None = None,
|
||||
builtin_persona: bool = False,
|
||||
is_default_persona: bool = False,
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||
) -> Persona:
|
||||
@@ -438,8 +458,8 @@ def upsert_persona(
|
||||
validate_persona_tools(tools)
|
||||
|
||||
if persona:
|
||||
if not default_persona and persona.default_persona:
|
||||
raise ValueError("Cannot update default persona with non-default.")
|
||||
if not builtin_persona and persona.builtin_persona:
|
||||
raise ValueError("Cannot update builtin persona with non-builtin.")
|
||||
|
||||
# this checks if the user has permission to edit the persona
|
||||
persona = fetch_persona_by_id(
|
||||
@@ -454,7 +474,7 @@ def upsert_persona(
|
||||
persona.llm_relevance_filter = llm_relevance_filter
|
||||
persona.llm_filter_extraction = llm_filter_extraction
|
||||
persona.recency_bias = recency_bias
|
||||
persona.default_persona = default_persona
|
||||
persona.builtin_persona = builtin_persona
|
||||
persona.llm_model_provider_override = llm_model_provider_override
|
||||
persona.llm_model_version_override = llm_model_version_override
|
||||
persona.starter_messages = starter_messages
|
||||
@@ -466,6 +486,8 @@ def upsert_persona(
|
||||
persona.uploaded_image_id = uploaded_image_id
|
||||
persona.display_priority = display_priority
|
||||
persona.is_visible = is_visible
|
||||
persona.search_start_date = search_start_date
|
||||
persona.is_default_persona = is_default_persona
|
||||
|
||||
# Do not delete any associations manually added unless
|
||||
# a new updated list is provided
|
||||
@@ -493,7 +515,7 @@ def upsert_persona(
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
default_persona=default_persona,
|
||||
builtin_persona=builtin_persona,
|
||||
prompts=prompts or [],
|
||||
document_sets=document_sets or [],
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
@@ -505,6 +527,8 @@ def upsert_persona(
|
||||
uploaded_image_id=uploaded_image_id,
|
||||
display_priority=display_priority,
|
||||
is_visible=is_visible,
|
||||
search_start_date=search_start_date,
|
||||
is_default_persona=is_default_persona,
|
||||
)
|
||||
db_session.add(persona)
|
||||
|
||||
@@ -534,7 +558,7 @@ def delete_old_default_personas(
|
||||
Need a more graceful fix later or those need to never have IDs"""
|
||||
stmt = (
|
||||
update(Persona)
|
||||
.where(Persona.default_persona, Persona.id > 0)
|
||||
.where(Persona.builtin_persona, Persona.id > 0)
|
||||
.values(deleted=True, name=func.concat(Persona.name, "_old"))
|
||||
)
|
||||
|
||||
@@ -551,6 +575,7 @@ def update_persona_visibility(
|
||||
persona = fetch_persona_by_id(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
persona.is_visible = is_visible
|
||||
db_session.commit()
|
||||
|
||||
@@ -563,13 +588,15 @@ def validate_persona_tools(tools: list[Tool]) -> None:
|
||||
)
|
||||
|
||||
|
||||
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> Sequence[Prompt]:
|
||||
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
|
||||
"""Unsafe, can fetch prompts from all users"""
|
||||
if not prompt_ids:
|
||||
return []
|
||||
prompts = db_session.scalars(select(Prompt).where(Prompt.id.in_(prompt_ids))).all()
|
||||
prompts = db_session.scalars(
|
||||
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
|
||||
).all()
|
||||
|
||||
return prompts
|
||||
return list(prompts)
|
||||
|
||||
|
||||
def get_prompt_by_id(
|
||||
@@ -650,9 +677,7 @@ def get_persona_by_id(
|
||||
result = db_session.execute(persona_stmt)
|
||||
persona = result.scalar_one_or_none()
|
||||
if persona is None:
|
||||
raise ValueError(
|
||||
f"Persona with ID {persona_id} does not exist or does not belong to user"
|
||||
)
|
||||
raise ValueError(f"Persona with ID {persona_id} does not exist")
|
||||
return persona
|
||||
|
||||
# or check if user owns persona
|
||||
@@ -715,7 +740,7 @@ def delete_persona_by_name(
|
||||
persona_name: str, db_session: Session, is_default: bool = True
|
||||
) -> None:
|
||||
stmt = delete(Persona).where(
|
||||
Persona.name == persona_name, Persona.default_persona == is_default
|
||||
Persona.name == persona_name, Persona.builtin_persona == is_default
|
||||
)
|
||||
|
||||
db_session.execute(stmt)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -14,8 +15,11 @@ from danswer.db.models import User
|
||||
from danswer.db.persona import get_default_prompt
|
||||
from danswer.db.persona import mark_persona_as_deleted
|
||||
from danswer.db.persona import upsert_persona
|
||||
from danswer.db.standard_answer import fetch_standard_answer_categories_by_ids
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.utils.errors import EERequiredError
|
||||
from danswer.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
|
||||
|
||||
def _build_persona_name(channel_names: list[str]) -> str:
|
||||
@@ -62,7 +66,7 @@ def create_slack_bot_persona(
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
is_public=True,
|
||||
default_persona=False,
|
||||
is_default_persona=False,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
@@ -70,6 +74,10 @@ def create_slack_bot_persona(
|
||||
return persona
|
||||
|
||||
|
||||
def _no_ee_standard_answer_categories(*args: Any, **kwargs: Any) -> list:
|
||||
return []
|
||||
|
||||
|
||||
def insert_slack_bot_config(
|
||||
persona_id: int | None,
|
||||
channel_config: ChannelConfig,
|
||||
@@ -78,14 +86,29 @@ def insert_slack_bot_config(
|
||||
enable_auto_filters: bool,
|
||||
db_session: Session,
|
||||
) -> SlackBotConfig:
|
||||
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
|
||||
raise ValueError(
|
||||
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
|
||||
versioned_fetch_standard_answer_categories_by_ids = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.standard_answer",
|
||||
"fetch_standard_answer_categories_by_ids",
|
||||
_no_ee_standard_answer_categories,
|
||||
)
|
||||
)
|
||||
existing_standard_answer_categories = (
|
||||
versioned_fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
|
||||
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
|
||||
if len(existing_standard_answer_categories) == 0:
|
||||
raise EERequiredError(
|
||||
"Standard answers are a paid Enterprise Edition feature - enable EE or remove standard answer categories"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
|
||||
)
|
||||
|
||||
slack_bot_config = SlackBotConfig(
|
||||
persona_id=persona_id,
|
||||
@@ -117,9 +140,18 @@ def update_slack_bot_config(
|
||||
f"Unable to find slack bot config with ID {slack_bot_config_id}"
|
||||
)
|
||||
|
||||
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
versioned_fetch_standard_answer_categories_by_ids = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"danswer.db.standard_answer",
|
||||
"fetch_standard_answer_categories_by_ids",
|
||||
_no_ee_standard_answer_categories,
|
||||
)
|
||||
)
|
||||
existing_standard_answer_categories = (
|
||||
versioned_fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=standard_answer_category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
|
||||
raise ValueError(
|
||||
|
||||
@@ -1,202 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import StandardAnswer
|
||||
from danswer.db.models import StandardAnswerCategory
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def check_category_validity(category_name: str) -> bool:
|
||||
"""If a category name is too long, it should not be used (it will cause an error in Postgres
|
||||
as the unique constraint can only apply to entries that are less than 2704 bytes).
|
||||
|
||||
Additionally, extremely long categories are not really usable / useful."""
|
||||
if len(category_name) > 255:
|
||||
logger.error(
|
||||
f"Category with name '{category_name}' is too long, cannot be used"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def insert_standard_answer_category(
|
||||
category_name: str, db_session: Session
|
||||
) -> StandardAnswerCategory:
|
||||
if not check_category_validity(category_name):
|
||||
raise ValueError(f"Invalid category name: {category_name}")
|
||||
standard_answer_category = StandardAnswerCategory(name=category_name)
|
||||
db_session.add(standard_answer_category)
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer_category
|
||||
|
||||
|
||||
def insert_standard_answer(
|
||||
keyword: str,
|
||||
answer: str,
|
||||
category_ids: list[int],
|
||||
match_regex: bool,
|
||||
match_any_keywords: bool,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer:
|
||||
existing_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_categories) != len(category_ids):
|
||||
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
|
||||
|
||||
standard_answer = StandardAnswer(
|
||||
keyword=keyword,
|
||||
answer=answer,
|
||||
categories=existing_categories,
|
||||
active=True,
|
||||
match_regex=match_regex,
|
||||
match_any_keywords=match_any_keywords,
|
||||
)
|
||||
db_session.add(standard_answer)
|
||||
db_session.commit()
|
||||
return standard_answer
|
||||
|
||||
|
||||
def update_standard_answer(
|
||||
standard_answer_id: int,
|
||||
keyword: str,
|
||||
answer: str,
|
||||
category_ids: list[int],
|
||||
match_regex: bool,
|
||||
match_any_keywords: bool,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer:
|
||||
standard_answer = db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
if standard_answer is None:
|
||||
raise ValueError(f"No standard answer with id {standard_answer_id}")
|
||||
|
||||
existing_categories = fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids=category_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
if len(existing_categories) != len(category_ids):
|
||||
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
|
||||
|
||||
standard_answer.keyword = keyword
|
||||
standard_answer.answer = answer
|
||||
standard_answer.categories = list(existing_categories)
|
||||
standard_answer.match_regex = match_regex
|
||||
standard_answer.match_any_keywords = match_any_keywords
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer
|
||||
|
||||
|
||||
def remove_standard_answer(
|
||||
standard_answer_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
standard_answer = db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
if standard_answer is None:
|
||||
raise ValueError(f"No standard answer with id {standard_answer_id}")
|
||||
|
||||
standard_answer.active = False
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_standard_answer_category(
|
||||
standard_answer_category_id: int,
|
||||
category_name: str,
|
||||
db_session: Session,
|
||||
) -> StandardAnswerCategory:
|
||||
standard_answer_category = db_session.scalar(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id == standard_answer_category_id
|
||||
)
|
||||
)
|
||||
if standard_answer_category is None:
|
||||
raise ValueError(
|
||||
f"No standard answer category with id {standard_answer_category_id}"
|
||||
)
|
||||
|
||||
if not check_category_validity(category_name):
|
||||
raise ValueError(f"Invalid category name: {category_name}")
|
||||
|
||||
standard_answer_category.name = category_name
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return standard_answer_category
|
||||
|
||||
|
||||
def fetch_standard_answer_category(
|
||||
standard_answer_category_id: int,
|
||||
db_session: Session,
|
||||
) -> StandardAnswerCategory | None:
|
||||
return db_session.scalar(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id == standard_answer_category_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def fetch_standard_answer_categories_by_ids(
|
||||
standard_answer_category_ids: list[int],
|
||||
db_session: Session,
|
||||
) -> Sequence[StandardAnswerCategory]:
|
||||
return db_session.scalars(
|
||||
select(StandardAnswerCategory).where(
|
||||
StandardAnswerCategory.id.in_(standard_answer_category_ids)
|
||||
)
|
||||
).all()
|
||||
|
||||
|
||||
def fetch_standard_answer_categories(
|
||||
db_session: Session,
|
||||
) -> Sequence[StandardAnswerCategory]:
|
||||
return db_session.scalars(select(StandardAnswerCategory)).all()
|
||||
|
||||
|
||||
def fetch_standard_answer(
|
||||
standard_answer_id: int,
|
||||
db_session: Session,
|
||||
) -> StandardAnswer | None:
|
||||
return db_session.scalar(
|
||||
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
|
||||
)
|
||||
|
||||
|
||||
def fetch_standard_answers(db_session: Session) -> Sequence[StandardAnswer]:
|
||||
return db_session.scalars(
|
||||
select(StandardAnswer).where(StandardAnswer.active.is_(True))
|
||||
).all()
|
||||
|
||||
|
||||
def create_initial_default_standard_answer_category(db_session: Session) -> None:
|
||||
default_category_id = 0
|
||||
default_category_name = "General"
|
||||
default_category = fetch_standard_answer_category(
|
||||
standard_answer_category_id=default_category_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
if default_category is not None:
|
||||
if default_category.name != default_category_name:
|
||||
raise ValueError(
|
||||
"DB is not in a valid initial state. "
|
||||
"Default standard answer category does not have expected name."
|
||||
)
|
||||
return
|
||||
|
||||
standard_answer_category = StandardAnswerCategory(
|
||||
id=default_category_id,
|
||||
name=default_category_name,
|
||||
)
|
||||
db_session.add(standard_answer_category)
|
||||
db_session.commit()
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import Tool
|
||||
from danswer.server.features.tool.models import Header
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -25,6 +26,7 @@ def create_tool(
|
||||
name: str,
|
||||
description: str | None,
|
||||
openapi_schema: dict[str, Any] | None,
|
||||
custom_headers: list[Header] | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> Tool:
|
||||
@@ -33,6 +35,9 @@ def create_tool(
|
||||
description=description,
|
||||
in_code_tool_id=None,
|
||||
openapi_schema=openapi_schema,
|
||||
custom_headers=[header.dict() for header in custom_headers]
|
||||
if custom_headers
|
||||
else [],
|
||||
user_id=user_id,
|
||||
)
|
||||
db_session.add(new_tool)
|
||||
@@ -45,6 +50,7 @@ def update_tool(
|
||||
name: str | None,
|
||||
description: str | None,
|
||||
openapi_schema: dict[str, Any] | None,
|
||||
custom_headers: list[Header] | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> Tool:
|
||||
@@ -60,6 +66,8 @@ def update_tool(
|
||||
tool.openapi_schema = openapi_schema
|
||||
if user_id is not None:
|
||||
tool.user_id = user_id
|
||||
if custom_headers is not None:
|
||||
tool.custom_headers = [header.dict() for header in custom_headers]
|
||||
db_session.commit()
|
||||
|
||||
return tool
|
||||
|
||||
@@ -2,6 +2,7 @@ from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -22,8 +23,23 @@ def list_users(
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def get_users_by_emails(
|
||||
db_session: Session, emails: list[str]
|
||||
) -> tuple[list[User], list[str]]:
|
||||
# Use distinct to avoid duplicates
|
||||
stmt = select(User).filter(User.email.in_(emails)) # type: ignore
|
||||
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
|
||||
found_users_emails = [user.email for user in found_users]
|
||||
missing_user_emails = [email for email in emails if email not in found_users_emails]
|
||||
return found_users, missing_user_emails
|
||||
|
||||
|
||||
def get_user_by_email(email: str, db_session: Session) -> User | None:
|
||||
user = db_session.query(User).filter(User.email == email).first() # type: ignore
|
||||
user = (
|
||||
db_session.query(User)
|
||||
.filter(func.lower(User.email) == func.lower(email))
|
||||
.first()
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@@ -34,20 +50,50 @@ def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None:
|
||||
return user
|
||||
|
||||
|
||||
def add_non_web_user_if_not_exists(email: str, db_session: Session) -> User:
|
||||
user = get_user_by_email(email, db_session)
|
||||
if user is not None:
|
||||
return user
|
||||
|
||||
def _generate_non_web_user(email: str) -> User:
|
||||
fastapi_users_pw_helper = PasswordHelper()
|
||||
password = fastapi_users_pw_helper.generate()
|
||||
hashed_pass = fastapi_users_pw_helper.hash(password)
|
||||
user = User(
|
||||
return User(
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
has_web_login=False,
|
||||
role=UserRole.BASIC,
|
||||
)
|
||||
|
||||
|
||||
def add_non_web_user_if_not_exists(db_session: Session, email: str) -> User:
|
||||
user = get_user_by_email(email, db_session)
|
||||
if user is not None:
|
||||
return user
|
||||
|
||||
user = _generate_non_web_user(email=email)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def add_non_web_user_if_not_exists__no_commit(db_session: Session, email: str) -> User:
|
||||
user = get_user_by_email(email, db_session)
|
||||
if user is not None:
|
||||
return user
|
||||
|
||||
user = _generate_non_web_user(email=email)
|
||||
db_session.add(user)
|
||||
db_session.flush() # generate id
|
||||
return user
|
||||
|
||||
|
||||
def batch_add_non_web_user_if_not_exists__no_commit(
|
||||
db_session: Session, emails: list[str]
|
||||
) -> list[User]:
|
||||
found_users, missing_user_emails = get_users_by_emails(db_session, emails)
|
||||
|
||||
new_users: list[User] = []
|
||||
for email in missing_user_emails:
|
||||
new_users.append(_generate_non_web_user(email=email))
|
||||
|
||||
db_session.add_all(new_users)
|
||||
db_session.flush() # generate ids
|
||||
|
||||
return found_users + new_users
|
||||
|
||||
@@ -177,6 +177,30 @@ class Updatable(abc.ABC):
|
||||
- Whether the document is hidden or not, hidden documents are not returned from search
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_single(self, update_request: UpdateRequest) -> None:
|
||||
"""
|
||||
Updates some set of chunks for a document. The document and fields to update
|
||||
are specified in the update request. Each update request in the list applies
|
||||
its changes to a list of document ids.
|
||||
None values mean that the field does not need an update.
|
||||
|
||||
The rationale for a single update function is that it allows retries and parallelism
|
||||
to happen at a higher / more strategic level, is simpler to read, and allows
|
||||
us to individually handle error conditions per document.
|
||||
|
||||
Parameters:
|
||||
- update_request: for a list of document ids in the update request, apply the same updates
|
||||
to all of the documents with those ids.
|
||||
|
||||
Return:
|
||||
- an HTTPStatus code. The code can used to decide whether to fail immediately,
|
||||
retry, etc. Although this method likely hits an HTTP API behind the
|
||||
scenes, the usage of HTTPStatus is a convenience and the interface is not
|
||||
actually HTTP specific.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, update_requests: list[UpdateRequest]) -> None:
|
||||
"""
|
||||
|
||||
@@ -377,6 +377,91 @@ class VespaIndex(DocumentIndex):
|
||||
time.monotonic() - update_start,
|
||||
)
|
||||
|
||||
def update_single(self, update_request: UpdateRequest) -> None:
|
||||
"""Note: if the document id does not exist, the update will be a no-op and the
|
||||
function will complete with no errors or exceptions.
|
||||
Handle other exceptions if you wish to implement retry behavior
|
||||
"""
|
||||
if len(update_request.document_ids) != 1:
|
||||
raise ValueError("update_request must contain a single document id")
|
||||
|
||||
# Handle Vespa character limitations
|
||||
# Mutating update_request but it's not used later anyway
|
||||
update_request.document_ids = [
|
||||
replace_invalid_doc_id_characters(doc_id)
|
||||
for doc_id in update_request.document_ids
|
||||
]
|
||||
|
||||
# update_start = time.monotonic()
|
||||
|
||||
# Fetch all chunks for each document ahead of time
|
||||
index_names = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
chunk_id_start_time = time.monotonic()
|
||||
all_doc_chunk_ids: list[str] = []
|
||||
for index_name in index_names:
|
||||
for document_id in update_request.document_ids:
|
||||
# this calls vespa and can raise http exceptions
|
||||
doc_chunk_ids = get_all_vespa_ids_for_document_id(
|
||||
document_id=document_id,
|
||||
index_name=index_name,
|
||||
filters=None,
|
||||
get_large_chunks=True,
|
||||
)
|
||||
all_doc_chunk_ids.extend(doc_chunk_ids)
|
||||
logger.debug(
|
||||
f"Took {time.monotonic() - chunk_id_start_time:.2f} seconds to fetch all Vespa chunk IDs"
|
||||
)
|
||||
|
||||
# Build the _VespaUpdateRequest objects
|
||||
update_dict: dict[str, dict] = {"fields": {}}
|
||||
if update_request.boost is not None:
|
||||
update_dict["fields"][BOOST] = {"assign": update_request.boost}
|
||||
if update_request.document_sets is not None:
|
||||
update_dict["fields"][DOCUMENT_SETS] = {
|
||||
"assign": {
|
||||
document_set: 1 for document_set in update_request.document_sets
|
||||
}
|
||||
}
|
||||
if update_request.access is not None:
|
||||
update_dict["fields"][ACCESS_CONTROL_LIST] = {
|
||||
"assign": {acl_entry: 1 for acl_entry in update_request.access.to_acl()}
|
||||
}
|
||||
if update_request.hidden is not None:
|
||||
update_dict["fields"][HIDDEN] = {"assign": update_request.hidden}
|
||||
|
||||
if not update_dict["fields"]:
|
||||
logger.error("Update request received but nothing to update")
|
||||
return
|
||||
|
||||
processed_update_requests: list[_VespaUpdateRequest] = []
|
||||
for document_id in update_request.document_ids:
|
||||
for doc_chunk_id in all_doc_chunk_ids:
|
||||
processed_update_requests.append(
|
||||
_VespaUpdateRequest(
|
||||
document_id=document_id,
|
||||
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
|
||||
update_request=update_dict,
|
||||
)
|
||||
)
|
||||
|
||||
with httpx.Client(http2=True) as http_client:
|
||||
for update in processed_update_requests:
|
||||
http_client.put(
|
||||
update.url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=update.update_request,
|
||||
)
|
||||
|
||||
# logger.debug(
|
||||
# "Finished updating Vespa documents in %.2f seconds",
|
||||
# time.monotonic() - update_start,
|
||||
# )
|
||||
|
||||
return
|
||||
|
||||
def delete(self, doc_ids: list[str]) -> None:
|
||||
logger.info(f"Deleting {len(doc_ids)} documents from Vespa")
|
||||
|
||||
|
||||
@@ -13,8 +13,6 @@ class ChatFileType(str, Enum):
|
||||
DOC = "document"
|
||||
# Plain text only contain the text
|
||||
PLAIN_TEXT = "plain_text"
|
||||
# csv types contain the binary data
|
||||
CSV = "csv"
|
||||
|
||||
|
||||
class FileDescriptor(TypedDict):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import base64
|
||||
from collections.abc import Callable
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
@@ -17,27 +16,6 @@ from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
|
||||
def save_base64_image(base64_image: str) -> str:
|
||||
with get_session_context_manager() as db_session:
|
||||
if base64_image.startswith("data:image"):
|
||||
base64_image = base64_image.split(",", 1)[1]
|
||||
|
||||
image_data = base64.b64decode(base64_image)
|
||||
|
||||
unique_id = str(uuid4())
|
||||
|
||||
file_io = BytesIO(image_data)
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store.save_file(
|
||||
file_name=unique_id,
|
||||
content=file_io,
|
||||
display_name="GeneratedImage",
|
||||
file_origin=FileOrigin.CHAT_IMAGE_GEN,
|
||||
file_type="image/png",
|
||||
)
|
||||
return unique_id
|
||||
|
||||
|
||||
def load_chat_file(
|
||||
file_descriptor: FileDescriptor, db_session: Session
|
||||
) -> InMemoryChatFile:
|
||||
|
||||
@@ -265,7 +265,13 @@ def index_doc_batch(
|
||||
Note that the documents should already be batched at this point so that it does not inflate the
|
||||
memory requirements"""
|
||||
|
||||
no_access = DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
|
||||
no_access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
ctx = index_doc_batch_prepare(
|
||||
document_batch=document_batch,
|
||||
|
||||
@@ -16,7 +16,6 @@ from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
@@ -41,13 +40,11 @@ from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import ToolChoiceOptions
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.tools.analysis.analysis_tool import CSVAnalysisTool
|
||||
from danswer.tools.custom.custom_tool_prompt_builder import (
|
||||
build_user_message_for_custom_tool_for_non_tool_calling_llm,
|
||||
)
|
||||
from danswer.tools.force import filter_tools_for_force_tool_use
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.graphing.graphing_tool import GraphingTool
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
@@ -71,7 +68,6 @@ from danswer.tools.tool_runner import ToolRunner
|
||||
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MAX_TOOL_CALLS
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -165,10 +161,6 @@ class Answer:
|
||||
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
|
||||
self._is_cancelled = False
|
||||
|
||||
self.final_context_docs: list = []
|
||||
self.current_streamed_output: list = []
|
||||
self.processing_stream: list = []
|
||||
|
||||
def _update_prompt_builder_for_search_tool(
|
||||
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
|
||||
) -> None:
|
||||
@@ -187,6 +179,7 @@ class Answer:
|
||||
if self.answer_style_config.citation_config
|
||||
else False
|
||||
),
|
||||
history_message=self.single_message_history or "",
|
||||
)
|
||||
)
|
||||
elif self.answer_style_config.quotes_config:
|
||||
@@ -204,50 +197,41 @@ class Answer:
|
||||
) -> Iterator[
|
||||
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
|
||||
]:
|
||||
tool_calls = 0
|
||||
initiated = False
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
|
||||
while tool_calls < (1 if self.force_use_tool.force_use else MAX_TOOL_CALLS):
|
||||
if initiated:
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
|
||||
initiated = True
|
||||
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
|
||||
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
tool_call_chunk.tool_calls = [
|
||||
{
|
||||
"name": self.force_use_tool.tool_name,
|
||||
"args": self.force_use_tool.args,
|
||||
"id": str(uuid4()),
|
||||
}
|
||||
]
|
||||
|
||||
else:
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
|
||||
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
|
||||
# / need to generate the args
|
||||
tool_call_chunk = AIMessageChunk(
|
||||
content="",
|
||||
)
|
||||
tool_call_chunk.tool_calls = [
|
||||
{
|
||||
"name": self.force_use_tool.tool_name,
|
||||
"args": self.force_use_tool.args,
|
||||
"id": str(uuid4()),
|
||||
}
|
||||
]
|
||||
else:
|
||||
# if tool calling is supported, first try the raw message
|
||||
# to see if we don't need to use any tools
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question,
|
||||
self.prompt_config,
|
||||
self.latest_query_files,
|
||||
tool_calls,
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
final_tool_definitions = [
|
||||
tool.tool_definition()
|
||||
for tool in filter_tools_for_force_tool_use(
|
||||
self.tools, self.force_use_tool
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
]
|
||||
|
||||
final_tool_definitions = [
|
||||
tool.tool_definition()
|
||||
for tool in filter_tools_for_force_tool_use(
|
||||
self.tools, self.force_use_tool
|
||||
)
|
||||
]
|
||||
|
||||
print(final_tool_definitions)
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=final_tool_definitions if final_tool_definitions else None,
|
||||
@@ -274,129 +258,67 @@ class Answer:
|
||||
)
|
||||
|
||||
if not tool_call_chunk:
|
||||
logger.info("Skipped tool call but generated message")
|
||||
return
|
||||
return # no tool call needed
|
||||
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
print(tool_call_requests)
|
||||
# if we have a tool call, we need to call the tool
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
for tool_call_request in tool_call_requests:
|
||||
known_tools_by_name = [
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
for tool_call_request in tool_call_requests:
|
||||
tool_calls += 1
|
||||
known_tools_by_name = [
|
||||
tool
|
||||
for tool in self.tools
|
||||
if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
if self.tools:
|
||||
tool = self.tools[0]
|
||||
else:
|
||||
continue
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
if self.tools:
|
||||
tool = self.tools[0]
|
||||
else:
|
||||
tool = known_tools_by_name[0]
|
||||
continue
|
||||
else:
|
||||
tool = known_tools_by_name[0]
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.tool_name == tool.name
|
||||
and self.force_use_tool.args
|
||||
else tool_call_request["args"]
|
||||
)
|
||||
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.tool_name == tool.name
|
||||
and self.force_use_tool.args
|
||||
else tool_call_request["args"]
|
||||
)
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
yield from tool_runner.tool_responses()
|
||||
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
yield tool_kickoff
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=tool_call_chunk,
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call_request, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
yield from tool_runner.tool_responses()
|
||||
|
||||
tool_responses = []
|
||||
for response in tool_runner.tool_responses():
|
||||
tool_responses.append(response)
|
||||
yield response
|
||||
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=tool_call_chunk,
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call_request, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
self._update_prompt_builder_for_search_tool(prompt_builder, [])
|
||||
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = [
|
||||
img_generation_result["url"]
|
||||
for img_generation_result in tool_runner.tool_final_result().tool_result
|
||||
]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question, img_urls=img_urls
|
||||
)
|
||||
)
|
||||
|
||||
yield tool_runner.tool_final_result()
|
||||
|
||||
# Update message history with tool call and response
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message=self.question,
|
||||
message_type=MessageType.USER,
|
||||
token_count=len(self.llm_tokenizer.encode(self.question)),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
self._update_prompt_builder_for_search_tool(prompt_builder, [])
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = [
|
||||
img_generation_result["url"]
|
||||
for img_generation_result in tool_runner.tool_final_result().tool_result
|
||||
]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question, img_urls=img_urls
|
||||
)
|
||||
)
|
||||
yield tool_runner.tool_final_result()
|
||||
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message=str(tool_call_request),
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=len(
|
||||
self.llm_tokenizer.encode(str(tool_call_request))
|
||||
),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message="\n".join(str(response) for response in tool_responses),
|
||||
message_type=MessageType.SYSTEM,
|
||||
token_count=len(
|
||||
self.llm_tokenizer.encode(
|
||||
"\n".join(str(response) for response in tool_responses)
|
||||
)
|
||||
),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
|
||||
# Generate response based on updated message history
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
tools=[tool.tool_definition() for tool in self.tools],
|
||||
)
|
||||
|
||||
response_content = ""
|
||||
for content in self._process_llm_stream(prompt=prompt, tools=None):
|
||||
if isinstance(content, str):
|
||||
response_content += content
|
||||
yield content
|
||||
|
||||
# Update message history with LLM response
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message=response_content,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=len(self.llm_tokenizer.encode(response_content)),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# This method processes the LLM stream and yields the content or stop information
|
||||
def _process_llm_stream(
|
||||
@@ -425,234 +347,139 @@ class Answer:
|
||||
) -> Iterator[
|
||||
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
|
||||
]:
|
||||
tool_calls = 0
|
||||
initiated = False
|
||||
while tool_calls < (1 if self.force_use_tool.force_use else MAX_TOOL_CALLS):
|
||||
if initiated:
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
chosen_tool_and_args: tuple[Tool, dict] | None = None
|
||||
|
||||
initiated = True
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
chosen_tool_and_args: tuple[Tool, dict] | None = None
|
||||
if self.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = next(
|
||||
iter(
|
||||
[
|
||||
tool
|
||||
for tool in self.tools
|
||||
if tool.name == self.force_use_tool.tool_name
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
|
||||
|
||||
if self.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = next(
|
||||
iter(
|
||||
[
|
||||
tool
|
||||
for tool in self.tools
|
||||
if tool.name == self.force_use_tool.tool_name
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(
|
||||
f"Tool '{self.force_use_tool.tool_name}' not found"
|
||||
)
|
||||
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
llm=self.llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
chosen_tool_and_args = (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=self.tools,
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
llm=self.llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
available_tools_and_args = [
|
||||
(self.tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
chosen_tool_and_args = (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=self.tools,
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
llm=self.llm,
|
||||
)
|
||||
|
||||
available_tools_and_args = [
|
||||
(self.tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=self.message_history,
|
||||
query=self.question,
|
||||
llm=self.llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=self.message_history,
|
||||
query=self.question,
|
||||
llm=self.llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
|
||||
if not chosen_tool_and_args:
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
tools=None,
|
||||
)
|
||||
return
|
||||
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
|
||||
if not chosen_tool_and_args:
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
tools=None,
|
||||
)
|
||||
return
|
||||
tool_calls += 1
|
||||
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
print("tool args")
|
||||
print(tool_args)
|
||||
|
||||
tool_runner = ToolRunner(tool, tool_args, self.llm)
|
||||
yield tool_runner.kickoff()
|
||||
tool_responses = []
|
||||
file_name = tool_runner.args["filename"]
|
||||
|
||||
print(f"file ame is {file_name}")
|
||||
|
||||
csv_file = None
|
||||
for message in self.message_history:
|
||||
if message.files:
|
||||
csv_file = next(
|
||||
(file for file in message.files if file.filename == file_name),
|
||||
None,
|
||||
)
|
||||
if csv_file:
|
||||
break
|
||||
print(self.latest_query_files)
|
||||
if csv_file is None:
|
||||
raise ValueError(
|
||||
f"CSV file with name '{file_name}' not found in latest query files."
|
||||
)
|
||||
print("csv file found")
|
||||
|
||||
tool_runner.args["filename"] = csv_file.content
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
final_context_documents = None
|
||||
for response in tool_runner.tool_responses():
|
||||
tool_responses.append(response)
|
||||
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_context_documents = cast(list[LlmDoc], response.response)
|
||||
yield response
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
final_context_documents = None
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_context_documents = cast(list[LlmDoc], response.response)
|
||||
yield response
|
||||
|
||||
if final_context_documents is None:
|
||||
raise RuntimeError(
|
||||
f"{tool.name} did not return final context documents"
|
||||
)
|
||||
|
||||
self._update_prompt_builder_for_search_tool(
|
||||
prompt_builder, final_context_documents
|
||||
if final_context_documents is None:
|
||||
raise RuntimeError(
|
||||
f"{tool.name} did not return final context documents"
|
||||
)
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], response.response
|
||||
)
|
||||
# img_urls = [img.url for img in img_generation_response]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question,
|
||||
img_urls=[img.url for img in img_generation_response],
|
||||
)
|
||||
)
|
||||
|
||||
yield response
|
||||
elif tool.name == CSVAnalysisTool._NAME:
|
||||
for response in tool_runner.tool_responses():
|
||||
yield response
|
||||
|
||||
elif tool.name == GraphingTool._NAME:
|
||||
for response in tool_runner.tool_responses():
|
||||
print("RESOS")
|
||||
print(response)
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question,
|
||||
# img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
else:
|
||||
prompt_builder.update_user_prompt(
|
||||
HumanMessage(
|
||||
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
|
||||
self.question,
|
||||
tool.name,
|
||||
*tool_runner.tool_responses(),
|
||||
)
|
||||
)
|
||||
)
|
||||
final_result = tool_runner.tool_final_result()
|
||||
yield final_result
|
||||
|
||||
# Update message history
|
||||
self.message_history.extend(
|
||||
[
|
||||
PreviousMessage(
|
||||
message=str(self.question),
|
||||
message_type=MessageType.USER,
|
||||
token_count=len(self.llm_tokenizer.encode(str(self.question))),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
),
|
||||
PreviousMessage(
|
||||
message=f"Tool used: {tool.name}",
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=len(
|
||||
self.llm_tokenizer.encode(f"Tool used: {tool.name}")
|
||||
),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
),
|
||||
PreviousMessage(
|
||||
message=str(final_result),
|
||||
message_type=MessageType.SYSTEM,
|
||||
token_count=len(self.llm_tokenizer.encode(str(final_result))),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
),
|
||||
]
|
||||
self._update_prompt_builder_for_search_tool(
|
||||
prompt_builder, final_context_documents
|
||||
)
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = []
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], response.response
|
||||
)
|
||||
img_urls = [img.url for img in img_generation_response]
|
||||
|
||||
# Generate response based on updated message history
|
||||
prompt = prompt_builder.build()
|
||||
yield response
|
||||
|
||||
response_content = ""
|
||||
for content in self._process_llm_stream(prompt=prompt, tools=None):
|
||||
if isinstance(content, str):
|
||||
response_content += content
|
||||
yield content
|
||||
|
||||
# Update message history with LLM response
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message=response_content,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=len(self.llm_tokenizer.encode(response_content)),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question,
|
||||
img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
else:
|
||||
prompt_builder.update_user_prompt(
|
||||
HumanMessage(
|
||||
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
|
||||
self.question,
|
||||
tool.name,
|
||||
*tool_runner.tool_responses(),
|
||||
)
|
||||
)
|
||||
)
|
||||
final = tool_runner.tool_final_result()
|
||||
|
||||
yield final
|
||||
|
||||
prompt = prompt_builder.build()
|
||||
|
||||
yield from self._process_llm_stream(prompt=prompt, tools=None)
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
@@ -669,8 +496,6 @@ class Answer:
|
||||
else self._raw_output_for_non_explicit_tool_calling_llms()
|
||||
)
|
||||
|
||||
self.processing_stream = []
|
||||
|
||||
def _process_stream(
|
||||
stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo],
|
||||
) -> AnswerStream:
|
||||
@@ -711,70 +536,64 @@ class Answer:
|
||||
|
||||
yield message
|
||||
else:
|
||||
process_answer_stream_fn = _get_answer_stream_processor(
|
||||
context_docs=final_context_docs or [],
|
||||
# if doc selection is enabled, then search_results will be None,
|
||||
# so we need to use the final_context_docs
|
||||
doc_id_to_rank_map=map_document_id_order(
|
||||
search_results or final_context_docs or []
|
||||
),
|
||||
answer_style_configs=self.answer_style_config,
|
||||
)
|
||||
# assumes all tool responses will come first, then the final answer
|
||||
break
|
||||
|
||||
stream_stop_info = None
|
||||
new_kickoff = None
|
||||
if not self.skip_gen_ai_answer_generation:
|
||||
process_answer_stream_fn = _get_answer_stream_processor(
|
||||
context_docs=final_context_docs or [],
|
||||
# if doc selection is enabled, then search_results will be None,
|
||||
# so we need to use the final_context_docs
|
||||
doc_id_to_rank_map=map_document_id_order(
|
||||
search_results or final_context_docs or []
|
||||
),
|
||||
answer_style_configs=self.answer_style_config,
|
||||
)
|
||||
|
||||
def _stream() -> Iterator[str]:
|
||||
nonlocal stream_stop_info
|
||||
nonlocal new_kickoff
|
||||
stream_stop_info = None
|
||||
|
||||
yield cast(str, message)
|
||||
for item in stream:
|
||||
if isinstance(item, StreamStopInfo):
|
||||
stream_stop_info = item
|
||||
return
|
||||
if isinstance(item, ToolCallKickoff):
|
||||
new_kickoff = item
|
||||
return
|
||||
else:
|
||||
yield cast(str, item)
|
||||
def _stream() -> Iterator[str]:
|
||||
nonlocal stream_stop_info
|
||||
yield cast(str, message)
|
||||
for item in stream:
|
||||
if isinstance(item, StreamStopInfo):
|
||||
stream_stop_info = item
|
||||
return
|
||||
|
||||
yield from process_answer_stream_fn(_stream())
|
||||
# this should never happen, but we're seeing weird behavior here so handling for now
|
||||
if not isinstance(item, str):
|
||||
logger.error(
|
||||
f"Received non-string item in answer stream: {item}. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
if stream_stop_info:
|
||||
yield stream_stop_info
|
||||
yield item
|
||||
|
||||
# handle new tool call (continuation of message)
|
||||
if new_kickoff:
|
||||
yield new_kickoff
|
||||
yield from process_answer_stream_fn(_stream())
|
||||
|
||||
if stream_stop_info:
|
||||
yield stream_stop_info
|
||||
|
||||
processed_stream = []
|
||||
for processed_packet in _process_stream(output_generator):
|
||||
if (
|
||||
isinstance(processed_packet, StreamStopInfo)
|
||||
and processed_packet.stop_reason == StreamStopReason.NEW_RESPONSE
|
||||
):
|
||||
self.current_streamed_output = self.processing_stream
|
||||
self.processing_stream = []
|
||||
|
||||
self.processing_stream.append(processed_packet)
|
||||
processed_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
self.current_streamed_output = self.processing_stream
|
||||
self._processed_stream = self.processing_stream
|
||||
|
||||
self._processed_stream = processed_stream
|
||||
|
||||
@property
|
||||
def llm_answer(self) -> str:
|
||||
answer = ""
|
||||
if not self._processed_stream and not self.current_streamed_output:
|
||||
return ""
|
||||
for packet in self.current_streamed_output or self._processed_stream or []:
|
||||
for packet in self.processed_streamed_output:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
|
||||
return answer
|
||||
|
||||
@property
|
||||
def citations(self) -> list[CitationInfo]:
|
||||
citations: list[CitationInfo] = []
|
||||
for packet in self.current_streamed_output:
|
||||
for packet in self.processed_streamed_output:
|
||||
if isinstance(packet, CitationInfo):
|
||||
citations.append(packet)
|
||||
|
||||
@@ -789,4 +608,5 @@ class Answer:
|
||||
if not self.is_connected():
|
||||
logger.debug("Answer stream has been cancelled")
|
||||
self._is_cancelled = not self.is_connected()
|
||||
|
||||
return self._is_cancelled
|
||||
|
||||
@@ -33,7 +33,7 @@ class PreviousMessage(BaseModel):
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
@@ -51,13 +51,14 @@ class PreviousMessage(BaseModel):
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
|
||||
@@ -36,10 +36,7 @@ def default_build_system_message(
|
||||
|
||||
|
||||
def default_build_user_message(
|
||||
user_query: str,
|
||||
prompt_config: PromptConfig,
|
||||
files: list[InMemoryChatFile] = [],
|
||||
previous_tool_calls: int = 0,
|
||||
user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = []
|
||||
) -> HumanMessage:
|
||||
user_prompt = (
|
||||
CHAT_USER_CONTEXT_FREE_PROMPT.format(
|
||||
@@ -48,12 +45,6 @@ def default_build_user_message(
|
||||
if prompt_config.task_prompt
|
||||
else user_query
|
||||
)
|
||||
if previous_tool_calls > 0:
|
||||
user_prompt = (
|
||||
f"You have already generated the above so do not call a tool if not necessary. "
|
||||
f"Remember the query is: `{user_prompt}`"
|
||||
)
|
||||
|
||||
user_prompt = user_prompt.strip()
|
||||
user_msg = HumanMessage(
|
||||
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
|
||||
@@ -122,25 +113,25 @@ class AnswerPromptBuilder:
|
||||
|
||||
final_messages_with_tokens.append(self.user_message_and_token_cnt)
|
||||
|
||||
# if tool_call_summary:
|
||||
# final_messages_with_tokens.append(
|
||||
# (
|
||||
# tool_call_summary.tool_call_request,
|
||||
# check_message_tokens(
|
||||
# tool_call_summary.tool_call_request,
|
||||
# self.llm_tokenizer_encode_func,
|
||||
# ),
|
||||
# )
|
||||
# )
|
||||
# final_messages_with_tokens.append(
|
||||
# (
|
||||
# tool_call_summary.tool_call_result,
|
||||
# check_message_tokens(
|
||||
# tool_call_summary.tool_call_result,
|
||||
# self.llm_tokenizer_encode_func,
|
||||
# ),
|
||||
# )
|
||||
# )
|
||||
if tool_call_summary:
|
||||
final_messages_with_tokens.append(
|
||||
(
|
||||
tool_call_summary.tool_call_request,
|
||||
check_message_tokens(
|
||||
tool_call_summary.tool_call_request,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
)
|
||||
final_messages_with_tokens.append(
|
||||
(
|
||||
tool_call_summary.tool_call_result,
|
||||
check_message_tokens(
|
||||
tool_call_summary.tool_call_result,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
|
||||
@@ -29,6 +29,9 @@ from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_prompt_tokens(prompt_config: PromptConfig) -> int:
|
||||
@@ -156,6 +159,7 @@ def build_citations_user_message(
|
||||
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=question,
|
||||
history_block=history_message,
|
||||
)
|
||||
|
||||
user_prompt = user_prompt.strip()
|
||||
|
||||
@@ -85,6 +85,15 @@ def extract_citations_from_stream(
|
||||
curr_segment += token
|
||||
llm_out += token
|
||||
|
||||
# Handle code blocks without language tags
|
||||
if "`" in curr_segment:
|
||||
if curr_segment.endswith("`"):
|
||||
continue
|
||||
elif "```" in curr_segment:
|
||||
piece_that_comes_after = curr_segment.split("```")[1][0]
|
||||
if piece_that_comes_after == "\n" and in_code_block(llm_out):
|
||||
curr_segment = curr_segment.replace("```", "```plaintext")
|
||||
|
||||
citation_pattern = r"\[(\d+)\]"
|
||||
|
||||
citations_found = list(re.finditer(citation_pattern, curr_segment))
|
||||
|
||||
@@ -29,7 +29,6 @@ from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.interfaces import ToolChoiceOptions
|
||||
from danswer.tools.graphing.models import GRAPHING_RESPONSE_ID
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -99,7 +98,6 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
@@ -126,21 +124,12 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"name": message.name,
|
||||
}
|
||||
elif isinstance(message, ToolMessage):
|
||||
if message.id == GRAPHING_RESPONSE_ID:
|
||||
message_dict = {
|
||||
"tool_call_id": message.tool_call_id,
|
||||
"role": "tool",
|
||||
"name": message.name or "",
|
||||
"content": "a graph",
|
||||
}
|
||||
else:
|
||||
message_dict = {
|
||||
"tool_call_id": message.tool_call_id,
|
||||
"role": "tool",
|
||||
"name": message.name or "",
|
||||
"content": "a graph",
|
||||
}
|
||||
|
||||
message_dict = {
|
||||
"tool_call_id": message.tool_call_id,
|
||||
"role": "tool",
|
||||
"name": message.name or "",
|
||||
"content": message.content,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
|
||||
@@ -24,6 +24,8 @@ class WellKnownLLMProviderDescriptor(BaseModel):
|
||||
|
||||
OPENAI_PROVIDER_NAME = "openai"
|
||||
OPEN_AI_MODEL_NAMES = [
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
"gpt-4",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
|
||||
@@ -47,7 +47,9 @@ if TYPE_CHECKING:
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def litellm_exception_to_error_msg(e: Exception, llm: LLM) -> str:
|
||||
def litellm_exception_to_error_msg(
|
||||
e: Exception, llm: LLM, fallback_to_error_msg: bool = False
|
||||
) -> str:
|
||||
error_msg = str(e)
|
||||
|
||||
if isinstance(e, BadRequestError):
|
||||
@@ -94,7 +96,7 @@ def litellm_exception_to_error_msg(e: Exception, llm: LLM) -> str:
|
||||
error_msg = "Request timed out: The operation took too long to complete. Please try again."
|
||||
elif isinstance(e, APIError):
|
||||
error_msg = f"API error: An error occurred while communicating with the API. Details: {str(e)}"
|
||||
else:
|
||||
elif not fallback_to_error_msg:
|
||||
error_msg = "An unexpected error occurred while processing your request. Please try again later."
|
||||
return error_msg
|
||||
|
||||
@@ -112,7 +114,7 @@ def translate_danswer_msg_to_langchain(
|
||||
content = build_content_with_imgs(msg.message, files)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
return SystemMessage(content=content)
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
if msg.message_type == MessageType.USER:
|
||||
@@ -133,21 +135,6 @@ def translate_history_to_basemessages(
|
||||
return history_basemessages, history_token_counts
|
||||
|
||||
|
||||
def _process_csv_file(file: InMemoryChatFile) -> str:
|
||||
import pandas as pd
|
||||
import io
|
||||
|
||||
df = pd.read_csv(io.StringIO(file.content.decode("utf-8")))
|
||||
csv_preview = df.head().to_string()
|
||||
|
||||
file_name_section = (
|
||||
f"CSV FILE NAME: {file.filename}\n"
|
||||
if file.filename
|
||||
else "CSV FILE (NO NAME PROVIDED):\n"
|
||||
)
|
||||
return f"{file_name_section}{CODE_BLOCK_PAT.format(csv_preview)}\n\n\n"
|
||||
|
||||
|
||||
def _build_content(
|
||||
message: str,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
@@ -158,26 +145,16 @@ def _build_content(
|
||||
if files
|
||||
else None
|
||||
)
|
||||
csv_files = (
|
||||
[file for file in files if file.file_type == ChatFileType.CSV]
|
||||
if files
|
||||
else None
|
||||
)
|
||||
|
||||
if not text_files and not csv_files:
|
||||
if not text_files:
|
||||
return message
|
||||
|
||||
final_message_with_files = "FILES:\n\n"
|
||||
for file in text_files or []:
|
||||
for file in text_files:
|
||||
file_content = file.content.decode("utf-8")
|
||||
file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else ""
|
||||
final_message_with_files += (
|
||||
f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n"
|
||||
)
|
||||
|
||||
for file in csv_files or []:
|
||||
final_message_with_files += _process_csv_file(file)
|
||||
|
||||
final_message_with_files += message
|
||||
|
||||
return final_message_with_files
|
||||
|
||||
@@ -51,7 +51,6 @@ from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.db.document import check_docs_exist
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import init_sqlalchemy_engine
|
||||
from danswer.db.engine import warm_up_connections
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import expire_index_attempts
|
||||
from danswer.db.llm import fetch_default_provider
|
||||
@@ -62,7 +61,6 @@ from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.search_settings import update_current_search_settings
|
||||
from danswer.db.search_settings import update_secondary_search_settings
|
||||
from danswer.db.standard_answer import create_initial_default_standard_answer_category
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
@@ -111,6 +109,8 @@ from danswer.server.query_and_chat.query_backend import (
|
||||
from danswer.server.query_and_chat.query_backend import basic_router as query_router
|
||||
from danswer.server.settings.api import admin_router as settings_admin_router
|
||||
from danswer.server.settings.api import basic_router as settings_router
|
||||
from danswer.server.settings.store import load_settings
|
||||
from danswer.server.settings.store import store_settings
|
||||
from danswer.server.token_rate_limits.api import (
|
||||
router as token_rate_limit_settings_router,
|
||||
)
|
||||
@@ -125,10 +125,10 @@ from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import CORS_ALLOWED_ORIGIN
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -186,9 +186,6 @@ def setup_postgres(db_session: Session) -> None:
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
logger.notice("Verifying default standard answer category exists.")
|
||||
create_initial_default_standard_answer_category(db_session)
|
||||
|
||||
logger.notice("Loading default Prompts and Personas")
|
||||
delete_old_default_personas(db_session)
|
||||
load_chat_yamls()
|
||||
@@ -245,6 +242,12 @@ def update_default_multipass_indexing(db_session: Session) -> None:
|
||||
)
|
||||
update_current_search_settings(db_session, updated_settings)
|
||||
|
||||
# Update settings with GPU availability
|
||||
settings = load_settings()
|
||||
settings.gpu_enabled = gpu_available
|
||||
store_settings(settings)
|
||||
logger.notice(f"Updated settings with GPU availability: {gpu_available}")
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"Existing docs or connectors found. Skipping multipass indexing update."
|
||||
@@ -365,7 +368,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
logger.notice("Generative AI Q&A disabled")
|
||||
|
||||
# fill up Postgres connection pools
|
||||
await warm_up_connections()
|
||||
# await warm_up_connections()
|
||||
|
||||
# We cache this at the beginning so there is no delay in the first telemetry
|
||||
get_or_generate_uuid()
|
||||
@@ -591,7 +594,7 @@ def get_application() -> FastAPI:
|
||||
|
||||
application.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Change this to the list of allowed origins if needed
|
||||
allow_origins=CORS_ALLOWED_ORIGIN, # Configurable via environment variable
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
|
||||
@@ -26,6 +26,7 @@ from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.chat import update_search_docs_table_with_relevance
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompt_by_id
|
||||
from danswer.llm.answering.answer import Answer
|
||||
@@ -60,7 +61,7 @@ from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -118,7 +119,17 @@ def stream_answer_objects(
|
||||
one_shot=True,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
)
|
||||
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona)
|
||||
|
||||
temporary_persona: Persona | None = None
|
||||
if query_req.persona_config is not None:
|
||||
new_persona = create_temporary_persona(
|
||||
db_session=db_session, persona_config=query_req.persona_config, user=user
|
||||
)
|
||||
temporary_persona = new_persona
|
||||
|
||||
persona = temporary_persona if temporary_persona else chat_session.persona
|
||||
|
||||
llm, fast_llm = get_llms_for_persona(persona=persona)
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
@@ -153,11 +164,11 @@ def stream_answer_objects(
|
||||
prompt_id=query_req.prompt_id, user=None, db_session=db_session
|
||||
)
|
||||
if prompt is None:
|
||||
if not chat_session.persona.prompts:
|
||||
if not persona.prompts:
|
||||
raise RuntimeError(
|
||||
"Persona does not have any prompts - this should never happen"
|
||||
)
|
||||
prompt = chat_session.persona.prompts[0]
|
||||
prompt = persona.prompts[0]
|
||||
|
||||
# Create the first User query message
|
||||
new_user_message = create_new_chat_message(
|
||||
@@ -174,9 +185,7 @@ def stream_answer_objects(
|
||||
prompt_config = PromptConfig.from_model(prompt)
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
max_chunks=int(
|
||||
chat_session.persona.num_chunks
|
||||
if chat_session.persona.num_chunks is not None
|
||||
else default_num_chunks
|
||||
persona.num_chunks if persona.num_chunks is not None else default_num_chunks
|
||||
),
|
||||
max_tokens=max_document_tokens,
|
||||
)
|
||||
@@ -187,16 +196,16 @@ def stream_answer_objects(
|
||||
evaluation_type=LLMEvaluationType.SKIP
|
||||
if DISABLE_LLM_DOC_RELEVANCE
|
||||
else query_req.evaluation_type,
|
||||
persona=chat_session.persona,
|
||||
persona=persona,
|
||||
retrieval_options=query_req.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
bypass_acl=bypass_acl,
|
||||
chunks_above=query_req.chunks_above,
|
||||
chunks_below=query_req.chunks_below,
|
||||
full_doc=query_req.full_doc,
|
||||
bypass_acl=bypass_acl,
|
||||
)
|
||||
|
||||
answer_config = AnswerStyleConfig(
|
||||
@@ -209,13 +218,15 @@ def stream_answer_objects(
|
||||
question=query_msg.message,
|
||||
answer_style_config=answer_config,
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona)),
|
||||
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=persona)),
|
||||
single_message_history=history_str,
|
||||
tools=[search_tool],
|
||||
force_use_tool=ForceUseTool(
|
||||
force_use=True,
|
||||
tool_name=search_tool.name,
|
||||
args={"query": rephrased_query},
|
||||
tools=[search_tool] if search_tool else [],
|
||||
force_use_tool=(
|
||||
ForceUseTool(
|
||||
tool_name=search_tool.name,
|
||||
args={"query": rephrased_query},
|
||||
force_use=True,
|
||||
)
|
||||
),
|
||||
# for now, don't use tool calling for this flow, as we haven't
|
||||
# tested quotes with tool calling too much yet
|
||||
@@ -223,9 +234,7 @@ def stream_answer_objects(
|
||||
return_contexts=query_req.return_contexts,
|
||||
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
# won't be any ImageGenerationDisplay responses since that tool is never passed in
|
||||
|
||||
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
|
||||
# for one-shot flow, don't currently do anything with these
|
||||
if isinstance(packet, ToolResponse):
|
||||
@@ -261,6 +270,7 @@ def stream_answer_objects(
|
||||
applied_time_cutoff=search_response_summary.final_filters.time_cutoff,
|
||||
recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
|
||||
)
|
||||
|
||||
yield initial_response
|
||||
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
||||
@@ -287,6 +297,7 @@ def stream_answer_objects(
|
||||
relevance_summary=evaluation_response,
|
||||
)
|
||||
yield evaluation_response
|
||||
|
||||
else:
|
||||
yield packet
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
@@ -8,6 +10,8 @@ from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import ChunkContext
|
||||
from danswer.search.models import RerankingDetails
|
||||
from danswer.search.models import RetrievalDetails
|
||||
@@ -23,10 +27,49 @@ class ThreadMessage(BaseModel):
|
||||
role: MessageType = MessageType.USER
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
include_citations: bool = True
|
||||
datetime_aware: bool = True
|
||||
|
||||
|
||||
class DocumentSetConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class PersonaConfig(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptConfig] = Field(default_factory=list)
|
||||
prompt_ids: list[int] = Field(default_factory=list)
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
tool_ids: list[int] = Field(default_factory=list)
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DirectQARequest(ChunkContext):
|
||||
persona_config: PersonaConfig | None = None
|
||||
persona_id: int | None = None
|
||||
|
||||
messages: list[ThreadMessage]
|
||||
prompt_id: int | None
|
||||
persona_id: int
|
||||
prompt_id: int | None = None
|
||||
multilingual_query_expansion: list[str] | None = None
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
@@ -43,6 +86,12 @@ class DirectQARequest(ChunkContext):
|
||||
# If True, skips generative an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_persona_fields(self) -> "DirectQARequest":
|
||||
if (self.persona_config is None) == (self.persona_id is None):
|
||||
raise ValueError("Exactly one of persona_config or persona_id must be set")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest":
|
||||
if self.chain_of_thought and self.prompt_id is not None:
|
||||
|
||||
@@ -39,46 +39,6 @@ CHAT_USER_CONTEXT_FREE_PROMPT = f"""
|
||||
""".strip()
|
||||
|
||||
|
||||
GRAPHING_QUERY_REPHRASE_GRAPH = f"""
|
||||
Given the following conversation and a follow-up input, generate Python code using matplotlib to create the requested graph.
|
||||
IMPORTANT: The code should be complete and executable, including data generation if necessary.
|
||||
Focus on creating a clear and informative visualization based on the user's request.
|
||||
Guidelines:
|
||||
- Import matplotlib.pyplot as plt at the beginning of the code.
|
||||
- Generate sample data if specific data is not provided in the conversation.
|
||||
- Use appropriate graph types (line, bar, scatter, pie) based on the nature of the data and request.
|
||||
- Include proper labeling (title, x-axis, y-axis, legend) in the graph.
|
||||
- Use plt.figure() to create the figure and assign it to a variable named 'fig'.
|
||||
- Do not include plt.show() at the end of the code.
|
||||
{GENERAL_SEP_PAT}
|
||||
Chat History:
|
||||
{{chat_history}}
|
||||
{GENERAL_SEP_PAT}
|
||||
Follow Up Input: {{question}}
|
||||
Python Code for Graph (Respond with only the Python code to generate the graph):
|
||||
```python
|
||||
# Your code here
|
||||
```
|
||||
""".strip()
|
||||
|
||||
|
||||
GRAPHING_GET_FILE_NAME_PROMPT = f"""
|
||||
Given the following conversation, a follow-up input, and a list of available CSV files,
|
||||
provide the name of the CSV file to analyze.
|
||||
|
||||
{GENERAL_SEP_PAT}
|
||||
Chat History:
|
||||
{{chat_history}}
|
||||
{GENERAL_SEP_PAT}
|
||||
Follow Up Input: {{question}}
|
||||
{GENERAL_SEP_PAT}
|
||||
Available CSV Files:
|
||||
{{file_list}}
|
||||
{GENERAL_SEP_PAT}
|
||||
CSV File Name to Analyze:
|
||||
"""
|
||||
|
||||
|
||||
# Design considerations for the below:
|
||||
# - In case of uncertainty, favor yes search so place the "yes" sections near the start of the
|
||||
# prompt and after the no section as well to deemphasize the no section
|
||||
|
||||
@@ -109,6 +109,9 @@ CITATIONS_PROMPT_FOR_TOOL_CALLING = f"""
|
||||
Refer to the provided context documents when responding to me.{DEFAULT_IGNORE_STATEMENT} \
|
||||
You should always get right to the point, and never use extraneous language.
|
||||
|
||||
CHAT HISTORY:
|
||||
{{history_block}}
|
||||
|
||||
{{task_prompt}}
|
||||
|
||||
{QUESTION_PAT.upper()}
|
||||
|
||||
@@ -84,6 +84,7 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
|
||||
# Multilingual Expansion
|
||||
multilingual_expansion=search_settings.multilingual_expansion,
|
||||
rerank_api_url=search_settings.rerank_api_url,
|
||||
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -67,6 +67,9 @@ def retrieval_preprocessing(
|
||||
]
|
||||
|
||||
time_filter = preset_filters.time_cutoff
|
||||
if time_filter is None and persona:
|
||||
time_filter = persona.search_start_date
|
||||
|
||||
source_filter = preset_filters.source_type
|
||||
|
||||
auto_detect_time_filter = True
|
||||
@@ -154,7 +157,7 @@ def retrieval_preprocessing(
|
||||
final_filters = IndexFilters(
|
||||
source_type=preset_filters.source_type or predicted_source_filters,
|
||||
document_set=preset_filters.document_set,
|
||||
time_cutoff=preset_filters.time_cutoff or predicted_time_cutoff,
|
||||
time_cutoff=time_filter or predicted_time_cutoff,
|
||||
tags=preset_filters.tags, # Tags are never auto-extracted
|
||||
access_control_list=user_acl_filters,
|
||||
)
|
||||
|
||||
@@ -16,8 +16,9 @@ from danswer.db.connector_credential_pair import remove_credential_from_connecto
|
||||
from danswer.db.connector_credential_pair import (
|
||||
update_connector_credential_pair_from_id,
|
||||
)
|
||||
from danswer.db.document import get_document_cnts_for_cc_pairs
|
||||
from danswer.db.document import get_document_counts_for_cc_pairs
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
@@ -95,7 +96,7 @@ def get_cc_pair_full_info(
|
||||
)
|
||||
|
||||
document_count_info_list = list(
|
||||
get_document_cnts_for_cc_pairs(
|
||||
get_document_counts_for_cc_pairs(
|
||||
db_session=db_session,
|
||||
cc_pair_identifiers=[cc_pair_identifier],
|
||||
)
|
||||
@@ -201,7 +202,7 @@ def associate_credential_to_connector(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=metadata.groups,
|
||||
object_is_public=metadata.is_public,
|
||||
object_is_public=metadata.access_type == AccessType.PUBLIC,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -211,7 +212,8 @@ def associate_credential_to_connector(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
cc_pair_name=metadata.name,
|
||||
is_public=True if metadata.is_public is None else metadata.is_public,
|
||||
access_type=metadata.access_type,
|
||||
auto_sync_options=metadata.auto_sync_options,
|
||||
groups=metadata.groups,
|
||||
)
|
||||
|
||||
|
||||
@@ -62,8 +62,9 @@ from danswer.db.credentials import delete_gmail_service_account_credentials
|
||||
from danswer.db.credentials import delete_google_drive_service_account_credentials
|
||||
from danswer.db.credentials import fetch_credential_by_id
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.document import get_document_cnts_for_cc_pairs
|
||||
from danswer.db.document import get_document_counts_for_cc_pairs
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
|
||||
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
@@ -511,7 +512,7 @@ def get_connector_indexing_status(
|
||||
for index_attempt in latest_index_attempts
|
||||
}
|
||||
|
||||
document_count_info = get_document_cnts_for_cc_pairs(
|
||||
document_count_info = get_document_counts_for_cc_pairs(
|
||||
db_session=db_session,
|
||||
cc_pair_identifiers=cc_pair_identifiers,
|
||||
)
|
||||
@@ -559,7 +560,7 @@ def get_connector_indexing_status(
|
||||
cc_pair_status=cc_pair.status,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(connector),
|
||||
credential=CredentialSnapshot.from_credential_db_model(credential),
|
||||
public_doc=cc_pair.is_public,
|
||||
access_type=cc_pair.access_type,
|
||||
owner=credential.user.email if credential.user else "",
|
||||
groups=group_cc_pair_relationships_dict.get(cc_pair.id, []),
|
||||
last_finished_status=(
|
||||
@@ -668,12 +669,15 @@ def create_connector_with_mock_credential(
|
||||
credential = create_credential(
|
||||
mock_credential, user=user, db_session=db_session
|
||||
)
|
||||
access_type = (
|
||||
AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE
|
||||
)
|
||||
response = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
connector_id=cast(int, connector_response.id), # will aways be an int
|
||||
credential_id=credential.id,
|
||||
is_public=connector_data.is_public or False,
|
||||
access_type=access_type,
|
||||
cc_pair_name=connector_data.name,
|
||||
groups=connector_data.groups,
|
||||
)
|
||||
@@ -968,7 +972,7 @@ def get_basic_connector_indexing_status(
|
||||
)
|
||||
for cc_pair in cc_pairs
|
||||
]
|
||||
document_count_info = get_document_cnts_for_cc_pairs(
|
||||
document_count_info = get_document_counts_for_cc_pairs(
|
||||
db_session=db_session,
|
||||
cc_pair_identifiers=cc_pair_identifiers,
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import DocumentErrorSummary
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
@@ -218,7 +219,7 @@ class CCPairFullInfo(BaseModel):
|
||||
number_of_index_attempts: int
|
||||
last_index_attempt_status: IndexingStatus | None
|
||||
latest_deletion_attempt: DeletionAttemptSnapshot | None
|
||||
is_public: bool
|
||||
access_type: AccessType
|
||||
is_editable_for_current_user: bool
|
||||
deletion_failure_message: str | None
|
||||
|
||||
@@ -261,7 +262,7 @@ class CCPairFullInfo(BaseModel):
|
||||
number_of_index_attempts=number_of_index_attempts,
|
||||
last_index_attempt_status=last_indexing_status,
|
||||
latest_deletion_attempt=latest_deletion_attempt,
|
||||
is_public=cc_pair_model.is_public,
|
||||
access_type=cc_pair_model.access_type,
|
||||
is_editable_for_current_user=is_editable_for_current_user,
|
||||
deletion_failure_message=cc_pair_model.deletion_failure_message,
|
||||
)
|
||||
@@ -288,7 +289,7 @@ class ConnectorIndexingStatus(BaseModel):
|
||||
credential: CredentialSnapshot
|
||||
owner: str
|
||||
groups: list[int]
|
||||
public_doc: bool
|
||||
access_type: AccessType
|
||||
last_finished_status: IndexingStatus | None
|
||||
last_status: IndexingStatus | None
|
||||
last_success: datetime | None
|
||||
@@ -306,7 +307,8 @@ class ConnectorCredentialPairIdentifier(BaseModel):
|
||||
|
||||
class ConnectorCredentialPairMetadata(BaseModel):
|
||||
name: str | None = None
|
||||
is_public: bool | None = None
|
||||
access_type: AccessType
|
||||
auto_sync_options: dict[str, Any] | None = None
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
|
||||
|
||||
|
||||
@@ -47,7 +47,6 @@ class DocumentSet(BaseModel):
|
||||
description: str
|
||||
cc_pair_descriptors: list[ConnectorCredentialPairDescriptor]
|
||||
is_up_to_date: bool
|
||||
contains_non_public: bool
|
||||
is_public: bool
|
||||
# For Private Document Sets, who should be able to access these
|
||||
users: list[UUID]
|
||||
@@ -59,12 +58,6 @@ class DocumentSet(BaseModel):
|
||||
id=document_set_model.id,
|
||||
name=document_set_model.name,
|
||||
description=document_set_model.description,
|
||||
contains_non_public=any(
|
||||
[
|
||||
not cc_pair.is_public
|
||||
for cc_pair in document_set_model.connector_credential_pairs
|
||||
]
|
||||
),
|
||||
cc_pair_descriptors=[
|
||||
ConnectorCredentialPairDescriptor(
|
||||
id=cc_pair.id,
|
||||
|
||||
@@ -3,6 +3,7 @@ from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel
|
||||
@@ -20,6 +21,7 @@ from danswer.db.persona import get_personas
|
||||
from danswer.db.persona import mark_persona_as_deleted
|
||||
from danswer.db.persona import mark_persona_as_not_deleted
|
||||
from danswer.db.persona import update_all_personas_display_priority
|
||||
from danswer.db.persona import update_persona_public_status
|
||||
from danswer.db.persona import update_persona_shared_users
|
||||
from danswer.db.persona import update_persona_visibility
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
@@ -43,6 +45,10 @@ class IsVisibleRequest(BaseModel):
|
||||
is_visible: bool
|
||||
|
||||
|
||||
class IsPublicRequest(BaseModel):
|
||||
is_public: bool
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/visible")
|
||||
def patch_persona_visibility(
|
||||
persona_id: int,
|
||||
@@ -58,6 +64,25 @@ def patch_persona_visibility(
|
||||
)
|
||||
|
||||
|
||||
@basic_router.patch("/{persona_id}/public")
|
||||
def patch_user_presona_public_status(
|
||||
persona_id: int,
|
||||
is_public_request: IsPublicRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
update_persona_public_status(
|
||||
persona_id=persona_id,
|
||||
is_public=is_public_request.is_public,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to update persona public status")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.put("/display-priority")
|
||||
def patch_persona_display_priority(
|
||||
display_priority_request: DisplayPriorityRequest,
|
||||
@@ -122,25 +147,6 @@ def upload_file(
|
||||
return {"file_id": file_id}
|
||||
|
||||
|
||||
@admin_router.post("/upload-csv")
|
||||
def upload_csv(
|
||||
file: UploadFile,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_user),
|
||||
) -> dict[str, str]:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_type = ChatFileType.CSV
|
||||
file_id = str(uuid.uuid4())
|
||||
file_store.save_file(
|
||||
file_name=file_id,
|
||||
content=file.file,
|
||||
display_name=file.filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=file.content_type or file_type.value,
|
||||
)
|
||||
return {"file_id": file_id}
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -12,7 +13,6 @@ from danswer.server.features.tool.api import ToolSnapshot
|
||||
from danswer.server.models import MinimalUserSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -38,6 +38,9 @@ class CreatePersonaRequest(BaseModel):
|
||||
icon_shape: int | None = None
|
||||
uploaded_image_id: str | None = None # New field for uploaded image
|
||||
remove_image: bool | None = None
|
||||
is_default_persona: bool = False
|
||||
display_priority: int | None = None
|
||||
search_start_date: datetime | None = None
|
||||
|
||||
|
||||
class PersonaSnapshot(BaseModel):
|
||||
@@ -54,7 +57,7 @@ class PersonaSnapshot(BaseModel):
|
||||
llm_model_provider_override: str | None
|
||||
llm_model_version_override: str | None
|
||||
starter_messages: list[StarterMessage] | None
|
||||
default_persona: bool
|
||||
builtin_persona: bool
|
||||
prompts: list[PromptSnapshot]
|
||||
tools: list[ToolSnapshot]
|
||||
document_sets: list[DocumentSet]
|
||||
@@ -63,6 +66,8 @@ class PersonaSnapshot(BaseModel):
|
||||
icon_color: str | None
|
||||
icon_shape: int | None
|
||||
uploaded_image_id: str | None = None
|
||||
is_default_persona: bool
|
||||
search_start_date: datetime | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -93,7 +98,8 @@ class PersonaSnapshot(BaseModel):
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
starter_messages=persona.starter_messages,
|
||||
default_persona=persona.default_persona,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts],
|
||||
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
|
||||
document_sets=[
|
||||
@@ -108,6 +114,7 @@ class PersonaSnapshot(BaseModel):
|
||||
icon_color=persona.icon_color,
|
||||
icon_shape=persona.icon_shape,
|
||||
uploaded_image_id=persona.uploaded_image_id,
|
||||
search_start_date=persona.search_start_date,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -15,6 +15,8 @@ from danswer.db.tools import delete_tool
|
||||
from danswer.db.tools import get_tool_by_id
|
||||
from danswer.db.tools import get_tools
|
||||
from danswer.db.tools import update_tool
|
||||
from danswer.server.features.tool.models import CustomToolCreate
|
||||
from danswer.server.features.tool.models import CustomToolUpdate
|
||||
from danswer.server.features.tool.models import ToolSnapshot
|
||||
from danswer.tools.custom.openapi_parsing import MethodSpec
|
||||
from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
|
||||
@@ -24,18 +26,6 @@ router = APIRouter(prefix="/tool")
|
||||
admin_router = APIRouter(prefix="/admin/tool")
|
||||
|
||||
|
||||
class CustomToolCreate(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
definition: dict[str, Any]
|
||||
|
||||
|
||||
class CustomToolUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
definition: dict[str, Any] | None = None
|
||||
|
||||
|
||||
def _validate_tool_definition(definition: dict[str, Any]) -> None:
|
||||
try:
|
||||
validate_openapi_schema(definition)
|
||||
@@ -54,6 +44,7 @@ def create_custom_tool(
|
||||
name=tool_data.name,
|
||||
description=tool_data.description,
|
||||
openapi_schema=tool_data.definition,
|
||||
custom_headers=tool_data.custom_headers,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
@@ -74,6 +65,7 @@ def update_custom_tool(
|
||||
name=tool_data.name,
|
||||
description=tool_data.description,
|
||||
openapi_schema=tool_data.definition,
|
||||
custom_headers=tool_data.custom_headers,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ class ToolSnapshot(BaseModel):
|
||||
definition: dict[str, Any] | None
|
||||
display_name: str
|
||||
in_code_tool_id: str | None
|
||||
custom_headers: list[Any] | None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, tool: Tool) -> "ToolSnapshot":
|
||||
@@ -22,4 +23,24 @@ class ToolSnapshot(BaseModel):
|
||||
definition=tool.openapi_schema,
|
||||
display_name=tool.display_name or tool.name,
|
||||
in_code_tool_id=tool.in_code_tool_id,
|
||||
custom_headers=tool.custom_headers,
|
||||
)
|
||||
|
||||
|
||||
class Header(BaseModel):
|
||||
key: str
|
||||
value: str
|
||||
|
||||
|
||||
class CustomToolCreate(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
definition: dict[str, Any]
|
||||
custom_headers: list[Header] | None = None
|
||||
|
||||
|
||||
class CustomToolUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
description: str | None = None
|
||||
definition: dict[str, Any] | None = None
|
||||
custom_headers: list[Header] | None = None
|
||||
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_GEN_AI_KEY_CHECK_TIME
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
@@ -146,7 +147,7 @@ def create_deletion_attempt_for_connector_id(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
from danswer.background.celery.celery_app import (
|
||||
cleanup_connector_credential_pair_task,
|
||||
check_for_connector_deletion_task,
|
||||
)
|
||||
|
||||
connector_id = connector_credential_pair_identifier.connector_id
|
||||
@@ -191,9 +192,10 @@ def create_deletion_attempt_for_connector_id(
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=ConnectorCredentialPairStatus.DELETING,
|
||||
)
|
||||
# actually kick off the deletion
|
||||
cleanup_connector_credential_pair_task.apply_async(
|
||||
kwargs=dict(connector_id=connector_id, credential_id=credential_id),
|
||||
|
||||
# run the beat task to pick up this deletion early
|
||||
check_for_connector_deletion_task.apply_async(
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
if cc_pair.connector.source == DocumentSource.FILE:
|
||||
|
||||
@@ -18,6 +18,9 @@ class TestEmbeddingRequest(BaseModel):
|
||||
api_url: str | None = None
|
||||
model_name: str | None = None
|
||||
|
||||
# This disables the "model_" protected namespace for pydantic
|
||||
model_config = {"protected_namespaces": ()}
|
||||
|
||||
|
||||
class CloudEmbeddingProvider(BaseModel):
|
||||
provider_type: EmbeddingProvider
|
||||
|
||||
@@ -3,6 +3,7 @@ from collections.abc import Callable
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
@@ -17,6 +18,7 @@ from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.factory import get_llm
|
||||
from danswer.llm.llm_provider_options import fetch_available_well_known_llms
|
||||
from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
||||
from danswer.llm.utils import litellm_exception_to_error_msg
|
||||
from danswer.llm.utils import test_llm
|
||||
from danswer.server.manage.llm.models import FullLLMProvider
|
||||
from danswer.server.manage.llm.models import LLMProviderDescriptor
|
||||
@@ -77,7 +79,10 @@ def test_llm_configuration(
|
||||
)
|
||||
|
||||
if error:
|
||||
raise HTTPException(status_code=400, detail=error)
|
||||
client_error_msg = litellm_exception_to_error_msg(
|
||||
error, llm, fallback_to_error_msg=True
|
||||
)
|
||||
raise HTTPException(status_code=400, detail=client_error_msg)
|
||||
|
||||
|
||||
@admin_router.post("/test/default")
|
||||
@@ -118,10 +123,22 @@ def list_llm_providers(
|
||||
@admin_router.put("/provider")
|
||||
def put_llm_provider(
|
||||
llm_provider: LLMProviderUpsertRequest,
|
||||
is_creation: bool = Query(
|
||||
True,
|
||||
description="True if updating an existing provider, False if creating a new one",
|
||||
),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FullLLMProvider:
|
||||
return upsert_llm_provider(llm_provider=llm_provider, db_session=db_session)
|
||||
try:
|
||||
return upsert_llm_provider(
|
||||
llm_provider=llm_provider,
|
||||
db_session=db_session,
|
||||
is_creation=is_creation,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to upsert LLM Provider")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.delete("/provider/{provider_id}")
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -17,13 +15,12 @@ from danswer.db.models import AllowedAnswerFilters
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import SlackBotConfig as SlackBotConfigModel
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import StandardAnswer as StandardAnswerModel
|
||||
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
|
||||
from danswer.db.models import User
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
from danswer.server.models import FullUserSnapshot
|
||||
from danswer.server.models import InvitedUserSnapshot
|
||||
from ee.danswer.server.manage.models import StandardAnswerCategory
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -43,6 +40,9 @@ class AuthTypeResponse(BaseModel):
|
||||
|
||||
class UserPreferences(BaseModel):
|
||||
chosen_assistants: list[int] | None = None
|
||||
hidden_assistants: list[int] = []
|
||||
visible_assistants: list[int] = []
|
||||
|
||||
default_model: str | None = None
|
||||
|
||||
|
||||
@@ -76,6 +76,8 @@ class UserInfo(BaseModel):
|
||||
UserPreferences(
|
||||
chosen_assistants=user.chosen_assistants,
|
||||
default_model=user.default_model,
|
||||
hidden_assistants=user.hidden_assistants,
|
||||
visible_assistants=user.visible_assistants,
|
||||
)
|
||||
),
|
||||
# set to None if TRACK_EXTERNAL_IDP_EXPIRY is False so that we avoid cases
|
||||
@@ -119,95 +121,6 @@ class HiddenUpdateRequest(BaseModel):
|
||||
hidden: bool
|
||||
|
||||
|
||||
class StandardAnswerCategoryCreationRequest(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class StandardAnswerCategory(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, standard_answer_category: StandardAnswerCategoryModel
|
||||
) -> "StandardAnswerCategory":
|
||||
return cls(
|
||||
id=standard_answer_category.id,
|
||||
name=standard_answer_category.name,
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswer(BaseModel):
|
||||
id: int
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[StandardAnswerCategory]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
|
||||
return cls(
|
||||
id=standard_answer_model.id,
|
||||
keyword=standard_answer_model.keyword,
|
||||
answer=standard_answer_model.answer,
|
||||
match_regex=standard_answer_model.match_regex,
|
||||
match_any_keywords=standard_answer_model.match_any_keywords,
|
||||
categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
for standard_answer_category_model in standard_answer_model.categories
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class StandardAnswerCreationRequest(BaseModel):
|
||||
keyword: str
|
||||
answer: str
|
||||
categories: list[int]
|
||||
match_regex: bool
|
||||
match_any_keywords: bool
|
||||
|
||||
@field_validator("categories", mode="before")
|
||||
@classmethod
|
||||
def validate_categories(cls, value: list[int]) -> list[int]:
|
||||
if len(value) < 1:
|
||||
raise ValueError(
|
||||
"At least one category must be attached to a standard answer"
|
||||
)
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_only_match_any_if_not_regex(self) -> Any:
|
||||
if self.match_regex and self.match_any_keywords:
|
||||
raise ValueError(
|
||||
"Can only match any keywords in keyword mode, not regex mode"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_keyword_if_regex(self) -> Any:
|
||||
if not self.match_regex:
|
||||
# no validation for keywords
|
||||
return self
|
||||
|
||||
try:
|
||||
re.compile(self.keyword)
|
||||
return self
|
||||
except re.error as err:
|
||||
if isinstance(err.pattern, bytes):
|
||||
raise ValueError(
|
||||
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
|
||||
)
|
||||
else:
|
||||
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
|
||||
raise ValueError(
|
||||
" ".join(
|
||||
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SlackBotTokens(BaseModel):
|
||||
bot_token: str
|
||||
app_token: str
|
||||
@@ -233,6 +146,7 @@ class SlackBotConfigCreationRequest(BaseModel):
|
||||
# list of user emails
|
||||
follow_up_tags: list[str] | None = None
|
||||
response_type: SlackBotResponseType
|
||||
# XXX this is going away soon
|
||||
standard_answer_categories: list[int] = Field(default_factory=list)
|
||||
|
||||
@field_validator("answer_filters", mode="before")
|
||||
@@ -257,6 +171,7 @@ class SlackBotConfig(BaseModel):
|
||||
persona: PersonaSnapshot | None
|
||||
channel_config: ChannelConfig
|
||||
response_type: SlackBotResponseType
|
||||
# XXX this is going away soon
|
||||
standard_answer_categories: list[StandardAnswerCategory]
|
||||
enable_auto_filters: bool
|
||||
|
||||
@@ -275,6 +190,7 @@ class SlackBotConfig(BaseModel):
|
||||
),
|
||||
channel_config=slack_bot_config_model.channel_config,
|
||||
response_type=slack_bot_config_model.response_type,
|
||||
# XXX this is going away soon
|
||||
standard_answer_categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
for standard_answer_category_model in slack_bot_config_model.standard_answer_categories
|
||||
|
||||
@@ -108,6 +108,7 @@ def create_slack_bot_config(
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=slack_bot_config_creation_request.response_type,
|
||||
# XXX this is going away soon
|
||||
standard_answer_category_ids=slack_bot_config_creation_request.standard_answer_categories,
|
||||
db_session=db_session,
|
||||
enable_auto_filters=slack_bot_config_creation_request.enable_auto_filters,
|
||||
|
||||
@@ -31,13 +31,18 @@ from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import AccessToken
|
||||
from danswer.db.models import DocumentSet__User
|
||||
from danswer.db.models import Persona__User
|
||||
from danswer.db.models import SamlAccount
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.db.users import list_users
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.server.manage.models import AllUsersResponse
|
||||
from danswer.server.manage.models import UserByEmail
|
||||
from danswer.server.manage.models import UserInfo
|
||||
from danswer.server.manage.models import UserPreferences
|
||||
from danswer.server.manage.models import UserRoleResponse
|
||||
from danswer.server.manage.models import UserRoleUpdateRequest
|
||||
from danswer.server.models import FullUserSnapshot
|
||||
@@ -45,6 +50,7 @@ from danswer.server.models import InvitedUserSnapshot
|
||||
from danswer.server.models import MinimalUserSnapshot
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.api_key import is_api_key_email_address
|
||||
from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_commit
|
||||
from ee.danswer.db.user_group import remove_curator_status__no_commit
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -237,10 +243,25 @@ async def delete_user(
|
||||
db_session.expunge(user_to_delete)
|
||||
|
||||
try:
|
||||
# Delete related OAuthAccounts first
|
||||
for oauth_account in user_to_delete.oauth_accounts:
|
||||
db_session.delete(oauth_account)
|
||||
|
||||
delete_user__ext_group_for_user__no_commit(
|
||||
db_session=db_session,
|
||||
user_id=user_to_delete.id,
|
||||
)
|
||||
db_session.query(SamlAccount).filter(
|
||||
SamlAccount.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(DocumentSet__User).filter(
|
||||
DocumentSet__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(User__UserGroup).filter(
|
||||
User__UserGroup.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.delete(user_to_delete)
|
||||
db_session.commit()
|
||||
|
||||
@@ -254,6 +275,10 @@ async def delete_user(
|
||||
|
||||
logger.info(f"Deleted user {user_to_delete.email}")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
full_traceback = traceback.format_exc()
|
||||
logger.error(f"Full stack trace:\n{full_traceback}")
|
||||
db_session.rollback()
|
||||
logger.error(f"Error deleting user {user_to_delete.email}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Error deleting user")
|
||||
@@ -423,3 +448,64 @@ def update_user_assistant_list(
|
||||
.values(chosen_assistants=request.chosen_assistants)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_assistant_list(
|
||||
preferences: UserPreferences, assistant_id: int, show: bool
|
||||
) -> UserPreferences:
|
||||
visible_assistants = preferences.visible_assistants or []
|
||||
hidden_assistants = preferences.hidden_assistants or []
|
||||
chosen_assistants = preferences.chosen_assistants or []
|
||||
|
||||
if show:
|
||||
if assistant_id not in visible_assistants:
|
||||
visible_assistants.append(assistant_id)
|
||||
if assistant_id in hidden_assistants:
|
||||
hidden_assistants.remove(assistant_id)
|
||||
if assistant_id not in chosen_assistants:
|
||||
chosen_assistants.append(assistant_id)
|
||||
else:
|
||||
if assistant_id in visible_assistants:
|
||||
visible_assistants.remove(assistant_id)
|
||||
if assistant_id not in hidden_assistants:
|
||||
hidden_assistants.append(assistant_id)
|
||||
if assistant_id in chosen_assistants:
|
||||
chosen_assistants.remove(assistant_id)
|
||||
|
||||
preferences.visible_assistants = visible_assistants
|
||||
preferences.hidden_assistants = hidden_assistants
|
||||
preferences.chosen_assistants = chosen_assistants
|
||||
return preferences
|
||||
|
||||
|
||||
@router.patch("/user/assistant-list/update/{assistant_id}")
|
||||
def update_user_assistant_visibility(
|
||||
assistant_id: int,
|
||||
show: bool,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if user is None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
store = get_dynamic_config_store()
|
||||
no_auth_user = fetch_no_auth_user(store)
|
||||
preferences = no_auth_user.preferences
|
||||
updated_preferences = update_assistant_list(preferences, assistant_id, show)
|
||||
set_no_auth_user_preferences(store, updated_preferences)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError("This should never happen")
|
||||
|
||||
user_preferences = UserInfo.from_model(user).preferences
|
||||
updated_preferences = update_assistant_list(user_preferences, assistant_id, show)
|
||||
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id) # type: ignore
|
||||
.values(
|
||||
hidden_assistants=updated_preferences.hidden_assistants,
|
||||
visible_assistants=updated_preferences.visible_assistants,
|
||||
chosen_assistants=updated_preferences.chosen_assistants,
|
||||
)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
@@ -164,7 +164,7 @@ def get_chat_session(
|
||||
chat_session_id=session_id,
|
||||
description=chat_session.description,
|
||||
persona_id=chat_session.persona_id,
|
||||
persona_name=chat_session.persona.name,
|
||||
persona_name=chat_session.persona.name if chat_session.persona else None,
|
||||
current_alternate_model=chat_session.current_alternate_model,
|
||||
messages=[
|
||||
translate_db_message_to_chat_message_detail(
|
||||
@@ -281,17 +281,14 @@ async def is_disconnected(request: Request) -> Callable[[], bool]:
|
||||
def is_disconnected_sync() -> bool:
|
||||
future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop)
|
||||
try:
|
||||
result = not future.result(timeout=0.01)
|
||||
return result
|
||||
return not future.result(timeout=0.01)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Asyncio timed out while checking client connection")
|
||||
return True
|
||||
except asyncio.CancelledError:
|
||||
logger.error("Asyncio timed out")
|
||||
return True
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.critical(
|
||||
f"An unexpected error occurred with the disconnect check coroutine: {error_msg}"
|
||||
f"An unexpected error occured with the disconnect check coroutine: {error_msg}"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -520,6 +517,7 @@ def upload_files_for_chat(
|
||||
image_content_types = {"image/jpeg", "image/png", "image/webp"}
|
||||
text_content_types = {
|
||||
"text/plain",
|
||||
"text/csv",
|
||||
"text/markdown",
|
||||
"text/x-markdown",
|
||||
"text/x-config",
|
||||
@@ -529,9 +527,6 @@ def upload_files_for_chat(
|
||||
"text/xml",
|
||||
"application/x-yaml",
|
||||
}
|
||||
csv_content_types = {
|
||||
"text/csv",
|
||||
}
|
||||
document_content_types = {
|
||||
"application/pdf",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
@@ -541,10 +536,8 @@ def upload_files_for_chat(
|
||||
"application/epub+zip",
|
||||
}
|
||||
|
||||
allowed_content_types = (
|
||||
image_content_types.union(text_content_types)
|
||||
.union(document_content_types)
|
||||
.union(csv_content_types)
|
||||
allowed_content_types = image_content_types.union(text_content_types).union(
|
||||
document_content_types
|
||||
)
|
||||
|
||||
for file in files:
|
||||
@@ -552,12 +545,8 @@ def upload_files_for_chat(
|
||||
if file.content_type in image_content_types:
|
||||
error_detail = "Unsupported image file type. Supported image types include .jpg, .jpeg, .png, .webp."
|
||||
elif file.content_type in text_content_types:
|
||||
error_detail = "Unsupported text file type. Supported text types include .txt, .md, .mdx, .conf, "
|
||||
error_detail = "Unsupported text file type. Supported text types include .txt, .csv, .md, .mdx, .conf, "
|
||||
".log, .tsv."
|
||||
elif file.content_type in csv_content_types:
|
||||
error_detail = (
|
||||
"Unsupported csv file type. Supported CSV types include .csv "
|
||||
)
|
||||
else:
|
||||
error_detail = (
|
||||
"Unsupported document file type. Supported document types include .pdf, .docx, .pptx, .xlsx, "
|
||||
@@ -583,8 +572,6 @@ def upload_files_for_chat(
|
||||
file_type = ChatFileType.IMAGE
|
||||
elif file.content_type in document_content_types:
|
||||
file_type = ChatFileType.DOC
|
||||
elif file.content_type in csv_content_types:
|
||||
file_type = ChatFileType.CSV
|
||||
else:
|
||||
file_type = ChatFileType.PLAIN_TEXT
|
||||
|
||||
@@ -597,7 +584,6 @@ def upload_files_for_chat(
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=file.content_type or file_type.value,
|
||||
)
|
||||
print(f"FILE TYPE IS {file_type}")
|
||||
|
||||
# if the file is a doc, extract text and store that so we don't need
|
||||
# to re-extract it every time we send a message
|
||||
|
||||
@@ -136,7 +136,7 @@ class RenameChatSessionResponse(BaseModel):
|
||||
class ChatSessionDetails(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
persona_id: int
|
||||
persona_id: int | None = None
|
||||
time_created: str
|
||||
shared_status: ChatSessionSharedStatus
|
||||
folder_id: int | None = None
|
||||
@@ -178,7 +178,7 @@ class ChatMessageDetail(BaseModel):
|
||||
chat_session_id: int | None = None
|
||||
citations: dict[int, int] | None = None
|
||||
files: list[FileDescriptor]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
@@ -196,8 +196,8 @@ class SearchSessionDetailResponse(BaseModel):
|
||||
class ChatSessionDetailResponse(BaseModel):
|
||||
chat_session_id: int
|
||||
description: str
|
||||
persona_id: int
|
||||
persona_name: str
|
||||
persona_id: int | None = None
|
||||
persona_name: str | None
|
||||
messages: list[ChatMessageDetail]
|
||||
time_created: datetime
|
||||
shared_status: ChatSessionSharedStatus
|
||||
|
||||
@@ -37,6 +37,7 @@ class Settings(BaseModel):
|
||||
search_page_enabled: bool = True
|
||||
default_page: PageType = PageType.SEARCH
|
||||
maximum_chat_retention_days: int | None = None
|
||||
gpu_enabled: bool | None = None
|
||||
|
||||
def check_validity(self) -> None:
|
||||
chat_page_enabled = self.chat_page_enabled
|
||||
|
||||
@@ -1,22 +1,19 @@
|
||||
import base64
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
class DateTimeAndBytesEncoder(json.JSONEncoder):
|
||||
"""Custom JSON encoder that converts datetime objects to ISO format strings and bytes to base64."""
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
"""Custom JSON encoder that converts datetime objects to ISO format strings."""
|
||||
|
||||
def default(self, obj: Any) -> Any:
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
elif isinstance(obj, bytes):
|
||||
return base64.b64encode(obj).decode("utf-8")
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
def get_json_line(
|
||||
json_dict: dict[str, Any], encoder: type[json.JSONEncoder] = DateTimeAndBytesEncoder
|
||||
json_dict: dict[str, Any], encoder: type[json.JSONEncoder] = DateTimeEncoder
|
||||
) -> str:
|
||||
"""
|
||||
Convert a dictionary to a JSON string with datetime handling, and add a newline.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user