Compare commits

..

52 Commits

Author SHA1 Message Date
Weves
5f82de7c45 Debug test 2024-09-23 11:05:27 -07:00
pablodanswer
45f67368a2 Add support for o1 (#2538)
* add o1 support + bump litellm/openai

* ports

* update exception message for testing
2024-09-22 23:16:28 +00:00
pablodanswer
014ba9e220 Begin distinguishing upsert operations for clarity (#2535)
* additional clarity for llm provider creation / updates

* update provider APIs

* update typing (minor)
2024-09-21 22:36:22 +00:00
pablodanswer
ba64543dd7 Updated modals for clarity (#2529)
* udpated modals for clarity

* fix build
2024-09-21 19:55:54 +00:00
pablodanswer
18c62a0c24 Add additional custom tooling configuration (#2426)
* add custom headers

* add tool seeding

* squash

* tmep

* validated

* rm

* update typing

* update alembic

* update import name

* reformat

* alembic
2024-09-20 23:12:52 +00:00
Chris Weaver
33f555922c Fix duplicate users from slack / web (#2530) 2024-09-20 21:51:33 +00:00
pablodanswer
05f6f6d5b5 update default search assistant selection (#2527)
* update default search assistant selection

* update language
2024-09-20 21:21:44 +00:00
hagen-danswer
19dae1d870 Wrote tests for the chat apis (#2525)
* Wrote tests for the chat apis

* slight changes to the case
2024-09-20 19:00:03 +00:00
rkuo-danswer
6d859bd37c try adding build essential (#2526) 2024-09-20 11:51:44 -07:00
pablodanswer
122e3fa3fa Access type (#2523) 2024-09-20 11:16:37 -07:00
pablodanswer
87b542b335 align alembic 2024-09-20 11:13:00 -07:00
pablodanswer
00229d2abe Add start date to persona (#2407)
* add start date to persona

* remove logs

* rename

* update assistant editor

* update alembic

* update alembic

* update alembic

* udpate alembic

* remove rebase artifacts
2024-09-20 16:39:34 +00:00
pablodanswer
5f2644985c Route name (#2520)
* clearer refresh logic

* rename path
2024-09-20 15:44:28 +00:00
pablodanswer
c82a36ad68 Saml account fastapi deletion (#2512)
* saml account fastapi deletion

* update error detail
2024-09-20 00:20:50 +00:00
hagen-danswer
16d1c19d9f Added bool to disable chat_session_id check for search_docs for api 2024-09-19 17:36:46 -07:00
pablodanswer
9f179940f8 Asana connector (community originated) (#2485)
* initial Asana connector

* hint on how to get Asana workspace ID

* re-format with black

* re-order imports

* update asana connector for clarity

* minor robustification

* minor update to naming

* update for best practice

* update connector

---------

Co-authored-by: Daniel Naber <naber@danielnaber.de>
2024-09-19 23:54:18 +00:00
pablodanswer
8a8e2b310e Assistants panel rework (#2509)
* update user model

* squash - update assistant gallery

* rework assistant display logic + ux

* update tool + assistant display

* update a couple function names

* update typing + some logic

* remove unnecessary comments

* finalize functionality

* updated logic

* fully functional

* remove logs + ports

* small update to logic

* update typing

* allow seeding of display priority

* reorder migrations

* update for alembic
2024-09-19 23:36:15 +00:00
hagen-danswer
2274cab554 Added permission syncing (#2340)
* Added permission syncing on the backend

* Rewored to work with celery

alembic fix

fixed test

* frontend changes

* got groups working

* added comments and fixed public docs

* fixed merge issues

* frontend complete!

* frontend cleanup and mypy fixes

* refactored connector access_type selection

* mypy fixes

* minor refactor and frontend improvements

* get to fetch

* renames and comments

* minor change to var names

* got curator stuff working

* addressed pablo's comments

* refactored user_external_group to reference users table

* implemented polling

* small refactor

* fixed a whoopsies on the frontend

* added scripts to seed dummy docs and test query times

* fixed frontend build issue

* alembic fix

* handled is_public overlap

* yuhong feedback

* added more checks for sync

* black

* mypy

* fixed circular import

* todos

* alembic fix

* alembic
2024-09-19 22:07:36 +00:00
pablodanswer
ef104e9a82 Non-spotfix deletion of users (#2499)
* add description / robustify

* additional minor robustification (ideally we organized cascades slightly better)

* update deletion for simplicity

* minor typing update
2024-09-19 20:02:36 +00:00
hagen-danswer
a575d7f1eb Citations prompt for slack now includes thread history (#2510) 2024-09-19 19:31:26 +00:00
pablodanswer
f404c4b448 Move code block default language creation to citation processing (#2501)
* move code block default language creation to citaiton processing

* add test cases

* update copy
2024-09-19 06:00:58 +00:00
rkuo-danswer
3884f1d70a Bugfix/larger test runner (#2508)
* add pip retries to the github workflows too

* let's try running on amd64 ... docker builds are unusually flaky

* bump

* try large

* no yaml anchors

* switch back down to Amd64

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-19 05:36:07 +00:00
rkuo-danswer
bc9d5fece7 prevent trying to submit to jobclient when it can't take any more work (reduces log spam) (#2482) 2024-09-19 04:01:15 +00:00
rkuo-danswer
bb279a8580 add pip retries. should help with github's occasional flaky network during build/test (#2506) 2024-09-19 00:46:41 +00:00
pablodanswer
a9403016c9 fix basic auth (#2505) 2024-09-18 22:45:58 +00:00
hagen-danswer
f3cea79c1c Deleting a connector should redirect to the indexing status page (#2504)
* Deleting a connector should redirect to the indexing status page

* minor update to dev background jobs

* update refresh logic

* remove print statement

---------

Co-authored-by: pablodanswer <pablo@danswer.ai>
2024-09-18 21:38:35 +00:00
hagen-danswer
54bb79303c corrected error message (#2502) 2024-09-18 19:13:28 +00:00
pablodanswer
d3dfabb20e fix parentheses (#2486) 2024-09-18 18:39:23 +00:00
pablodanswer
7d1ec1095c proper z index for chat bubbles (#2500) 2024-09-18 18:02:50 +00:00
rkuo-danswer
f531d071af Feature/background deletion (#2337)
* first cut at redis

* some new helper functions for the db

* ignore kombu tables in alembic migrations (used by celery)

* multiline commands for readability, add vespa_metadata_sync queue to worker

* typo fix

* fix returning tuple fields

* add constants

* fix _get_access_for_document

* docstrings!

* fix double function declaration and typing

* fix type hinting

* add a global redis pool

* Add get_document function

* use task_logger in various celery tasks

* add celeryconfig.py to simplify configuration. Will be used in a subsequent commit

* Add celery redis helper. used in a subsequent PR

* kombu warning getting spammy since celery is not self managing its queue in Postgres any more

* add last_modified and last_synced to documents

* fix task naming convention

* use celeryconfig.py

* the big one. adds queues and tasks, updates functions to use the queues with priorities, etc

* change vespa index log line to debug

* mypy fixes

* update alembic migration

* fix fence ordering, rename to "monitor", fix fetch_versioned_implementation call

* mypy

* switch to monotonic time

* fix startup dependencies on redis

* rebase alembic migration

* kombu cleanup - fail silently

* mypy

* add redis_host environment override

* update REDIS_HOST env var in docker-compose.dev.yml

* update the rest of the docker files

* in flight

* harden indexing-status endpoint against db changes happening in the background.  Needs further improvement but OK for now.

* allow no task syncs to run because we create certain objects with no entries but initially marked as out of date

* add back writing to vespa on indexing

* actually working connector deletion

* update contributing guide

* backporting fixes from background_deletion

* renaming cache to cache_volume

* add redis password to various deployments

* try setting up pr testing for helm

* fix indent

* hopefully this release version actually exists

* fix command line option to --chart-dirs

* fetch-depth 0

* edit values.yaml

* try setting ct working directory

* bypass testing only on change for now

* move files and lint them

* update helm testing

* some issues suggest using --config works

* add vespa repo

* add postgresql repo

* increase timeout

* try amd64 runner

* fix redis password reference

* add comment to helm chart testing workflow

* rename helm testing workflow to disable it

* adding clarifying comments

* address code review

* missed a file

* remove commented warning ... just not needed

* fix imports

* refactor to use update_single

* mypy fixes

* add vespa test

* add db refresh to connector deletion

* code review fixes

* move monitor_usergroup_taskset to ee, improve logging

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-18 16:50:11 +00:00
Chris Weaver
4218814385 Add flow to query history CSV (#2492) 2024-09-18 14:23:56 +00:00
rkuo-danswer
e662e3b57d clarify ssl cert reqs (#2494)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-18 05:35:57 +00:00
pablodanswer
2073820e33 Update default assistants to all visible (#2490)
* update default assistants to all visible

* update with catch-all

* minor update

* update
2024-09-18 02:08:11 +00:00
Chris Weaver
5f25b243c5 Add back llm_chunks_indices (#2491) 2024-09-18 01:21:31 +00:00
pablodanswer
a9427f190a Extend time range (contributor submission) (#2484)
* added new options for time range; removed duplicated code

* refactor + remove unused code

---------

Co-authored-by: Zoltan Szabo <zoltan.szabo@eaudeweb.ro>
2024-09-17 22:36:25 +00:00
pablodanswer
18fbe9d7e8 Warn users of gpu-sensitive operation (#2488)
* warn users of gpu-sensitive operation

* update copy
2024-09-17 21:59:43 +00:00
Chris Weaver
75c9b1cafe Fix concatenate string with toolcallkickoff issue (#2487) 2024-09-17 21:25:06 +00:00
rkuo-danswer
632a8f700b Feature/celery backend db number (#2475)
* use separate database number for celery result backend

* add comments

* add env var for celery's result_expires

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-09-17 21:06:36 +00:00
pablodanswer
cd58c96014 Memoize AI message component (#2483)
* memoize AI message component

* rename memoized file

* remove "zz"

* update name

* memoize for coverage

* add display name
2024-09-17 18:47:23 +00:00
pablodanswer
c5032d25c9 Minor clarity update for connectors (#2480) 2024-09-17 10:25:39 -07:00
pablodanswer
72acde6fd4 Handle tool errors in display properly (can show valueError to user) (#2481)
* handle tool errors in display properly (can show valueerrors to user)

* update for clarity
2024-09-17 17:08:46 +00:00
rkuo-danswer
5596a68d08 harden migration (#2476)
* harden migration

* remove duplicate line
2024-09-17 16:44:53 +00:00
Weves
5b18409c89 Change user-message to user-prompt 2024-09-16 21:53:27 -07:00
Chris Weaver
84272af5ac Add back scrolling to ExceptionTraceModal (#2473) 2024-09-17 02:25:53 +00:00
pablodanswer
6bef70c8b7 ensure disabled gets propagated 2024-09-16 19:27:31 -07:00
pablodanswer
7f7559e3d2 Allow users to share assistants (#2434)
* enable assistant sharing

* functional

* remove logs

* revert ports

* remove accidental update

* minor updates to copy

* update formatting

* update for merge queue
2024-09-17 01:35:29 +00:00
Chris Weaver
7ba829a585 Add top_documents to APIs (#2469)
* Add top_documents

* Fix test

---------

Co-authored-by: hagen-danswer <hagen@danswer.ai>
2024-09-16 23:48:33 +00:00
trial-danswer
8b2ecb4eab EE movement followup for Standard Answers (#2467)
* Move StandardAnswer to EE section of danswer/db/models

* Move StandardAnswer DB layer to EE

* Add EERequiredError for distinct error handling here

* Handle EE fallback for slack bot config

* Migrate all standard answer models to ee

* Flagging categories for removal

* Add missing versioned impl for update_slack_bot_config

---------

Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
2024-09-16 22:05:53 +00:00
pablodanswer
2dd3870504 Add ability to specify persona in API request (#2302)
* persona

* all prepared excluding configuration

* more sensical model structure

* update tstream

* type updates

* rm

* quick and simple updates

* minor updates

* te

* ensure typing + naming

* remove old todo + rebase update

* remove unnecessary check
2024-09-16 21:31:01 +00:00
pablodanswer
df464fc54b Allow for CORS Origin Setting (#2449)
* allow setting of CORS origin

* simplify

* add environment variable + rename

* slightly more efficient

* simplify so mypy doens't complain

* temp

* go back to my preferred formatting
2024-09-16 18:54:36 +00:00
pablodanswer
96b98fbc4a Make it impossible to switch to non-image (#2440)
* make it impossible to switch to non-image

* revert ports

* proper provider support

* remove unused imports

* minor rename

* simplify interface

* remove logs
2024-09-16 18:35:40 +00:00
trial-danswer
66cf67d04d hotfix: sqlalchemy default -> server_default (#2442)
Co-authored-by: danswer-trial <danswer-trial@danswer-trials-MacBook-Pro.local>
2024-09-16 17:49:01 +00:00
292 changed files with 7722 additions and 7869 deletions

View File

@@ -27,6 +27,11 @@ jobs:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Install build-essential
run: |
sudo apt-get update
sudo apt-get install -y build-essential
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v5
with:

View File

@@ -37,9 +37,9 @@ jobs:
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt
pip install -r backend/requirements/model_server.txt
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1

View File

@@ -24,9 +24,9 @@ jobs:
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt
pip install -r backend/requirements/model_server.txt
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
- name: Run MyPy
run: |

View File

@@ -39,8 +39,8 @@ jobs:
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"

View File

@@ -29,8 +29,8 @@ jobs:
- name: Install Dependencies
run: |
python -m pip install --upgrade pip
pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"

View File

@@ -13,8 +13,7 @@ env:
jobs:
integration-tests:
runs-on:
group: 'arm64-image-builders'
runs-on: Amd64
steps:
- name: Checkout code
uses: actions/checkout@v4
@@ -41,7 +40,7 @@ jobs:
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/arm64
platforms: linux/amd64
tags: danswer/danswer-backend:it
cache-from: type=registry,ref=danswer/danswer-backend:it
cache-to: |
@@ -53,7 +52,7 @@ jobs:
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/arm64
platforms: linux/amd64
tags: danswer/danswer-model-server:it
cache-from: type=registry,ref=danswer/danswer-model-server:it
cache-to: |
@@ -65,7 +64,7 @@ jobs:
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/arm64
platforms: linux/amd64
tags: danswer/integration-test-runner:it
cache-from: type=registry,ref=danswer/integration-test-runner:it
cache-to: |

View File

@@ -41,6 +41,8 @@ RUN apt-get update && \
COPY ./requirements/default.txt /tmp/requirements.txt
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
RUN pip install --no-cache-dir --upgrade \
--retries 5 \
--timeout 30 \
-r /tmp/requirements.txt \
-r /tmp/ee-requirements.txt && \
pip uninstall -y py && \

View File

@@ -15,7 +15,10 @@ ENV DANSWER_VERSION=${DANSWER_VERSION} \
RUN echo "DANSWER_VERSION: ${DANSWER_VERSION}"
COPY ./requirements/model_server.txt /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade \
--retries 5 \
--timeout 30 \
-r /tmp/requirements.txt
RUN apt-get remove -y --allow-remove-essential perl-base && \
apt-get autoremove -y

Binary file not shown.

Before

Width:  |  Height:  |  Size: 27 KiB

View File

@@ -0,0 +1,102 @@
"""add_user_delete_cascades
Revision ID: 1b8206b29c5d
Revises: 35e6853a51d5
Create Date: 2024-09-18 11:48:59.418726
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "1b8206b29c5d"
down_revision = "35e6853a51d5"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_constraint("credential_user_id_fkey", "credential", type_="foreignkey")
op.create_foreign_key(
"credential_user_id_fkey",
"credential",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.drop_constraint("chat_session_user_id_fkey", "chat_session", type_="foreignkey")
op.create_foreign_key(
"chat_session_user_id_fkey",
"chat_session",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.drop_constraint("chat_folder_user_id_fkey", "chat_folder", type_="foreignkey")
op.create_foreign_key(
"chat_folder_user_id_fkey",
"chat_folder",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.drop_constraint("prompt_user_id_fkey", "prompt", type_="foreignkey")
op.create_foreign_key(
"prompt_user_id_fkey", "prompt", "user", ["user_id"], ["id"], ondelete="CASCADE"
)
op.drop_constraint("notification_user_id_fkey", "notification", type_="foreignkey")
op.create_foreign_key(
"notification_user_id_fkey",
"notification",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
op.drop_constraint("inputprompt_user_id_fkey", "inputprompt", type_="foreignkey")
op.create_foreign_key(
"inputprompt_user_id_fkey",
"inputprompt",
"user",
["user_id"],
["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
op.drop_constraint("credential_user_id_fkey", "credential", type_="foreignkey")
op.create_foreign_key(
"credential_user_id_fkey", "credential", "user", ["user_id"], ["id"]
)
op.drop_constraint("chat_session_user_id_fkey", "chat_session", type_="foreignkey")
op.create_foreign_key(
"chat_session_user_id_fkey", "chat_session", "user", ["user_id"], ["id"]
)
op.drop_constraint("chat_folder_user_id_fkey", "chat_folder", type_="foreignkey")
op.create_foreign_key(
"chat_folder_user_id_fkey", "chat_folder", "user", ["user_id"], ["id"]
)
op.drop_constraint("prompt_user_id_fkey", "prompt", type_="foreignkey")
op.create_foreign_key("prompt_user_id_fkey", "prompt", "user", ["user_id"], ["id"])
op.drop_constraint("notification_user_id_fkey", "notification", type_="foreignkey")
op.create_foreign_key(
"notification_user_id_fkey", "notification", "user", ["user_id"], ["id"]
)
op.drop_constraint("inputprompt_user_id_fkey", "inputprompt", type_="foreignkey")
op.create_foreign_key(
"inputprompt_user_id_fkey", "inputprompt", "user", ["user_id"], ["id"]
)

View File

@@ -0,0 +1,64 @@
"""server default chosen assistants
Revision ID: 35e6853a51d5
Revises: c99d76fcd298
Create Date: 2024-09-13 13:20:32.885317
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "35e6853a51d5"
down_revision = "c99d76fcd298"
branch_labels = None
depends_on = None
DEFAULT_ASSISTANTS = [-2, -1, 0]
def upgrade() -> None:
# Step 1: Update any NULL values to the default value
# This upgrades existing users without ordered assistant
# to have default assistants set to visible assistants which are
# accessible by them.
op.execute(
"""
UPDATE "user" u
SET chosen_assistants = (
SELECT jsonb_agg(
p.id ORDER BY
COALESCE(p.display_priority, 2147483647) ASC,
p.id ASC
)
FROM persona p
LEFT JOIN persona__user pu ON p.id = pu.persona_id AND pu.user_id = u.id
WHERE p.is_visible = true
AND (p.is_public = true OR pu.user_id IS NOT NULL)
)
WHERE chosen_assistants IS NULL
OR chosen_assistants = 'null'
OR jsonb_typeof(chosen_assistants) = 'null'
OR (jsonb_typeof(chosen_assistants) = 'string' AND chosen_assistants = '"null"')
"""
)
# Step 2: Alter the column to make it non-nullable
op.alter_column(
"user",
"chosen_assistants",
type_=postgresql.JSONB(astext_type=sa.Text()),
nullable=False,
server_default=sa.text(f"'{DEFAULT_ASSISTANTS}'::jsonb"),
)
def downgrade() -> None:
op.alter_column(
"user",
"chosen_assistants",
type_=postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
server_default=None,
)

View File

@@ -1,65 +0,0 @@
"""single tool call per message
Revision ID: 4e8e7ae58189
Revises: 5c7fdadae813
Create Date: 2024-09-09 10:07:58.008838
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4e8e7ae58189"
down_revision = "5c7fdadae813"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create the new column
op.add_column(
"chat_message", sa.Column("tool_call_id", sa.Integer(), nullable=True)
)
op.create_foreign_key(
"fk_chat_message_tool_call",
"chat_message",
"tool_call",
["tool_call_id"],
["id"],
)
# Migrate existing data
op.execute(
"UPDATE chat_message SET tool_call_id = (SELECT id FROM tool_call WHERE tool_call.message_id = chat_message.id LIMIT 1)"
)
# Drop the old relationship
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
op.drop_column("tool_call", "message_id")
# Add a unique constraint to ensure one-to-one relationship
op.create_unique_constraint(
"uq_chat_message_tool_call_id", "chat_message", ["tool_call_id"]
)
def downgrade() -> None:
# Add back the old column
op.add_column(
"tool_call",
sa.Column("message_id", sa.INTEGER(), autoincrement=False, nullable=True),
)
op.create_foreign_key(
"tool_call_message_id_fkey", "tool_call", "chat_message", ["message_id"], ["id"]
)
# Migrate data back
op.execute(
"UPDATE tool_call SET message_id = (SELECT id FROM chat_message WHERE chat_message.tool_call_id = tool_call.id)"
)
# Drop the new column
op.drop_constraint("fk_chat_message_tool_call", "chat_message", type_="foreignkey")
op.drop_column("chat_message", "tool_call_id")

View File

@@ -1,7 +1,7 @@
"""Add last synced and last modified to document table
Revision ID: 52a219fb5233
Revises: f17bf3b0d9f1
Revises: f7e58d357687
Create Date: 2024-08-28 17:40:46.077470
"""

View File

@@ -0,0 +1,79 @@
"""assistant_rework
Revision ID: 55546a7967ee
Revises: 61ff3651add4
Create Date: 2024-09-18 17:00:23.755399
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "55546a7967ee"
down_revision = "61ff3651add4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Reworking persona and user tables for new assistant features
# keep track of user's chosen assistants separate from their `ordering`
op.add_column("persona", sa.Column("builtin_persona", sa.Boolean(), nullable=True))
op.execute("UPDATE persona SET builtin_persona = default_persona")
op.alter_column("persona", "builtin_persona", nullable=False)
op.drop_index("_default_persona_name_idx", table_name="persona")
op.create_index(
"_builtin_persona_name_idx",
"persona",
["name"],
unique=True,
postgresql_where=sa.text("builtin_persona = true"),
)
op.add_column(
"user", sa.Column("visible_assistants", postgresql.JSONB(), nullable=True)
)
op.add_column(
"user", sa.Column("hidden_assistants", postgresql.JSONB(), nullable=True)
)
op.execute(
"UPDATE \"user\" SET visible_assistants = '[]'::jsonb, hidden_assistants = '[]'::jsonb"
)
op.alter_column(
"user",
"visible_assistants",
nullable=False,
server_default=sa.text("'[]'::jsonb"),
)
op.alter_column(
"user",
"hidden_assistants",
nullable=False,
server_default=sa.text("'[]'::jsonb"),
)
op.drop_column("persona", "default_persona")
op.add_column(
"persona", sa.Column("is_default_persona", sa.Boolean(), nullable=True)
)
def downgrade() -> None:
# Reverting changes made in upgrade
op.drop_column("user", "hidden_assistants")
op.drop_column("user", "visible_assistants")
op.drop_index("_builtin_persona_name_idx", table_name="persona")
op.drop_column("persona", "is_default_persona")
op.add_column("persona", sa.Column("default_persona", sa.Boolean(), nullable=True))
op.execute("UPDATE persona SET default_persona = builtin_persona")
op.alter_column("persona", "default_persona", nullable=False)
op.drop_column("persona", "builtin_persona")
op.create_index(
"_default_persona_name_idx",
"persona",
["name"],
unique=True,
postgresql_where=sa.text("default_persona = true"),
)

View File

@@ -0,0 +1,162 @@
"""Add Permission Syncing
Revision ID: 61ff3651add4
Revises: 1b8206b29c5d
Create Date: 2024-09-05 13:57:11.770413
"""
import fastapi_users_db_sqlalchemy
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "61ff3651add4"
down_revision = "1b8206b29c5d"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Admin user who set up connectors will lose access to the docs temporarily
# only way currently to give back access is to rerun from beginning
op.add_column(
"connector_credential_pair",
sa.Column(
"access_type",
sa.String(),
nullable=True,
),
)
op.execute(
"UPDATE connector_credential_pair SET access_type = 'PUBLIC' WHERE is_public = true"
)
op.execute(
"UPDATE connector_credential_pair SET access_type = 'PRIVATE' WHERE is_public = false"
)
op.alter_column("connector_credential_pair", "access_type", nullable=False)
op.add_column(
"connector_credential_pair",
sa.Column(
"auto_sync_options",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
op.add_column(
"connector_credential_pair",
sa.Column("last_time_perm_sync", sa.DateTime(timezone=True), nullable=True),
)
op.drop_column("connector_credential_pair", "is_public")
op.add_column(
"document",
sa.Column("external_user_emails", postgresql.ARRAY(sa.String()), nullable=True),
)
op.add_column(
"document",
sa.Column(
"external_user_group_ids", postgresql.ARRAY(sa.String()), nullable=True
),
)
op.add_column(
"document",
sa.Column("is_public", sa.Boolean(), nullable=True),
)
op.create_table(
"user__external_user_group_id",
sa.Column(
"user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False
),
sa.Column("external_user_group_id", sa.String(), nullable=False),
sa.Column("cc_pair_id", sa.Integer(), nullable=False),
sa.PrimaryKeyConstraint("user_id"),
)
op.drop_column("external_permission", "user_id")
op.drop_column("email_to_external_user_cache", "user_id")
op.drop_table("permission_sync_run")
op.drop_table("external_permission")
op.drop_table("email_to_external_user_cache")
def downgrade() -> None:
op.add_column(
"connector_credential_pair",
sa.Column("is_public", sa.BOOLEAN(), nullable=True),
)
op.execute(
"UPDATE connector_credential_pair SET is_public = (access_type = 'PUBLIC')"
)
op.alter_column("connector_credential_pair", "is_public", nullable=False)
op.drop_column("connector_credential_pair", "auto_sync_options")
op.drop_column("connector_credential_pair", "access_type")
op.drop_column("connector_credential_pair", "last_time_perm_sync")
op.drop_column("document", "external_user_emails")
op.drop_column("document", "external_user_group_ids")
op.drop_column("document", "is_public")
op.drop_table("user__external_user_group_id")
# Drop the enum type at the end of the downgrade
op.create_table(
"permission_sync_run",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"source_type",
sa.String(),
nullable=False,
),
sa.Column("update_type", sa.String(), nullable=False),
sa.Column("cc_pair_id", sa.Integer(), nullable=True),
sa.Column(
"status",
sa.String(),
nullable=False,
),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["cc_pair_id"],
["connector_credential_pair.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"external_permission",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("user_email", sa.String(), nullable=False),
sa.Column(
"source_type",
sa.String(),
nullable=False,
),
sa.Column("external_permission_group", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"email_to_external_user_cache",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("external_user_id", sa.String(), nullable=False),
sa.Column("user_id", sa.UUID(), nullable=True),
sa.Column("user_email", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)

View File

@@ -0,0 +1,27 @@
"""persona_start_date
Revision ID: 797089dfb4d2
Revises: 55546a7967ee
Create Date: 2024-09-11 14:51:49.785835
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "797089dfb4d2"
down_revision = "55546a7967ee"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"persona",
sa.Column("search_start_date", sa.DateTime(timezone=True), nullable=True),
)
def downgrade() -> None:
op.drop_column("persona", "search_start_date")

View File

@@ -0,0 +1,43 @@
"""non nullable default persona
Revision ID: bd2921608c3a
Revises: 797089dfb4d2
Create Date: 2024-09-20 10:28:37.992042
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "bd2921608c3a"
down_revision = "797089dfb4d2"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Set existing NULL values to False
op.execute(
"UPDATE persona SET is_default_persona = FALSE WHERE is_default_persona IS NULL"
)
# Alter the column to be not nullable with a default value of False
op.alter_column(
"persona",
"is_default_persona",
existing_type=sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
)
def downgrade() -> None:
# Revert the changes
op.alter_column(
"persona",
"is_default_persona",
existing_type=sa.Boolean(),
nullable=True,
server_default=None,
)

View File

@@ -0,0 +1,31 @@
"""add nullable to persona id in Chat Session
Revision ID: c99d76fcd298
Revises: 5c7fdadae813
Create Date: 2024-07-09 19:27:01.579697
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c99d76fcd298"
down_revision = "5c7fdadae813"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True
)
def downgrade() -> None:
op.alter_column(
"chat_session",
"persona_id",
existing_type=sa.INTEGER(),
nullable=False,
)

View File

@@ -1,7 +1,7 @@
"""standard answer match_regex flag
Revision ID: efb35676026c
Revises: 52a219fb5233
Revises: 0ebb1d516877
Create Date: 2024-09-11 13:55:46.101149
"""
@@ -19,7 +19,9 @@ def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"standard_answer",
sa.Column("match_regex", sa.Boolean(), nullable=False, default=False),
sa.Column(
"match_regex", sa.Boolean(), nullable=False, server_default=sa.false()
),
)
# ### end Alembic commands ###

View File

@@ -0,0 +1,26 @@
"""add custom headers to tools
Revision ID: f32615f71aeb
Revises: bd2921608c3a
Create Date: 2024-09-12 20:26:38.932377
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "f32615f71aeb"
down_revision = "bd2921608c3a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"tool", sa.Column("custom_headers", postgresql.JSONB(), nullable=True)
)
def downgrade() -> None:
op.drop_column("tool", "custom_headers")

View File

@@ -1,7 +1,7 @@
"""add has_web_login column to user
Revision ID: f7e58d357687
Revises: bceb1e139447
Revises: ba98eba0f66a
Create Date: 2024-09-07 20:20:54.522620
"""

View File

@@ -1,7 +1,7 @@
from sqlalchemy.orm import Session
from danswer.access.models import DocumentAccess
from danswer.access.utils import prefix_user
from danswer.access.utils import prefix_user_email
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.db.document import get_access_info_for_document
from danswer.db.document import get_access_info_for_documents
@@ -18,10 +18,13 @@ def _get_access_for_document(
document_id=document_id,
)
if not info:
return DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
return DocumentAccess.build(user_ids=info[1], user_groups=[], is_public=info[2])
return DocumentAccess.build(
user_emails=info[1] if info and info[1] else [],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=info[2] if info else False,
)
def get_access_for_document(
@@ -34,6 +37,16 @@ def get_access_for_document(
return versioned_get_access_for_document_fn(document_id, db_session) # type: ignore
def get_null_document_access() -> DocumentAccess:
return DocumentAccess(
user_emails=set(),
user_groups=set(),
is_public=False,
external_user_emails=set(),
external_user_group_ids=set(),
)
def _get_access_for_documents(
document_ids: list[str],
db_session: Session,
@@ -42,13 +55,27 @@ def _get_access_for_documents(
db_session=db_session,
document_ids=document_ids,
)
return {
document_id: DocumentAccess.build(
user_ids=user_ids, user_groups=[], is_public=is_public
doc_access = {
document_id: DocumentAccess(
user_emails=set([email for email in user_emails if email]),
# MIT version will wipe all groups and external groups on update
user_groups=set(),
is_public=is_public,
external_user_emails=set(),
external_user_group_ids=set(),
)
for document_id, user_ids, is_public in document_access_info
for document_id, user_emails, is_public in document_access_info
}
# Sometimes the document has not be indexed by the indexing job yet, in those cases
# the document does not exist and so we use least permissive. Specifically the EE version
# checks the MIT version permissions and creates a superset. This ensures that this flow
# does not fail even if the Document has not yet been indexed.
for doc_id in document_ids:
if doc_id not in doc_access:
doc_access[doc_id] = get_null_document_access()
return doc_access
def get_access_for_documents(
document_ids: list[str],
@@ -70,7 +97,7 @@ def _get_acl_for_user(user: User | None, db_session: Session) -> set[str]:
matches one entry in the returned set.
"""
if user:
return {prefix_user(str(user.id)), PUBLIC_DOC_PAT}
return {prefix_user_email(user.email), PUBLIC_DOC_PAT}
return {PUBLIC_DOC_PAT}

View File

@@ -1,30 +1,72 @@
from dataclasses import dataclass
from uuid import UUID
from danswer.access.utils import prefix_user
from danswer.access.utils import prefix_external_group
from danswer.access.utils import prefix_user_email
from danswer.access.utils import prefix_user_group
from danswer.configs.constants import PUBLIC_DOC_PAT
@dataclass(frozen=True)
class DocumentAccess:
user_ids: set[str] # stringified UUIDs
user_groups: set[str] # names of user groups associated with this document
class ExternalAccess:
# Emails of external users with access to the doc externally
external_user_emails: set[str]
# Names or external IDs of groups with access to the doc
external_user_group_ids: set[str]
# Whether the document is public in the external system or Danswer
is_public: bool
def to_acl(self) -> list[str]:
return (
[prefix_user(user_id) for user_id in self.user_ids]
@dataclass(frozen=True)
class DocumentAccess(ExternalAccess):
# User emails for Danswer users, None indicates admin
user_emails: set[str | None]
# Names of user groups associated with this document
user_groups: set[str]
def to_acl(self) -> set[str]:
return set(
[
prefix_user_email(user_email)
for user_email in self.user_emails
if user_email
]
+ [prefix_user_group(group_name) for group_name in self.user_groups]
+ [
prefix_user_email(user_email)
for user_email in self.external_user_emails
]
+ [
# The group names are already prefixed by the source type
# This adds an additional prefix of "external_group:"
prefix_external_group(group_name)
for group_name in self.external_user_group_ids
]
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
)
@classmethod
def build(
cls, user_ids: list[UUID | None], user_groups: list[str], is_public: bool
cls,
user_emails: list[str | None],
user_groups: list[str],
external_user_emails: list[str],
external_user_group_ids: list[str],
is_public: bool,
) -> "DocumentAccess":
return cls(
user_ids={str(user_id) for user_id in user_ids if user_id},
external_user_emails={
prefix_user_email(external_email)
for external_email in external_user_emails
},
external_user_group_ids={
prefix_external_group(external_group_id)
for external_group_id in external_user_group_ids
},
user_emails={
prefix_user_email(user_email)
for user_email in user_emails
if user_email
},
user_groups=set(user_groups),
is_public=is_public,
)

View File

@@ -1,10 +1,24 @@
def prefix_user(user_id: str) -> str:
"""Prefixes a user ID to eliminate collision with group names.
This assumes that groups are prefixed with a different prefix."""
return f"user_id:{user_id}"
from danswer.configs.constants import DocumentSource
def prefix_user_email(user_email: str) -> str:
"""Prefixes a user email to eliminate collision with group names.
This applies to both a Danswer user and an External user, this is to make the query time
more efficient"""
return f"user_email:{user_email}"
def prefix_user_group(user_group_name: str) -> str:
"""Prefixes a user group name to eliminate collision with user IDs.
"""Prefixes a user group name to eliminate collision with user emails.
This assumes that user ids are prefixed with a different prefix."""
return f"group:{user_group_name}"
def prefix_external_group(ext_group_name: str) -> str:
"""Prefixes an external group name to eliminate collision with user emails / Danswer groups."""
return f"external_group:{ext_group_name}"
def prefix_group_w_source(ext_group_name: str, source: DocumentSource) -> str:
"""External groups may collide across sources, every source needs its own prefix."""
return f"{source.value.upper()}_{ext_group_name}"

View File

@@ -300,17 +300,27 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def authenticate(
self, credentials: OAuth2PasswordRequestForm
) -> Optional[User]:
user = await super().authenticate(credentials)
if user is None:
try:
user = await self.get_by_email(credentials.username)
if not user.has_web_login:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
except exceptions.UserNotExists:
pass
try:
user = await self.get_by_email(credentials.username)
except exceptions.UserNotExists:
self.password_helper.hash(credentials.password)
return None
if not user.has_web_login:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
verified, updated_password_hash = self.password_helper.verify_and_update(
credentials.password, user.hashed_password
)
if not verified:
return None
if updated_password_hash is not None:
await self.user_db.update(user, {"hashed_password": updated_password_hash})
return user

View File

@@ -21,18 +21,17 @@ from redis import Redis
from sqlalchemy import inspect
from sqlalchemy import text
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import ObjectDeletedError
from danswer.access.access import get_access_for_document
from danswer.background.celery.celery_redis import RedisConnectorCredentialPair
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.celery.celery_redis import RedisDocumentSet
from danswer.background.celery.celery_redis import RedisUserGroup
from danswer.background.celery.celery_utils import extract_ids_from_runnable_connector
from danswer.background.celery.celery_utils import should_kick_off_deletion_of_cc_pair
from danswer.background.celery.celery_utils import should_prune_cc_pair
from danswer.background.connector_deletion import delete_connector_credential_pair
from danswer.background.connector_deletion import delete_connector_credential_pair_batch
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.task_utils import name_cc_prune_task
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -43,29 +42,42 @@ from danswer.configs.constants import POSTGRES_CELERY_WORKER_APP_NAME
from danswer.configs.constants import PostgresAdvisoryLocks
from danswer.connectors.factory import instantiate_connector
from danswer.connectors.models import InputType
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector_credential_pair import add_deletion_failure_message
from danswer.db.connector_credential_pair import (
get_connector_credential_pair,
delete_connector_credential_pair__no_commit,
)
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import count_documents_by_needs_sync
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document
from danswer.db.document import get_document_connector_count
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import mark_document_as_synced
from danswer.db.document_set import delete_document_set
from danswer.db.document_set import fetch_document_set_for_document
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from danswer.db.document_set import fetch_document_sets
from danswer.db.document_set import fetch_document_sets_for_document
from danswer.db.document_set import get_document_set_by_id
from danswer.db.document_set import mark_document_set_as_synced
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import init_sqlalchemy_engine
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import IndexingStatus
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.index_attempt import get_last_attempt
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import DocumentSet
from danswer.db.models import UserGroup
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import UpdateRequest
from danswer.redis.redis_pool import RedisPool
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import ColoredFormatter
from danswer.utils.logger import PlainFormatter
from danswer.utils.logger import setup_logger
@@ -93,56 +105,6 @@ celery_app.config_from_object(
#
# If imports from this module are needed, use local imports to avoid circular importing
#####
@build_celery_task_wrapper(name_cc_cleanup_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def cleanup_connector_credential_pair_task(
connector_id: int,
credential_id: int,
) -> int:
"""Connector deletion task. This is run as an async task because it is a somewhat slow job.
Needs to potentially update a large number of Postgres and Vespa docs, including deleting them
or updating the ACL"""
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
# validate that the connector / credential pair is deletable
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair:
raise ValueError(
f"Cannot run deletion attempt - connector_credential_pair with Connector ID: "
f"{connector_id} and Credential ID: {credential_id} does not exist."
)
try:
deletion_attempt_disallowed_reason = check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair, db_session=db_session
)
if deletion_attempt_disallowed_reason:
raise ValueError(deletion_attempt_disallowed_reason)
# The bulk of the work is in here, updates Postgres and Vespa
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
return delete_connector_credential_pair(
db_session=db_session,
document_index=document_index,
cc_pair=cc_pair,
)
except Exception as e:
stack_trace = traceback.format_exc()
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair.id, error_message)
task_logger.exception(
f"Failed to run connector_deletion. "
f"connector_id={connector_id} credential_id={credential_id}"
)
raise e
@build_celery_task_wrapper(name_cc_prune_task)
@@ -166,11 +128,11 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
return
runnable_connector = instantiate_connector(
cc_pair.connector.source,
InputType.PRUNE,
cc_pair.connector.connector_specific_config,
cc_pair.credential,
db_session,
db_session=db_session,
source=cc_pair.connector.source,
input_type=InputType.PRUNE,
connector_specific_config=cc_pair.connector.connector_specific_config,
credential=cc_pair.credential,
)
all_connector_doc_ids: set[str] = extract_ids_from_runnable_connector(
@@ -218,6 +180,11 @@ def prune_documents_task(connector_id: int, credential_id: int) -> None:
def try_generate_stale_document_sync_tasks(
db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
"""This picks up stale documents (typically from indexing) and queues them for sync to Vespa.
Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required.
"""
# the fence is up, do nothing
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
return None
@@ -233,6 +200,8 @@ def try_generate_stale_document_sync_tasks(
f"Stale documents found (at least {stale_doc_count}). Generating sync tasks by cc pair."
)
task_logger.info("RedisConnector.generate_tasks starting by cc_pair.")
# rkuo: we could technically sync all stale docs in one big pass.
# but I feel it's more understandable to group the docs by cc_pair
total_tasks_generated = 0
@@ -255,7 +224,7 @@ def try_generate_stale_document_sync_tasks(
total_tasks_generated += tasks_generated
task_logger.info(
f"All per connector generate_tasks finished. total_tasks_generated={total_tasks_generated}"
f"RedisConnector.generate_tasks finished for all cc_pairs. total_tasks_generated={total_tasks_generated}"
)
r.set(RedisConnectorCredentialPair.get_fence_key(), total_tasks_generated)
@@ -265,6 +234,10 @@ def try_generate_stale_document_sync_tasks(
def try_generate_document_set_sync_tasks(
document_set: DocumentSet, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero.
Returns None if no syncing is required.
"""
lock_beat.reacquire()
rds = RedisDocumentSet(document_set.id)
@@ -310,6 +283,10 @@ def try_generate_document_set_sync_tasks(
def try_generate_user_group_sync_tasks(
usergroup: UserGroup, db_session: Session, r: Redis, lock_beat: redis.lock.Lock
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero.
Returns None if no syncing is required.
"""
lock_beat.reacquire()
rug = RedisUserGroup(usergroup.id)
@@ -327,7 +304,9 @@ def try_generate_user_group_sync_tasks(
r.delete(rug.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(f"generate_tasks starting. usergroup_id={usergroup.id}")
task_logger.info(
f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}"
)
tasks_generated = rug.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
return None
@@ -339,7 +318,7 @@ def try_generate_user_group_sync_tasks(
# return 0
task_logger.info(
f"generate_tasks finished. "
f"RedisUserGroup.generate_tasks finished. "
f"usergroup_id={usergroup.id} tasks_generated={tasks_generated}"
)
@@ -348,6 +327,78 @@ def try_generate_user_group_sync_tasks(
return tasks_generated
def try_generate_document_cc_pair_cleanup_tasks(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
lock_beat: redis.lock.Lock,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero.
Returns None if no syncing is required.
"""
lock_beat.reacquire()
rcd = RedisConnectorDeletion(cc_pair.id)
# don't generate sync tasks if tasks are still pending
if r.exists(rcd.fence_key):
return None
# we need to refresh the state of the object inside the fence
# to avoid a race condition with db.commit/fence deletion
# at the end of this taskset
try:
db_session.refresh(cc_pair)
except ObjectDeletedError:
return None
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
return None
search_settings = get_current_search_settings(db_session)
last_indexing = get_last_attempt(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
search_settings_id=search_settings.id,
db_session=db_session,
)
if last_indexing:
if (
last_indexing.status == IndexingStatus.IN_PROGRESS
or last_indexing.status == IndexingStatus.NOT_STARTED
):
return None
# add tasks to celery and build up the task set to monitor in redis
r.delete(rcd.taskset_key)
# Add all documents that need to be updated into the queue
task_logger.info(
f"RedisConnectorDeletion.generate_tasks starting. cc_pair_id={cc_pair.id}"
)
tasks_generated = rcd.generate_tasks(celery_app, db_session, r, lock_beat)
if tasks_generated is None:
return None
# Currently we are allowing the sync to proceed with 0 tasks.
# It's possible for sets/groups to be generated initially with no entries
# and they still need to be marked as up to date.
# if tasks_generated == 0:
# return 0
task_logger.info(
f"RedisConnectorDeletion.generate_tasks finished. "
f"cc_pair_id={cc_pair.id} tasks_generated={tasks_generated}"
)
# set this only after all tasks have been added
r.set(rcd.fence_key, tasks_generated)
return tasks_generated
#####
# Periodic Tasks
#####
@@ -411,26 +462,37 @@ def check_for_vespa_sync_task() -> None:
@celery_app.task(
name="check_for_cc_pair_deletion_task",
name="check_for_connector_deletion_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_cc_pair_deletion_task() -> None:
"""Runs periodically to check if any deletion tasks should be run"""
with Session(get_sqlalchemy_engine()) as db_session:
# check if any cc pairs are up for deletion
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
if should_kick_off_deletion_of_cc_pair(cc_pair, db_session):
task_logger.info(
f"Deleting the {cc_pair.name} connector credential pair"
)
def check_for_connector_deletion_task() -> None:
r = redis_pool.get_client()
cleanup_connector_credential_pair_task.apply_async(
kwargs=dict(
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
),
lock_beat = r.lock(
DanswerRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
try_generate_document_cc_pair_cleanup_tasks(
cc_pair, db_session, r, lock_beat
)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception:
task_logger.exception("Unexpected exception")
finally:
if lock_beat.owned():
lock_beat.release()
@celery_app.task(
@@ -602,7 +664,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
return False
# document set sync
doc_sets = fetch_document_set_for_document(document_id, db_session)
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
# User group sync
@@ -617,8 +679,8 @@ def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
hidden=doc.hidden,
)
# update Vespa
document_index.update(update_requests=[update_request])
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
document_index.update_single(update_request=update_request)
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
@@ -635,6 +697,92 @@ def vespa_metadata_sync_task(self: Task, document_id: str) -> bool:
return True
@celery_app.task(
name="document_by_cc_pair_cleanup_task",
bind=True,
soft_time_limit=45,
time_limit=60,
max_retries=3,
)
def document_by_cc_pair_cleanup_task(
self: Task, document_id: str, connector_id: int, credential_id: int
) -> bool:
task_logger.info(f"document_id={document_id}")
try:
with Session(get_sqlalchemy_engine()) as db_session:
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
document_index = get_default_document_index(
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
)
count = get_document_connector_count(db_session, document_id)
if count == 1:
# count == 1 means this is the only remaining cc_pair reference to the doc
# delete it from vespa and the db
document_index.delete(doc_ids=[document_id])
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=[document_id],
)
elif count > 1:
# count > 1 means the document still has cc_pair references
doc = get_document(document_id, db_session)
if not doc:
return False
# the below functions do not include cc_pairs being deleted.
# i.e. they will correctly omit access for the current cc_pair
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
update_request = UpdateRequest(
document_ids=[document_id],
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
document_index.update_single(update_request=update_request)
# there are still other cc_pair references to the doc, so just resync to Vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_synced(document_id, db_session)
else:
pass
# update_docs_last_modified__no_commit(
# db_session=db_session,
# document_ids=[document_id],
# )
db_session.commit()
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc_id={document_id}")
except Exception as e:
task_logger.exception("Unexpected exception")
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
return True
@signals.task_postrun.connect
def celery_task_postrun(
sender: Any | None = None,
@@ -687,6 +835,14 @@ def celery_task_postrun(
r.srem(rug.taskset_key, task_id)
return
if task_id.startswith(RedisConnectorDeletion.PREFIX):
r = redis_pool.get_client()
cc_pair_id = RedisConnectorDeletion.get_id_from_task_id(task_id)
if cc_pair_id is not None:
rcd = RedisConnectorDeletion(cc_pair_id)
r.srem(rcd.taskset_key, task_id)
return
def monitor_connector_taskset(r: Redis) -> None:
fence_value = r.get(RedisConnectorCredentialPair.get_fence_key())
@@ -707,9 +863,7 @@ def monitor_connector_taskset(r: Redis) -> None:
task_logger.info(f"Successfully synced stale documents. count={initial_count}")
def monitor_document_set_taskset(
key_bytes: bytes, r: Redis, db_session: Session
) -> None:
def monitor_document_set_taskset(key_bytes: bytes, r: Redis) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id = RedisDocumentSet.get_id_from_fence_key(fence_key)
if document_set_id is None:
@@ -730,41 +884,47 @@ def monitor_document_set_taskset(
count = cast(int, r.scard(rds.taskset_key))
task_logger.info(
f"document_set_id={document_set_id} remaining={count} initial={initial_count}"
f"Document set sync: document_set_id={document_set_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
document_set = cast(
DocumentSet,
get_document_set_by_id(db_session=db_session, document_set_id=document_set_id),
) # casting since we "know" a document set with this ID exists
if document_set:
if not document_set.connector_credential_pairs:
# if there are no connectors, then delete the document set.
delete_document_set(document_set_row=document_set, db_session=db_session)
task_logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
)
else:
mark_document_set_as_synced(document_set_id, db_session)
task_logger.info(
f"Successfully synced document set with ID: '{document_set_id}'!"
)
with Session(get_sqlalchemy_engine()) as db_session:
document_set = cast(
DocumentSet,
get_document_set_by_id(
db_session=db_session, document_set_id=document_set_id
),
) # casting since we "know" a document set with this ID exists
if document_set:
if not document_set.connector_credential_pairs:
# if there are no connectors, then delete the document set.
delete_document_set(
document_set_row=document_set, db_session=db_session
)
task_logger.info(
f"Successfully deleted document set with ID: '{document_set_id}'!"
)
else:
mark_document_set_as_synced(document_set_id, db_session)
task_logger.info(
f"Successfully synced document set with ID: '{document_set_id}'!"
)
r.delete(rds.taskset_key)
r.delete(rds.fence_key)
def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -> None:
key = key_bytes.decode("utf-8")
usergroup_id = RedisUserGroup.get_id_from_fence_key(key)
if not usergroup_id:
task_logger.warning("Could not parse usergroup id from {key}")
def monitor_connector_deletion_taskset(key_bytes: bytes, r: Redis) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id = RedisConnectorDeletion.get_id_from_fence_key(fence_key)
if cc_pair_id is None:
task_logger.warning("could not parse document set id from {key}")
return
rug = RedisUserGroup(usergroup_id)
fence_value = r.get(rug.fence_key)
rcd = RedisConnectorDeletion(cc_pair_id)
fence_value = r.get(rcd.fence_key)
if fence_value is None:
return
@@ -774,44 +934,79 @@ def monitor_usergroup_taskset(key_bytes: bytes, r: Redis, db_session: Session) -
task_logger.error("The value is not an integer.")
return
count = cast(int, r.scard(rug.taskset_key))
count = cast(int, r.scard(rcd.taskset_key))
task_logger.info(
f"usergroup_id={usergroup_id} remaining={count} initial={initial_count}"
f"Connector deletion: cc_pair_id={cc_pair_id} remaining={count} initial={initial_count}"
)
if count > 0:
return
try:
fetch_user_group = fetch_versioned_implementation(
"danswer.db.user_group", "fetch_user_group"
)
except ModuleNotFoundError:
task_logger.exception(
"fetch_versioned_implementation failed to look up fetch_user_group."
)
return
with Session(get_sqlalchemy_engine()) as db_session:
cc_pair = get_connector_credential_pair_from_id(cc_pair_id, db_session)
if not cc_pair:
return
user_group: UserGroup | None = fetch_user_group(
db_session=db_session, user_group_id=usergroup_id
try:
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
cc_pair_id=cc_pair.id,
)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# user groups
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group",
"delete_user_group_cc_pair_relationship__no_commit",
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair.id,
db_session=db_session,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=cc_pair.connector_id,
)
if not connector or not len(connector.credentials):
task_logger.info(
"Found no credentials left for connector, deleting connector"
)
db_session.delete(connector)
db_session.commit()
except Exception as e:
stack_trace = traceback.format_exc()
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
add_deletion_failure_message(db_session, cc_pair.id, error_message)
task_logger.exception(
f"Failed to run connector_deletion. "
f"connector_id={cc_pair.connector_id} credential_id={cc_pair.credential_id}"
)
raise e
task_logger.info(
f"Successfully deleted connector_credential_pair with connector_id: '{cc_pair.connector_id}' "
f"and credential_id: '{cc_pair.credential_id}'. "
f"Deleted {initial_count} docs."
)
if user_group:
if user_group.is_up_for_deletion:
delete_user_group = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group", "delete_user_group", noop_fallback
)
delete_user_group(db_session=db_session, user_group=user_group)
task_logger.info(f" Deleted usergroup. id='{usergroup_id}'")
else:
mark_user_group_as_synced = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group", "mark_user_group_as_synced", noop_fallback
)
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
r.delete(rug.taskset_key)
r.delete(rug.fence_key)
r.delete(rcd.taskset_key)
r.delete(rcd.fence_key)
@celery_app.task(name="monitor_vespa_sync", soft_time_limit=300)
@@ -835,17 +1030,24 @@ def monitor_vespa_sync() -> None:
if not lock_beat.acquire(blocking=False):
return
with Session(get_sqlalchemy_engine()) as db_session:
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
monitor_document_set_taskset(key_bytes, r, db_session)
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
monitor_document_set_taskset(key_bytes, r)
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
monitor_usergroup_taskset(key_bytes, r, db_session)
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
"danswer.background.celery_utils",
"monitor_usergroup_taskset",
noop_fallback,
)
monitor_usergroup_taskset(key_bytes, r)
for key_bytes in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
monitor_connector_deletion_taskset(key_bytes, r)
#
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(DanswerCeleryQueues.VESPA_METADATA_SYNC, r_celery)
# task_logger.warning(f"queue={DanswerCeleryQueues.VESPA_METADATA_SYNC} length={length}")
@@ -889,6 +1091,12 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.TASKSET_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorDeletion.FENCE_PREFIX + "*"):
r.delete(key)
class CeleryTaskPlainFormatter(PlainFormatter):
def format(self, record: logging.LogRecord) -> str:
@@ -970,14 +1178,18 @@ celery_app.conf.beat_schedule = {
"schedule": timedelta(seconds=5),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
"check-for-cc-pair-deletion": {
"task": "check_for_cc_pair_deletion_task",
# don't need to check too often, since we kick off a deletion initially
# during the API call that actually marks the CC pair for deletion
"schedule": timedelta(minutes=1),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
celery_app.conf.beat_schedule.update(
{
"check-for-connector-deletion-task": {
"task": "check_for_connector_deletion_task",
# don't need to check too often, since we kick off a deletion initially
# during the API call that actually marks the CC pair for deletion
"schedule": timedelta(minutes=1),
"options": {"priority": DanswerCeleryPriority.HIGH},
},
}
)
celery_app.conf.beat_schedule.update(
{
"check-for-prune": {

View File

@@ -15,6 +15,7 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DanswerCeleryQueues
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import construct_document_select_for_connector_credential_pair
from danswer.db.document import (
construct_document_select_for_connector_credential_pair_by_needs_sync,
)
@@ -134,7 +135,7 @@ class RedisDocumentSet(RedisObjectHelper):
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
@@ -189,7 +190,7 @@ class RedisUserGroup(RedisObjectHelper):
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
@@ -211,6 +212,9 @@ class RedisUserGroup(RedisObjectHelper):
class RedisConnectorCredentialPair(RedisObjectHelper):
"""This class differs from the default in that the taskset used spans
all connectors and is not per connector."""
PREFIX = "connectorsync"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
@@ -256,7 +260,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
@@ -281,6 +285,64 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
return len(async_results)
class RedisConnectorDeletion(RedisObjectHelper):
PREFIX = "connectordeletion"
FENCE_PREFIX = PREFIX + "_fence"
TASKSET_PREFIX = PREFIX + "_taskset"
def generate_tasks(
self,
celery_app: Celery,
db_session: Session,
redis_client: Redis,
lock: redis.lock.Lock,
) -> int | None:
last_lock_time = time.monotonic()
async_results = []
cc_pair = get_connector_credential_pair_from_id(self._id, db_session)
if not cc_pair:
return None
stmt = construct_document_select_for_connector_credential_pair(
cc_pair.connector_id, cc_pair.credential_id
)
for doc in db_session.scalars(stmt).yield_per(1):
current_time = time.monotonic()
if current_time - last_lock_time >= (
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
):
lock.reacquire()
last_lock_time = current_time
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
# the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
# we prefix the task id so it's easier to keep track of who created the task
# aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac"
custom_task_id = f"{self.task_id_prefix}_{uuid4()}"
# add to the tracking taskset in redis BEFORE creating the celery task.
# note that for the moment we are using a single taskset key, not differentiated by cc_pair id
redis_client.sadd(self.taskset_key, custom_task_id)
# Priority on sync's triggered by new indexing should be medium
result = celery_app.send_task(
"document_by_cc_pair_cleanup_task",
kwargs=dict(
document_id=doc.id,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
),
queue=DanswerCeleryQueues.CONNECTOR_DELETION,
task_id=custom_task_id,
priority=DanswerCeleryPriority.MEDIUM,
)
async_results.append(result)
return len(async_results)
def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists

View File

@@ -3,7 +3,7 @@ from datetime import timezone
from sqlalchemy.orm import Session
from danswer.background.task_utils import name_cc_cleanup_task
from danswer.background.celery.celery_redis import RedisConnectorDeletion
from danswer.background.task_utils import name_cc_prune_task
from danswer.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from danswer.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
@@ -15,29 +15,44 @@ from danswer.connectors.interfaces import IdConnector
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import Document
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.engine import get_db_current_time
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.enums import TaskStatus
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import TaskQueueState
from danswer.db.tasks import check_task_is_live_and_not_timed_out
from danswer.db.tasks import get_latest_task
from danswer.db.tasks import get_latest_task_by_type
from danswer.redis.redis_pool import RedisPool
from danswer.server.documents.models import DeletionAttemptSnapshot
from danswer.utils.logger import setup_logger
logger = setup_logger()
redis_pool = RedisPool()
def _get_deletion_status(
connector_id: int, credential_id: int, db_session: Session
) -> TaskQueueState | None:
cleanup_task_name = name_cc_cleanup_task(
connector_id=connector_id, credential_id=credential_id
"""We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis.
"""
cc_pair = get_connector_credential_pair(
connector_id=connector_id, credential_id=credential_id, db_session=db_session
)
if not cc_pair:
return None
rcd = RedisConnectorDeletion(cc_pair.id)
r = redis_pool.get_client()
if not r.exists(rcd.fence_key):
return None
return TaskQueueState(
task_id="", task_name=rcd.fence_key, status=TaskStatus.STARTED
)
return get_latest_task(task_name=cleanup_task_name, db_session=db_session)
def get_deletion_attempt_snapshot(
@@ -54,31 +69,6 @@ def get_deletion_attempt_snapshot(
)
def should_kick_off_deletion_of_cc_pair(
cc_pair: ConnectorCredentialPair, db_session: Session
) -> bool:
if cc_pair.status != ConnectorCredentialPairStatus.DELETING:
return False
if check_deletion_attempt_is_allowed(cc_pair, db_session):
return False
deletion_task = _get_deletion_status(
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
db_session=db_session,
)
if deletion_task and check_task_is_live_and_not_timed_out(
deletion_task,
db_session,
# 1 hour timeout
timeout=60 * 60,
):
return False
return True
def should_prune_cc_pair(
connector: Connector, credential: Credential, db_session: Session
) -> bool:

View File

@@ -1,5 +1,7 @@
# docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html
from danswer.configs.app_configs import CELERY_RESULT_EXPIRES
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY
from danswer.configs.app_configs import REDIS_DB_NUMBER_CELERY_RESULT_BACKEND
from danswer.configs.app_configs import REDIS_HOST
from danswer.configs.app_configs import REDIS_PASSWORD
from danswer.configs.app_configs import REDIS_PORT
@@ -27,7 +29,7 @@ if REDIS_SSL:
# example celery_broker_url: "redis://:password@localhost:6379/15"
broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}"
result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}"
# NOTE: prefetch 4 is significantly faster than prefetch 1 for small tasks
# however, prefetching is bad when tasks are lengthy as those tasks
@@ -42,3 +44,33 @@ broker_transport_options = {
task_default_priority = DanswerCeleryPriority.MEDIUM
task_acks_late = True
# It's possible we don't even need celery's result backend, in which case all of the optimization below
# might be irrelevant
result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default
# Option 0: Defaults (json serializer, no compression)
# about 1.5 KB per queued task. 1KB in queue, 400B for result, 100 as a child entry in generator result
# Option 1: Reduces generator task result sizes by roughly 20%
# task_compression = "bzip2"
# task_serializer = "pickle"
# result_compression = "bzip2"
# result_serializer = "pickle"
# accept_content=["pickle"]
# Option 2: this significantly reduces the size of the result for generator tasks since the list of children
# can be large. small tasks change very little
# def pickle_bz2_encoder(data):
# return bz2.compress(pickle.dumps(data))
# def pickle_bz2_decoder(data):
# return pickle.loads(bz2.decompress(data))
# from kombu import serialization # To register custom serialization with Celery/Kombu
# serialization.register('pickle-bzip2', pickle_bz2_encoder, pickle_bz2_decoder, 'application/x-pickle-bz2', 'binary')
# task_serializer = "pickle-bzip2"
# result_serializer = "pickle-bzip2"
# accept_content=["pickle", "pickle-bzip2"]

View File

@@ -13,28 +13,16 @@ connector / credential pair from the access list
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector_credential_pair import (
delete_connector_credential_pair__no_commit,
)
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_by_connector_credential_pair__no_commit
from danswer.db.document import delete_documents_complete__no_commit
from danswer.db.document import get_document_connector_cnts
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import get_document_connector_counts
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import ConnectorCredentialPair
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
from danswer.utils.variable_functionality import noop_fallback
logger = setup_logger()
@@ -57,13 +45,15 @@ def delete_connector_credential_pair_batch(
with prepare_to_modify_documents(
db_session=db_session, document_ids=document_ids
):
document_connector_cnts = get_document_connector_cnts(
document_connector_counts = get_document_connector_counts(
db_session=db_session, document_ids=document_ids
)
# figure out which docs need to be completely deleted
document_ids_to_delete = [
document_id for document_id, cnt in document_connector_cnts if cnt == 1
document_id
for document_id, cnt in document_connector_counts
if cnt == 1
]
logger.debug(f"Deleting documents: {document_ids_to_delete}")
@@ -76,7 +66,7 @@ def delete_connector_credential_pair_batch(
# figure out which docs need to be updated
document_ids_to_update = [
document_id for document_id, cnt in document_connector_cnts if cnt > 1
document_id for document_id, cnt in document_connector_counts if cnt > 1
]
# maps document id to list of document set names
@@ -109,7 +99,7 @@ def delete_connector_credential_pair_batch(
document_index.update(update_requests=update_requests)
# clean up Postgres
delete_document_by_connector_credential_pair__no_commit(
delete_documents_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids_to_update,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
@@ -118,78 +108,3 @@ def delete_connector_credential_pair_batch(
),
)
db_session.commit()
def delete_connector_credential_pair(
db_session: Session,
document_index: DocumentIndex,
cc_pair: ConnectorCredentialPair,
) -> int:
connector_id = cc_pair.connector_id
credential_id = cc_pair.credential_id
num_docs_deleted = 0
while True:
documents = get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
limit=_DELETION_BATCH_SIZE,
)
if not documents:
break
delete_connector_credential_pair_batch(
document_ids=[document.id for document in documents],
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
)
num_docs_deleted += len(documents)
# clean up the rest of the related Postgres entities
# index attempts
delete_index_attempts(
db_session=db_session,
cc_pair_id=cc_pair.id,
)
# document sets
delete_document_set_cc_pair_relationship__no_commit(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
# user groups
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
"danswer.db.user_group",
"delete_user_group_cc_pair_relationship__no_commit",
noop_fallback,
)
cleanup_user_groups(
cc_pair_id=cc_pair.id,
db_session=db_session,
)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=connector_id,
)
if not connector or not len(connector.credentials):
logger.info("Found no credentials left for connector, deleting connector")
db_session.delete(connector)
db_session.commit()
logger.notice(
"Successfully deleted connector_credential_pair with connector_id:"
f" '{connector_id}' and credential_id: '{credential_id}'. Deleted {num_docs_deleted} docs."
)
return num_docs_deleted

View File

@@ -56,11 +56,11 @@ def _get_connector_runner(
try:
runnable_connector = instantiate_connector(
attempt.connector_credential_pair.connector.source,
task,
attempt.connector_credential_pair.connector.connector_specific_config,
attempt.connector_credential_pair.credential,
db_session,
db_session=db_session,
source=attempt.connector_credential_pair.connector.source,
input_type=task,
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
credential=attempt.connector_credential_pair.credential,
)
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")

View File

@@ -14,14 +14,6 @@ from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
def name_cc_cleanup_task(connector_id: int, credential_id: int) -> str:
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
def name_document_set_sync_task(document_set_id: int) -> str:
return f"sync_doc_set_{document_set_id}"
def name_cc_prune_task(
connector_id: int | None = None, credential_id: int | None = None
) -> str:

View File

@@ -211,7 +211,6 @@ def cleanup_indexing_jobs(
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
) -> dict[int, Future | SimpleJob]:
existing_jobs_copy = existing_jobs.copy()
# clean up completed jobs
with Session(get_sqlalchemy_engine()) as db_session:
for attempt_id, job in existing_jobs.items():
@@ -312,7 +311,12 @@ def kickoff_indexing_jobs(
indexing_attempt_count = 0
primary_client_full = False
secondary_client_full = False
for attempt, search_settings in new_indexing_attempts:
if primary_client_full and secondary_client_full:
break
use_secondary_index = (
search_settings.status == IndexModelStatus.FUTURE
if search_settings is not None
@@ -337,22 +341,28 @@ def kickoff_indexing_jobs(
)
continue
if use_secondary_index:
run = secondary_client.submit(
run_indexing_entrypoint,
attempt.id,
attempt.connector_credential_pair_id,
global_version.get_is_ee_version(),
pure=False,
)
if not use_secondary_index:
if not primary_client_full:
run = client.submit(
run_indexing_entrypoint,
attempt.id,
attempt.connector_credential_pair_id,
global_version.get_is_ee_version(),
pure=False,
)
if not run:
primary_client_full = True
else:
run = client.submit(
run_indexing_entrypoint,
attempt.id,
attempt.connector_credential_pair_id,
global_version.get_is_ee_version(),
pure=False,
)
if not secondary_client_full:
run = secondary_client.submit(
run_indexing_entrypoint,
attempt.id,
attempt.connector_credential_pair_id,
global_version.get_is_ee_version(),
pure=False,
)
if not run:
secondary_client_full = True
if run:
if indexing_attempt_count == 0:

View File

@@ -122,7 +122,7 @@ def load_personas_from_yaml(
prompt_ids=prompt_ids,
document_set_ids=doc_set_ids,
tool_ids=tool_ids,
default_persona=True,
builtin_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None

View File

@@ -11,7 +11,6 @@ from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.tools.custom.base_tool_types import ToolResultType
from danswer.tools.graphing.models import GraphGenerationDisplay
class LlmDoc(BaseModel):
@@ -49,8 +48,6 @@ class QADocsResponse(RetrievalDocs):
class StreamStopReason(Enum):
CONTEXT_LENGTH = "context_length"
CANCELLED = "cancelled"
FINISHED = "finished"
NEW_RESPONSE = "new_response"
class StreamStopInfo(BaseModel):
@@ -176,7 +173,6 @@ AnswerQuestionPossibleReturn = (
| ImageGenerationDisplay
| CustomToolResponse
| StreamingError
| GraphGenerationDisplay
| StreamStopInfo
)

View File

@@ -18,8 +18,6 @@ from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
@@ -74,16 +72,13 @@ from danswer.search.utils import relevant_sections_to_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.analysis.analysis_tool import CSVAnalysisTool
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
from danswer.tools.custom.custom_tool import (
build_custom_tools_from_openapi_schema_and_headers,
)
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.force import ForceUseTool
from danswer.tools.graphing.graphing_tool import GraphingResponse
from danswer.tools.graphing.graphing_tool import GraphingTool
from danswer.tools.graphing.models import GraphGenerationDisplay
from danswer.tools.graphing.models import GRAPHING_RESPONSE_ID
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
@@ -250,7 +245,6 @@ def _get_force_search_settings(
ChatPacket = (
StreamingError
| QADocsResponse
| GraphingResponse
| LLMRelevanceFilterResponse
| FinalUsedContextDocsResponse
| ChatMessageDetail
@@ -259,7 +253,6 @@ ChatPacket = (
| CitationInfo
| ImageGenerationDisplay
| CustomToolResponse
| GraphGenerationDisplay
| MessageSpecificCitations
| MessageResponseIDInfo
)
@@ -280,6 +273,7 @@ def stream_chat_message_objects(
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
enforce_chat_session_id_for_search_docs: bool = True,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
@@ -451,6 +445,7 @@ def stream_chat_message_objects(
chat_session=chat_session,
user_id=user_id,
db_session=db_session,
enforce_chat_session_id_for_search_docs=enforce_chat_session_id_for_search_docs,
)
# Generates full documents currently
@@ -537,21 +532,8 @@ def stream_chat_message_objects(
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
for db_tool_model in persona.tools:
# handle in-code tools specially
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
if (
tool_cls.__name__ == CSVAnalysisTool.__name__
and not latest_query_files
):
tool_dict[db_tool_model.id] = [CSVAnalysisTool()]
if (
tool_cls.__name__ == GraphingTool.__name__
and not latest_query_files
):
tool_dict[db_tool_model.id] = [GraphingTool(output_dir="output")]
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
search_tool = SearchTool(
db_session=db_session,
@@ -622,27 +604,24 @@ def stream_chat_message_objects(
]
continue
# handle all custom tools
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema(
build_custom_tools_from_openapi_schema_and_headers(
db_tool_model.openapi_schema,
dynamic_schema_info=DynamicSchemaInfo(
chat_session_id=chat_session_id,
message_id=user_message.id if user_message else None,
),
custom_headers=db_tool_model.custom_headers,
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
@@ -688,185 +667,102 @@ def stream_chat_message_objects(
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
dropped_indices = None
tool_result = None
yielded_message_id_info = True
for packet in answer.processed_streamed_output:
if isinstance(packet, StreamStopInfo):
if packet.stop_reason is not StreamStopReason.NEW_RESPONSE:
break
if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
qa_docs_response,
reference_db_search_docs,
dropped_indices,
) = _handle_search_tool_response_summary(
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=(
retrieval_options.dedupe_docs
if retrieval_options
else False
),
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
db_citations = None
if reference_db_search_docs is not None:
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in reference_db_search_docs
],
)
if reference_db_search_docs:
db_citations = _translate_citations(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
if dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
)
# Saving Gen AI answer and responding with message info
if tool_result is None:
tool_call = None
else:
tool_call = ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments={
k: v if not isinstance(v, bytes) else v.decode("utf-8")
for k, v in tool_result.tool_args.items()
},
tool_result=tool_result.tool_result,
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
)
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=cast(
QADocsResponse, qa_docs_response
).rephrased_query
if qa_docs_response is not None
else None,
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=cast(MessageSpecificCitations, db_citations).citation_map
if db_citations is not None
else None,
error=None,
tool_call=tool_call,
)
db_session.commit() # actually save user / assistant message
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield msg_detail_response
reserved_message_id = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=gen_ai_response_message.id
if user_message is not None
else gen_ai_response_message.id,
message_type=MessageType.ASSISTANT,
)
yielded_message_id_info = False
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=gen_ai_response_message,
prompt_id=prompt_id,
overridden_model=overridden_model,
message_type=MessageType.ASSISTANT,
alternate_assistant_id=new_msg_req.alternate_assistant_id,
db_session=db_session,
commit=False,
)
reference_db_search_docs = None
file_ids = save_files_from_urls(
[img.url for img in img_generation_response]
)
ai_message_files = [
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
yield ImageGenerationDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
qa_docs_response,
reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
else:
if not yielded_message_id_info:
yield MessageResponseIDInfo(
user_message_id=gen_ai_response_message.id,
reserved_assistant_message_id=reserved_message_id,
)
yielded_message_id_info = True
if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
qa_docs_response,
reference_db_search_docs,
dropped_indices,
) = _handle_search_tool_response_summary(
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if reference_db_search_docs is not None:
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in reference_db_search_docs
],
)
if dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
)
elif packet.id == GRAPHING_RESPONSE_ID:
graph_generation = cast(GraphingResponse, packet.response)
yield graph_generation
# yield GraphGenerationDisplay(
# file_id=graph_generation.extra_graph_display.file_id,
# line_graph=graph_generation.extra_graph_display.line_graph,
# )
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
)
file_ids = save_files_from_urls(
[img.url for img in img_generation_response]
)
ai_message_files = [
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
yield ImageGenerationDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
qa_docs_response,
reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(
CustomToolCallSummary, packet.response
)
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
yield cast(ChatPacket, packet)
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
except Exception as e:
error_msg = str(e)
logger.exception(f"Failed to process chat message: {error_msg}")
except ValueError as e:
logger.exception("Failed to process chat message.")
error_msg = str(e)
yield StreamingError(error=error_msg)
db_session.rollback()
return
except Exception as e:
logger.exception("Failed to process chat message.")
error_msg = str(e)
stack_trace = traceback.format_exc()
client_error_msg = litellm_exception_to_error_msg(e, llm)
if llm.config.api_key and len(llm.config.api_key) > 2:
@@ -887,8 +783,11 @@ def stream_chat_message_objects(
)
yield AllCitations(citations=answer.citations)
if answer.llm_answer == "":
return
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
@@ -903,14 +802,18 @@ def stream_chat_message_objects(
if message_specific_citations
else None,
error=None,
tool_call=ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
if tool_result
else None,
tool_calls=(
[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else []
),
)
logger.debug("Committing messages")

View File

@@ -135,7 +135,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password"
)
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
# defaults to False
@@ -159,11 +159,18 @@ REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
# Used by celery as broker and backend
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15))
REDIS_DB_NUMBER_CELERY_RESULT_BACKEND = int(
os.environ.get("REDIS_DB_NUMBER_CELERY_RESULT_BACKEND", 14)
)
REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # broker
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "CERT_NONE")
# https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings
# should be one of "required", "optional", or "none"
REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none")
REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", "")
CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds
#####
# Connector Configs
#####

View File

@@ -99,6 +99,7 @@ class DocumentSource(str, Enum):
CLICKUP = "clickup"
MEDIAWIKI = "mediawiki"
WIKIPEDIA = "wikipedia"
ASANA = "asana"
S3 = "s3"
R2 = "r2"
GOOGLE_CLOUD_STORAGE = "google_cloud_storage"
@@ -133,6 +134,12 @@ class AuthType(str, Enum):
SAML = "saml"
class SessionType(str, Enum):
CHAT = "Chat"
SEARCH = "Search"
SLACK = "Slack"
class QAFeedbackType(str, Enum):
LIKE = "like" # User likes the answer, used for metrics
DISLIKE = "dislike" # User dislikes the answer, used for metrics
@@ -165,7 +172,6 @@ class FileOrigin(str, Enum):
CONNECTOR = "connector"
GENERATED_REPORT = "generated_report"
OTHER = "other"
GRAPH_GEN = "graph_gen"
class PostgresAdvisoryLocks(Enum):
@@ -182,6 +188,8 @@ class DanswerCeleryQueues:
class DanswerRedisLocks:
CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat"
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
MONITOR_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:monitor_connector_deletion_beat"
class DanswerCeleryPriority(int, Enum):

View File

@@ -0,0 +1,233 @@
import time
from collections.abc import Iterator
from datetime import datetime
from typing import Dict
import asana # type: ignore
from danswer.utils.logger import setup_logger
logger = setup_logger()
# https://github.com/Asana/python-asana/tree/master?tab=readme-ov-file#documentation-for-api-endpoints
class AsanaTask:
def __init__(
self,
id: str,
title: str,
text: str,
link: str,
last_modified: datetime,
project_gid: str,
project_name: str,
) -> None:
self.id = id
self.title = title
self.text = text
self.link = link
self.last_modified = last_modified
self.project_gid = project_gid
self.project_name = project_name
def __str__(self) -> str:
return f"ID: {self.id}\nTitle: {self.title}\nLast modified: {self.last_modified}\nText: {self.text}"
class AsanaAPI:
def __init__(
self, api_token: str, workspace_gid: str, team_gid: str | None
) -> None:
self._user = None # type: ignore
self.workspace_gid = workspace_gid
self.team_gid = team_gid
self.configuration = asana.Configuration()
self.api_client = asana.ApiClient(self.configuration)
self.tasks_api = asana.TasksApi(self.api_client)
self.stories_api = asana.StoriesApi(self.api_client)
self.users_api = asana.UsersApi(self.api_client)
self.project_api = asana.ProjectsApi(self.api_client)
self.workspaces_api = asana.WorkspacesApi(self.api_client)
self.api_error_count = 0
self.configuration.access_token = api_token
self.task_count = 0
def get_tasks(
self, project_gids: list[str] | None, start_date: str
) -> Iterator[AsanaTask]:
"""Get all tasks from the projects with the given gids that were modified since the given date.
If project_gids is None, get all tasks from all projects in the workspace."""
logger.info("Starting to fetch Asana projects")
projects = self.project_api.get_projects(
opts={
"workspace": self.workspace_gid,
"opt_fields": "gid,name,archived,modified_at",
}
)
start_seconds = int(time.mktime(datetime.now().timetuple()))
projects_list = []
project_count = 0
for project_info in projects:
project_gid = project_info["gid"]
if project_gids is None or project_gid in project_gids:
projects_list.append(project_gid)
else:
logger.debug(
f"Skipping project: {project_gid} - not in accepted project_gids"
)
project_count += 1
if project_count % 100 == 0:
logger.info(f"Processed {project_count} projects")
logger.info(f"Found {len(projects_list)} projects to process")
for project_gid in projects_list:
for task in self._get_tasks_for_project(
project_gid, start_date, start_seconds
):
yield task
logger.info(f"Completed fetching {self.task_count} tasks from Asana")
if self.api_error_count > 0:
logger.warning(
f"Encountered {self.api_error_count} API errors during task fetching"
)
def _get_tasks_for_project(
self, project_gid: str, start_date: str, start_seconds: int
) -> Iterator[AsanaTask]:
project = self.project_api.get_project(project_gid, opts={})
if project["archived"]:
logger.info(f"Skipping archived project: {project['name']} ({project_gid})")
return []
if not project["team"] or not project["team"]["gid"]:
logger.info(
f"Skipping project without a team: {project['name']} ({project_gid})"
)
return []
if project["privacy_setting"] == "private":
if self.team_gid and project["team"]["gid"] != self.team_gid:
logger.info(
f"Skipping private project not in configured team: {project['name']} ({project_gid})"
)
return []
else:
logger.info(
f"Processing private project in configured team: {project['name']} ({project_gid})"
)
simple_start_date = start_date.split(".")[0].split("+")[0]
logger.info(
f"Fetching tasks modified since {simple_start_date} for project: {project['name']} ({project_gid})"
)
opts = {
"opt_fields": "name,memberships,memberships.project,completed_at,completed_by,created_at,"
"created_by,custom_fields,dependencies,due_at,due_on,external,html_notes,liked,likes,"
"modified_at,notes,num_hearts,parent,projects,resource_subtype,resource_type,start_on,"
"workspace,permalink_url",
"modified_since": start_date,
}
tasks_from_api = self.tasks_api.get_tasks_for_project(project_gid, opts)
for data in tasks_from_api:
self.task_count += 1
if self.task_count % 10 == 0:
end_seconds = time.mktime(datetime.now().timetuple())
runtime_seconds = end_seconds - start_seconds
if runtime_seconds > 0:
logger.info(
f"Processed {self.task_count} tasks in {runtime_seconds:.0f} seconds "
f"({self.task_count / runtime_seconds:.2f} tasks/second)"
)
logger.debug(f"Processing Asana task: {data['name']}")
text = self._construct_task_text(data)
try:
text += self._fetch_and_add_comments(data["gid"])
last_modified_date = self.format_date(data["modified_at"])
text += f"Last modified: {last_modified_date}\n"
task = AsanaTask(
id=data["gid"],
title=data["name"],
text=text,
link=data["permalink_url"],
last_modified=datetime.fromisoformat(data["modified_at"]),
project_gid=project_gid,
project_name=project["name"],
)
yield task
except Exception:
logger.error(
f"Error processing task {data['gid']} in project {project_gid}",
exc_info=True,
)
self.api_error_count += 1
def _construct_task_text(self, data: Dict) -> str:
text = f"{data['name']}\n\n"
if data["notes"]:
text += f"{data['notes']}\n\n"
if data["created_by"] and data["created_by"]["gid"]:
creator = self.get_user(data["created_by"]["gid"])["name"]
created_date = self.format_date(data["created_at"])
text += f"Created by: {creator} on {created_date}\n"
if data["due_on"]:
due_date = self.format_date(data["due_on"])
text += f"Due date: {due_date}\n"
if data["completed_at"]:
completed_date = self.format_date(data["completed_at"])
text += f"Completed on: {completed_date}\n"
text += "\n"
return text
def _fetch_and_add_comments(self, task_gid: str) -> str:
text = ""
stories_opts: Dict[str, str] = {}
story_start = time.time()
stories = self.stories_api.get_stories_for_task(task_gid, stories_opts)
story_count = 0
comment_count = 0
for story in stories:
story_count += 1
if story["resource_subtype"] == "comment_added":
comment = self.stories_api.get_story(
story["gid"], opts={"opt_fields": "text,created_by,created_at"}
)
commenter = self.get_user(comment["created_by"]["gid"])["name"]
text += f"Comment by {commenter}: {comment['text']}\n\n"
comment_count += 1
story_duration = time.time() - story_start
logger.debug(
f"Processed {story_count} stories (including {comment_count} comments) in {story_duration:.2f} seconds"
)
return text
def get_user(self, user_gid: str) -> Dict:
if self._user is not None:
return self._user
self._user = self.users_api.get_user(user_gid, {"opt_fields": "name,email"})
if not self._user:
logger.warning(f"Unable to fetch user information for user_gid: {user_gid}")
return {"name": "Unknown"}
return self._user
def format_date(self, date_str: str) -> str:
date = datetime.fromisoformat(date_str)
return time.strftime("%Y-%m-%d", date.timetuple())
def get_time(self) -> str:
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())

View File

@@ -0,0 +1,120 @@
import datetime
from typing import Any
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.asana import asana_api
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.utils.logger import setup_logger
logger = setup_logger()
class AsanaConnector(LoadConnector, PollConnector):
def __init__(
self,
asana_workspace_id: str,
asana_project_ids: str | None = None,
asana_team_id: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
) -> None:
self.workspace_id = asana_workspace_id
self.project_ids_to_index: list[str] | None = (
asana_project_ids.split(",") if asana_project_ids is not None else None
)
self.asana_team_id = asana_team_id
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
logger.info(
f"AsanaConnector initialized with workspace_id: {asana_workspace_id}"
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.api_token = credentials["asana_api_token_secret"]
self.asana_client = asana_api.AsanaAPI(
api_token=self.api_token,
workspace_gid=self.workspace_id,
team_gid=self.asana_team_id,
)
logger.info("Asana credentials loaded and API client initialized")
return None
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
start_time = datetime.datetime.fromtimestamp(start).isoformat()
logger.info(f"Starting Asana poll from {start_time}")
asana = asana_api.AsanaAPI(
api_token=self.api_token,
workspace_gid=self.workspace_id,
team_gid=self.asana_team_id,
)
docs_batch: list[Document] = []
tasks = asana.get_tasks(self.project_ids_to_index, start_time)
for task in tasks:
doc = self._message_to_doc(task)
docs_batch.append(doc)
if len(docs_batch) >= self.batch_size:
logger.info(f"Yielding batch of {len(docs_batch)} documents")
yield docs_batch
docs_batch = []
if docs_batch:
logger.info(f"Yielding final batch of {len(docs_batch)} documents")
yield docs_batch
logger.info("Asana poll completed")
def load_from_state(self) -> GenerateDocumentsOutput:
logger.notice("Starting full index of all Asana tasks")
return self.poll_source(start=0, end=None)
def _message_to_doc(self, task: asana_api.AsanaTask) -> Document:
logger.debug(f"Converting Asana task {task.id} to Document")
return Document(
id=task.id,
sections=[Section(link=task.link, text=task.text)],
doc_updated_at=task.last_modified,
source=DocumentSource.ASANA,
semantic_identifier=task.title,
metadata={
"group": task.project_gid,
"project": task.project_name,
},
)
if __name__ == "__main__":
import time
import os
logger.notice("Starting Asana connector test")
connector = AsanaConnector(
os.environ["WORKSPACE_ID"],
os.environ["PROJECT_IDS"],
os.environ["TEAM_ID"],
)
connector.load_credentials(
{
"asana_api_token_secret": os.environ["API_TOKEN"],
}
)
logger.info("Loading all documents from Asana")
all_docs = connector.load_from_state()
current = time.time()
one_day_ago = current - 24 * 60 * 60 # 1 day
logger.info("Polling for documents updated in the last 24 hours")
latest_docs = connector.poll_source(one_day_ago, current)
for docs in latest_docs:
for doc in docs:
print(doc.id)
logger.notice("Asana connector test completed")

View File

@@ -4,6 +4,7 @@ from typing import Type
from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.connectors.asana.connector import AsanaConnector
from danswer.connectors.axero.connector import AxeroConnector
from danswer.connectors.blob.connector import BlobStorageConnector
from danswer.connectors.bookstack.connector import BookstackConnector
@@ -91,6 +92,7 @@ def identify_connector_class(
DocumentSource.CLICKUP: ClickupConnector,
DocumentSource.MEDIAWIKI: MediaWikiConnector,
DocumentSource.WIKIPEDIA: WikipediaConnector,
DocumentSource.ASANA: AsanaConnector,
DocumentSource.S3: BlobStorageConnector,
DocumentSource.R2: BlobStorageConnector,
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
@@ -124,11 +126,11 @@ def identify_connector_class(
def instantiate_connector(
db_session: Session,
source: DocumentSource,
input_type: InputType,
connector_specific_config: dict[str, Any],
credential: Credential,
db_session: Session,
) -> BaseConnector:
connector_class = identify_connector_class(source, input_type)
connector = connector_class(**connector_specific_config)

View File

@@ -6,7 +6,6 @@ from datetime import timezone
from enum import Enum
from itertools import chain
from typing import Any
from typing import cast
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
@@ -21,19 +20,13 @@ from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.connectors.cross_connector_utils.retry_wrapper import retry_builder
from danswer.connectors.google_drive.connector_auth import (
get_google_drive_creds_for_authorized_user,
)
from danswer.connectors.google_drive.connector_auth import (
get_google_drive_creds_for_service_account,
)
from danswer.connectors.google_drive.connector_auth import get_google_drive_creds
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
@@ -407,42 +400,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
(2) A credential which holds a service account key JSON file, which
can then be used to impersonate any user in the workspace.
"""
creds: OAuthCredentials | ServiceAccountCredentials | None = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
access_token_json_str = cast(
str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY]
)
creds = get_google_drive_creds_for_authorized_user(
token_json_str=access_token_json_str
)
# tell caller to update token stored in DB if it has changed
# (e.g. the token has been refreshed)
new_creds_json_str = creds.to_json() if creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
if DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
service_account_key_json_str = credentials[
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
]
creds = get_google_drive_creds_for_service_account(
service_account_key_json_str=service_account_key_json_str
)
# "Impersonate" a user if one is specified
delegated_user_email = cast(
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
)
if delegated_user_email:
creds = creds.with_subject(delegated_user_email) if creds else None # type: ignore
if creds is None:
raise PermissionError(
"Unable to access Google Drive - unknown credential structure."
)
creds, new_creds_dict = get_google_drive_creds(credentials)
self.creds = creds
return new_creds_dict
@@ -509,6 +467,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector):
file["modifiedTime"]
).astimezone(timezone.utc),
metadata={} if text_contents else {IGNORE_FOR_QA: "True"},
additional_info=file.get("id"),
)
)
except Exception as e:

View File

@@ -10,11 +10,13 @@ from google.oauth2.service_account import Credentials as ServiceAccountCredentia
from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import ENTERPRISE_EDITION_ENABLED
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY
from danswer.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY
from danswer.connectors.google_drive.constants import BASE_SCOPES
from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY,
)
@@ -22,7 +24,8 @@ from danswer.connectors.google_drive.constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from danswer.connectors.google_drive.constants import DB_CREDENTIALS_DICT_TOKEN_KEY
from danswer.connectors.google_drive.constants import SCOPES
from danswer.connectors.google_drive.constants import FETCH_GROUPS_SCOPES
from danswer.connectors.google_drive.constants import FETCH_PERMISSIONS_SCOPES
from danswer.db.credentials import update_credential_json
from danswer.db.models import User
from danswer.dynamic_configs.factory import get_dynamic_config_store
@@ -34,15 +37,25 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def build_gdrive_scopes() -> list[str]:
base_scopes: list[str] = BASE_SCOPES
permissions_scopes: list[str] = FETCH_PERMISSIONS_SCOPES
groups_scopes: list[str] = FETCH_GROUPS_SCOPES
if ENTERPRISE_EDITION_ENABLED:
return base_scopes + permissions_scopes + groups_scopes
return base_scopes + permissions_scopes
def _build_frontend_google_drive_redirect() -> str:
return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback"
def get_google_drive_creds_for_authorized_user(
token_json_str: str,
token_json_str: str, scopes: list[str] = build_gdrive_scopes()
) -> OAuthCredentials | None:
creds_json = json.loads(token_json_str)
creds = OAuthCredentials.from_authorized_user_info(creds_json, SCOPES)
creds = OAuthCredentials.from_authorized_user_info(creds_json, scopes)
if creds.valid:
return creds
@@ -59,18 +72,67 @@ def get_google_drive_creds_for_authorized_user(
return None
def get_google_drive_creds_for_service_account(
service_account_key_json_str: str,
def _get_google_drive_creds_for_service_account(
service_account_key_json_str: str, scopes: list[str] = build_gdrive_scopes()
) -> ServiceAccountCredentials | None:
service_account_key = json.loads(service_account_key_json_str)
creds = ServiceAccountCredentials.from_service_account_info(
service_account_key, scopes=SCOPES
service_account_key, scopes=scopes
)
if not creds.valid or not creds.expired:
creds.refresh(Request())
return creds if creds.valid else None
def get_google_drive_creds(
credentials: dict[str, str], scopes: list[str] = build_gdrive_scopes()
) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]:
oauth_creds = None
service_creds = None
new_creds_dict = None
if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials:
access_token_json_str = cast(str, credentials[DB_CREDENTIALS_DICT_TOKEN_KEY])
oauth_creds = get_google_drive_creds_for_authorized_user(
token_json_str=access_token_json_str, scopes=scopes
)
# tell caller to update token stored in DB if it has changed
# (e.g. the token has been refreshed)
new_creds_json_str = oauth_creds.to_json() if oauth_creds else ""
if new_creds_json_str != access_token_json_str:
new_creds_dict = {DB_CREDENTIALS_DICT_TOKEN_KEY: new_creds_json_str}
elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials:
service_account_key_json_str = credentials[
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY
]
service_creds = _get_google_drive_creds_for_service_account(
service_account_key_json_str=service_account_key_json_str,
scopes=scopes,
)
# "Impersonate" a user if one is specified
delegated_user_email = cast(
str | None, credentials.get(DB_CREDENTIALS_DICT_DELEGATED_USER_KEY)
)
if delegated_user_email:
service_creds = (
service_creds.with_subject(delegated_user_email)
if service_creds
else None
)
creds: ServiceAccountCredentials | OAuthCredentials | None = (
oauth_creds or service_creds
)
if creds is None:
raise PermissionError(
"Unable to access Google Drive - unknown credential structure."
)
return creds, new_creds_dict
def verify_csrf(credential_id: int, state: str) -> None:
csrf = get_dynamic_config_store().load(KV_CRED_KEY.format(str(credential_id)))
if csrf != state:
@@ -84,7 +146,7 @@ def get_auth_url(credential_id: int) -> str:
credential_json = json.loads(creds_str)
flow = InstalledAppFlow.from_client_config(
credential_json,
scopes=SCOPES,
scopes=build_gdrive_scopes(),
redirect_uri=_build_frontend_google_drive_redirect(),
)
auth_url, _ = flow.authorization_url(prompt="consent")
@@ -107,7 +169,7 @@ def update_credential_access_tokens(
app_credentials = get_google_app_cred()
flow = InstalledAppFlow.from_client_config(
app_credentials.model_dump(),
scopes=SCOPES,
scopes=build_gdrive_scopes(),
redirect_uri=_build_frontend_google_drive_redirect(),
)
flow.fetch_token(code=auth_code)

View File

@@ -1,7 +1,7 @@
DB_CREDENTIALS_DICT_TOKEN_KEY = "google_drive_tokens"
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key"
DB_CREDENTIALS_DICT_DELEGATED_USER_KEY = "google_drive_delegated_user"
SCOPES = [
"https://www.googleapis.com/auth/drive.readonly",
"https://www.googleapis.com/auth/drive.metadata.readonly",
]
BASE_SCOPES = ["https://www.googleapis.com/auth/drive.readonly"]
FETCH_PERMISSIONS_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
FETCH_GROUPS_SCOPES = ["https://www.googleapis.com/auth/cloud-identity.groups.readonly"]

View File

@@ -113,6 +113,9 @@ class DocumentBase(BaseModel):
# The default title is semantic_identifier though unless otherwise specified
title: str | None = None
from_ingestion_api: bool = False
# Anything else that may be useful that is specific to this particular connector type that other
# parts of the code may need. If you're unsure, this can be left as None
additional_info: Any = None
def get_title_for_document_index(
self,

View File

@@ -211,7 +211,7 @@ def handle_message(
with Session(get_sqlalchemy_engine()) as db_session:
if message_info.email:
add_non_web_user_if_not_exists(message_info.email, db_session)
add_non_web_user_if_not_exists(db_session, message_info.email)
# first check if we need to respond with a standard answer
used_standard_answer = handle_standard_answers(

View File

@@ -5,6 +5,7 @@ from typing import cast
from typing import Optional
from typing import TypeVar
from fastapi import HTTPException
from retry import retry
from slack_sdk import WebClient
from slack_sdk.models.blocks import DividerBlock
@@ -153,15 +154,23 @@ def handle_regular_answer(
with Session(get_sqlalchemy_engine()) as db_session:
if len(new_message_request.messages) > 1:
persona = cast(
Persona,
fetch_persona_by_id(
db_session,
new_message_request.persona_id,
user=None,
get_editable=False,
),
)
if new_message_request.persona_config:
raise HTTPException(
status_code=403,
detail="Slack bot does not support persona config",
)
elif new_message_request.persona_id:
persona = cast(
Persona,
fetch_persona_by_id(
db_session,
new_message_request.persona_id,
user=None,
get_editable=False,
),
)
llm, _ = get_llms_for_persona(persona)
# In cases of threads, split the available tokens between docs and thread context

View File

@@ -178,14 +178,8 @@ def delete_search_doc_message_relationship(
def delete_tool_call_for_message_id(message_id: int, db_session: Session) -> None:
chat_message = (
db_session.query(ChatMessage).filter(ChatMessage.id == message_id).first()
)
if chat_message and chat_message.tool_call_id:
stmt = delete(ToolCall).where(ToolCall.id == chat_message.tool_call_id)
db_session.execute(stmt)
chat_message.tool_call_id = None
stmt = delete(ToolCall).where(ToolCall.message_id == message_id)
db_session.execute(stmt)
db_session.commit()
@@ -232,7 +226,7 @@ def create_chat_session(
db_session: Session,
description: str,
user_id: UUID | None,
persona_id: int,
persona_id: int | None, # Can be none if temporary persona is used
llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None,
one_shot: bool = False,
@@ -394,7 +388,7 @@ def get_chat_messages_by_session(
)
if prefetch_tool_calls:
stmt = stmt.options(joinedload(ChatMessage.tool_call))
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
result = db_session.scalars(stmt).unique().all()
else:
result = db_session.scalars(stmt).all()
@@ -480,7 +474,7 @@ def create_new_chat_message(
alternate_assistant_id: int | None = None,
# Maps the citation number [n] to the DB SearchDoc
citations: dict[int, int] | None = None,
tool_call: ToolCall | None = None,
tool_calls: list[ToolCall] | None = None,
commit: bool = True,
reserved_message_id: int | None = None,
overridden_model: str | None = None,
@@ -500,7 +494,7 @@ def create_new_chat_message(
existing_message.message_type = message_type
existing_message.citations = citations
existing_message.files = files
existing_message.tool_call = tool_call if tool_call else None
existing_message.tool_calls = tool_calls if tool_calls else []
existing_message.error = error
existing_message.alternate_assistant_id = alternate_assistant_id
existing_message.overridden_model = overridden_model
@@ -519,7 +513,7 @@ def create_new_chat_message(
message_type=message_type,
citations=citations,
files=files,
tool_call=tool_call if tool_call else None,
tool_calls=tool_calls if tool_calls else [],
error=error,
alternate_assistant_id=alternate_assistant_id,
overridden_model=overridden_model,
@@ -604,6 +598,7 @@ def get_doc_query_identifiers_from_model(
chat_session: ChatSession,
user_id: UUID | None,
db_session: Session,
enforce_chat_session_id_for_search_docs: bool,
) -> list[tuple[str, int]]:
"""Given a list of search_doc_ids"""
search_docs = (
@@ -623,7 +618,8 @@ def get_doc_query_identifiers_from_model(
for doc in search_docs
]
):
raise ValueError("Invalid reference doc, not from this chat session.")
if enforce_chat_session_id_for_search_docs:
raise ValueError("Invalid reference doc, not from this chat session.")
except IndexError:
# This happens when the doc has no chat_messages associated with it.
# which happens as an edge case where the chat message failed to save
@@ -753,13 +749,14 @@ def translate_db_message_to_chat_message_detail(
time_sent=chat_message.time_sent,
citations=chat_message.citations,
files=chat_message.files or [],
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
alternate_assistant_id=chat_message.alternate_assistant_id,
overridden_model=chat_message.overridden_model,
)

View File

@@ -12,6 +12,7 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import DocumentSource
from danswer.db.connector import fetch_connector_by_id
from danswer.db.credentials import fetch_credential_by_id
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexAttempt
@@ -24,6 +25,10 @@ from danswer.db.models import UserGroup__ConnectorCredentialPair
from danswer.db.models import UserRole
from danswer.server.models import StatusResponse
from danswer.utils.logger import setup_logger
from ee.danswer.db.external_perm import delete_user__ext_group_for_cc_pair__no_commit
from ee.danswer.external_permissions.permission_sync_function_map import (
check_if_valid_sync_source,
)
logger = setup_logger()
@@ -74,7 +79,7 @@ def _add_user_filters(
.correlate(ConnectorCredentialPair)
)
else:
where_clause |= ConnectorCredentialPair.is_public == True # noqa: E712
where_clause |= ConnectorCredentialPair.access_type == AccessType.PUBLIC
return stmt.where(where_clause)
@@ -94,8 +99,7 @@ def get_connector_credential_pairs(
) # noqa
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
results = db_session.scalars(stmt)
return list(results.all())
return list(db_session.scalars(stmt).all())
def add_deletion_failure_message(
@@ -309,9 +313,9 @@ def associate_default_cc_pair(db_session: Session) -> None:
association = ConnectorCredentialPair(
connector_id=0,
credential_id=0,
access_type=AccessType.PUBLIC,
name="DefaultCCPair",
status=ConnectorCredentialPairStatus.ACTIVE,
is_public=True,
)
db_session.add(association)
db_session.commit()
@@ -336,8 +340,9 @@ def add_credential_to_connector(
connector_id: int,
credential_id: int,
cc_pair_name: str | None,
is_public: bool,
access_type: AccessType,
groups: list[int] | None,
auto_sync_options: dict | None = None,
) -> StatusResponse:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(credential_id, user, db_session)
@@ -345,6 +350,13 @@ def add_credential_to_connector(
if connector is None:
raise HTTPException(status_code=404, detail="Connector does not exist")
if access_type == AccessType.SYNC:
if not check_if_valid_sync_source(connector.source):
raise HTTPException(
status_code=400,
detail=f"Connector of type {connector.source} does not support SYNC access type",
)
if credential is None:
error_msg = (
f"Credential {credential_id} does not exist or does not belong to user"
@@ -375,12 +387,13 @@ def add_credential_to_connector(
credential_id=credential_id,
name=cc_pair_name,
status=ConnectorCredentialPairStatus.ACTIVE,
is_public=is_public,
access_type=access_type,
auto_sync_options=auto_sync_options,
)
db_session.add(association)
db_session.flush() # make sure the association has an id
if groups:
if groups and access_type != AccessType.SYNC:
_relate_groups_to_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
@@ -423,6 +436,10 @@ def remove_credential_from_connector(
)
if association is not None:
delete_user__ext_group_for_cc_pair__no_commit(
db_session=db_session,
cc_pair_id=association.id,
)
db_session.delete(association)
db_session.commit()
return StatusResponse(

View File

@@ -4,7 +4,6 @@ from collections.abc import Generator
from collections.abc import Sequence
from datetime import datetime
from datetime import timezone
from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import delete
@@ -17,14 +16,17 @@ from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.engine.util import TransactionalContext
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import null
from danswer.configs.constants import DEFAULT_BOOST
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.feedback import delete_document_feedback_for_documents__no_commit
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.db.models import Document as DbDocument
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import User
from danswer.db.tag import delete_document_tags_for_documents__no_commit
from danswer.db.utils import model_to_dict
from danswer.document_index.interfaces import DocumentMetadata
@@ -126,7 +128,18 @@ def get_documents_by_ids(
return list(documents)
def get_document_connector_cnts(
def get_document_connector_count(
db_session: Session,
document_id: str,
) -> int:
results = get_document_connector_counts(db_session, [document_id])
if not results or len(results) == 0:
return 0
return results[0][1]
def get_document_connector_counts(
db_session: Session,
document_ids: list[str],
) -> Sequence[tuple[str, int]]:
@@ -141,7 +154,7 @@ def get_document_connector_cnts(
return db_session.execute(stmt).all() # type: ignore
def get_document_cnts_for_cc_pairs(
def get_document_counts_for_cc_pairs(
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
) -> Sequence[tuple[int, int, int]]:
stmt = (
@@ -175,16 +188,14 @@ def get_document_cnts_for_cc_pairs(
def get_access_info_for_document(
db_session: Session,
document_id: str,
) -> tuple[str, list[UUID | None], bool] | None:
) -> tuple[str, list[str | None], bool] | None:
"""Gets access info for a single document by calling the get_access_info_for_documents function
and passing a list with a single document ID.
Args:
db_session (Session): The database session to use.
document_id (str): The document ID to fetch access info for.
Returns:
Optional[Tuple[str, List[UUID | None], bool]]: A tuple containing the document ID, a list of user IDs,
Optional[Tuple[str, List[str | None], bool]]: A tuple containing the document ID, a list of user emails,
and a boolean indicating if the document is globally public, or None if no results are found.
"""
results = get_access_info_for_documents(db_session, [document_id])
@@ -197,19 +208,27 @@ def get_access_info_for_document(
def get_access_info_for_documents(
db_session: Session,
document_ids: list[str],
) -> Sequence[tuple[str, list[UUID | None], bool]]:
) -> Sequence[tuple[str, list[str | None], bool]]:
"""Gets back all relevant access info for the given documents. This includes
the user_ids for cc pairs that the document is associated with + whether any
of the associated cc pairs are intending to make the document globally public.
Returns the list where each element contains:
- Document ID (which is also the ID of the DocumentByConnectorCredentialPair)
- List of emails of Danswer users with direct access to the doc (includes a "None" element if
the connector was set up by an admin when auth was off
- bool for whether the document is public (the document later can also be marked public by
automatic permission sync step)
"""
stmt = select(
DocumentByConnectorCredentialPair.id,
func.array_agg(func.coalesce(User.email, null())).label("user_emails"),
func.bool_or(ConnectorCredentialPair.access_type == AccessType.PUBLIC).label(
"public_doc"
),
).where(DocumentByConnectorCredentialPair.id.in_(document_ids))
stmt = (
select(
DocumentByConnectorCredentialPair.id,
func.array_agg(Credential.user_id).label("user_ids"),
func.bool_or(ConnectorCredentialPair.is_public).label("public_doc"),
)
.where(DocumentByConnectorCredentialPair.id.in_(document_ids))
.join(
stmt.join(
Credential,
DocumentByConnectorCredentialPair.credential_id == Credential.id,
)
@@ -222,6 +241,13 @@ def get_access_info_for_documents(
== ConnectorCredentialPair.credential_id,
),
)
.outerjoin(
User,
and_(
Credential.user_id == User.id,
ConnectorCredentialPair.access_type != AccessType.SYNC,
),
)
# don't include CC pairs that are being deleted
# NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them
.where(ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING)
@@ -267,9 +293,19 @@ def upsert_documents(
for doc in seen_documents.values()
]
)
# for now, there are no columns to update. If more metadata is added, then this
# needs to change to an `on_conflict_do_update`
on_conflict_stmt = insert_stmt.on_conflict_do_nothing()
on_conflict_stmt = insert_stmt.on_conflict_do_update(
index_elements=["id"], # Conflict target
set_={
"from_ingestion_api": insert_stmt.excluded.from_ingestion_api,
"boost": insert_stmt.excluded.boost,
"hidden": insert_stmt.excluded.hidden,
"semantic_id": insert_stmt.excluded.semantic_id,
"link": insert_stmt.excluded.link,
"primary_owners": insert_stmt.excluded.primary_owners,
"secondary_owners": insert_stmt.excluded.secondary_owners,
},
)
db_session.execute(on_conflict_stmt)
db_session.commit()
@@ -350,11 +386,34 @@ def upsert_documents_complete(
def delete_document_by_connector_credential_pair__no_commit(
db_session: Session,
document_id: str,
connector_credential_pair_identifier: ConnectorCredentialPairIdentifier
| None = None,
) -> None:
"""Deletes a single document by cc pair relationship entry.
Foreign key rows are left in place.
The implicit assumption is that the document itself still has other cc_pair
references and needs to continue existing.
"""
delete_documents_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=[document_id],
connector_credential_pair_identifier=connector_credential_pair_identifier,
)
def delete_documents_by_connector_credential_pair__no_commit(
db_session: Session,
document_ids: list[str],
connector_credential_pair_identifier: ConnectorCredentialPairIdentifier
| None = None,
) -> None:
"""This deletes just the document by cc pair entries for a particular cc pair.
Foreign key rows are left in place.
The implicit assumption is that the document itself still has other cc_pair
references and needs to continue existing.
"""
stmt = delete(DocumentByConnectorCredentialPair).where(
DocumentByConnectorCredentialPair.id.in_(document_ids)
)
@@ -377,8 +436,9 @@ def delete_documents__no_commit(db_session: Session, document_ids: list[str]) ->
def delete_documents_complete__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""This completely deletes the documents from the db, including all foreign key relationships"""
logger.info(f"Deleting {len(document_ids)} documents from the DB")
delete_document_by_connector_credential_pair__no_commit(db_session, document_ids)
delete_documents_by_connector_credential_pair__no_commit(db_session, document_ids)
delete_document_feedback_for_documents__no_commit(
document_ids=document_ids, db_session=db_session
)

View File

@@ -14,6 +14,7 @@ from sqlalchemy.orm import Session
from danswer.db.connector_credential_pair import get_cc_pair_groups_for_ids
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Document
@@ -180,7 +181,7 @@ def _check_if_cc_pairs_are_owned_by_groups(
ids=missing_cc_pair_ids,
)
for cc_pair in cc_pairs:
if not cc_pair.is_public:
if cc_pair.access_type != AccessType.PUBLIC:
raise ValueError(
f"Connector Credential Pair with ID: '{cc_pair.id}'"
" is not owned by the specified groups"
@@ -569,7 +570,7 @@ def construct_document_select_by_docset(
return stmt
def fetch_document_set_for_document(
def fetch_document_sets_for_document(
document_id: str,
db_session: Session,
) -> list[str]:
@@ -704,7 +705,7 @@ def check_document_sets_are_public(
ConnectorCredentialPair.id.in_(
connector_credential_pair_ids # type:ignore
),
ConnectorCredentialPair.is_public.is_(False),
ConnectorCredentialPair.access_type != AccessType.PUBLIC,
)
.limit(1)
.first()

View File

@@ -137,8 +137,8 @@ def get_sqlalchemy_engine() -> Engine:
)
_SYNC_ENGINE = create_engine(
connection_string,
pool_size=40,
max_overflow=10,
pool_size=5,
max_overflow=0,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)
@@ -156,8 +156,8 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
connect_args={
"server_settings": {"application_name": POSTGRES_APP_NAME + "_async"}
},
pool_size=40,
max_overflow=10,
pool_size=5,
max_overflow=0,
pool_pre_ping=POSTGRES_POOL_PRE_PING,
pool_recycle=POSTGRES_POOL_RECYCLE,
)

View File

@@ -51,3 +51,9 @@ class ConnectorCredentialPairStatus(str, PyEnum):
def is_active(self) -> bool:
return self == ConnectorCredentialPairStatus.ACTIVE
class AccessType(str, PyEnum):
PUBLIC = "public"
PRIVATE = "private"
SYNC = "sync"

View File

@@ -16,6 +16,7 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType
from danswer.db.chat import get_chat_message
from danswer.db.enums import AccessType
from danswer.db.models import ChatMessageFeedback
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Document as DbDocument
@@ -94,7 +95,7 @@ def _add_user_filters(
.correlate(CCPair)
)
else:
where_clause |= CCPair.is_public == True # noqa: E712
where_clause |= CCPair.access_type == AccessType.PUBLIC
return stmt.where(where_clause)

View File

@@ -4,9 +4,11 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
from danswer.db.models import DocumentSet
from danswer.db.models import LLMProvider as LLMProviderModel
from danswer.db.models import LLMProvider__UserGroup
from danswer.db.models import SearchSettings
from danswer.db.models import Tool as ToolModel
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
@@ -60,13 +62,21 @@ def upsert_cloud_embedding_provider(
def upsert_llm_provider(
llm_provider: LLMProviderUpsertRequest, db_session: Session
llm_provider: LLMProviderUpsertRequest,
db_session: Session,
is_creation: bool = True,
) -> FullLLMProvider:
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
)
if existing_llm_provider and is_creation:
raise ValueError(f"LLM Provider with name {llm_provider.name} already exists")
if not existing_llm_provider:
if not is_creation:
raise ValueError(
f"LLM Provider with name {llm_provider.name} does not exist"
)
existing_llm_provider = LLMProviderModel(name=llm_provider.name)
db_session.add(existing_llm_provider)
@@ -103,6 +113,20 @@ def fetch_existing_embedding_providers(
return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all())
def fetch_existing_doc_sets(
db_session: Session, doc_ids: list[int]
) -> list[DocumentSet]:
return list(
db_session.scalars(select(DocumentSet).where(DocumentSet.id.in_(doc_ids))).all()
)
def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolModel]:
return list(
db_session.scalars(select(ToolModel).where(ToolModel.id.in_(tool_ids))).all()
)
def fetch_existing_llm_providers(
db_session: Session,
user: User | None = None,

View File

@@ -39,6 +39,7 @@ from danswer.configs.constants import DEFAULT_BOOST
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import FileOrigin
from danswer.configs.constants import MessageType
from danswer.db.enums import AccessType
from danswer.configs.constants import NotificationType
from danswer.configs.constants import SearchFeedbackType
from danswer.configs.constants import TokenRateLimitScope
@@ -108,7 +109,7 @@ class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
class User(SQLAlchemyBaseUserTableUUID, Base):
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
"OAuthAccount", lazy="joined"
"OAuthAccount", lazy="joined", cascade="all, delete-orphan"
)
role: Mapped[UserRole] = mapped_column(
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
@@ -122,7 +123,13 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# if specified, controls the assistants that are shown to the user + their order
# if not specified, all assistants are shown
chosen_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=True
postgresql.JSONB(), nullable=False, default=[-2, -1, 0]
)
visible_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=False, default=[]
)
hidden_assistants: Mapped[list[int]] = mapped_column(
postgresql.JSONB(), nullable=False, default=[]
)
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
@@ -170,7 +177,9 @@ class InputPrompt(Base):
active: Mapped[bool] = mapped_column(Boolean)
user: Mapped[User | None] = relationship("User", back_populates="input_prompts")
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
class InputPrompt__User(Base):
@@ -214,7 +223,9 @@ class Notification(Base):
notif_type: Mapped[NotificationType] = mapped_column(
Enum(NotificationType, native_enum=False)
)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
dismissed: Mapped[bool] = mapped_column(Boolean, default=False)
last_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
@@ -249,7 +260,7 @@ class Persona__User(Base):
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id"), primary_key=True, nullable=True
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True, nullable=True
)
@@ -260,7 +271,7 @@ class DocumentSet__User(Base):
ForeignKey("document_set.id"), primary_key=True
)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id"), primary_key=True, nullable=True
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True, nullable=True
)
@@ -384,10 +395,20 @@ class ConnectorCredentialPair(Base):
# controls whether the documents indexed by this CC pair are visible to all
# or if they are only visible to those with that are given explicit access
# (e.g. via owning the credential or being a part of a group that is given access)
is_public: Mapped[bool] = mapped_column(
Boolean,
default=True,
nullable=False,
access_type: Mapped[AccessType] = mapped_column(
Enum(AccessType, native_enum=False), nullable=False
)
# special info needed for the auto-sync feature. The exact structure depends on the
# source type (defined in the connector's `source` field)
# E.g. for google_drive perm sync:
# {"customer_id": "123567", "company_domain": "@danswer.ai"}
auto_sync_options: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
last_time_perm_sync: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
# Time finished, not used for calculating backend jobs which uses time started (created)
last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column(
@@ -418,6 +439,7 @@ class ConnectorCredentialPair(Base):
class Document(Base):
__tablename__ = "document"
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
# this should correspond to the ID of the document
# (as is passed around in Danswer)
@@ -461,7 +483,18 @@ class Document(Base):
secondary_owners: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
# TODO if more sensitive data is added here for display, make sure to add user/group permission
# Permission sync columns
# Email addresses are saved at the document level for externally synced permissions
# This is becuase the normal flow of assigning permissions is through the cc_pair
# doesn't apply here
external_user_emails: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
# These group ids have been prefixed by the source type
external_user_group_ids: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
is_public: Mapped[bool] = mapped_column(Boolean, default=False)
retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
"DocumentRetrievalFeedback", back_populates="document"
@@ -541,7 +574,9 @@ class Credential(Base):
id: Mapped[int] = mapped_column(primary_key=True)
credential_json: Mapped[dict[str, Any]] = mapped_column(EncryptedJson())
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
# if `true`, then all Admins will have access to the credential
admin_public: Mapped[bool] = mapped_column(Boolean, default=True)
time_created: Mapped[datetime.datetime] = mapped_column(
@@ -854,8 +889,10 @@ class ToolCall(Base):
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
message: Mapped["ChatMessage"] = relationship(
"ChatMessage", back_populates="tool_call"
"ChatMessage", back_populates="tool_calls"
)
@@ -863,8 +900,12 @@ class ChatSession(Base):
__tablename__ = "chat_session"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"))
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), nullable=True
)
description: Mapped[str] = mapped_column(Text)
# One-shot direct answering, currently the two types of chats are not mixed
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
@@ -898,7 +939,6 @@ class ChatSession(Base):
prompt_override: Mapped[PromptOverride | None] = mapped_column(
PydanticType(PromptOverride), nullable=True
)
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
@@ -907,7 +947,6 @@ class ChatSession(Base):
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
folder: Mapped["ChatFolder"] = relationship(
"ChatFolder", back_populates="chat_sessions"
@@ -982,14 +1021,9 @@ class ChatMessage(Base):
)
# NOTE: Should always be attached to the `assistant` message.
# represents the tool calls used to generate this message
tool_call_id: Mapped[int | None] = mapped_column(
ForeignKey("tool_call.id"), nullable=True
)
# NOTE: Should always be attached to the `assistant` message.
# represents the tool calls used to generate this message
tool_call: Mapped["ToolCall"] = relationship(
"ToolCall", back_populates="message", foreign_keys=[tool_call_id]
tool_calls: Mapped[list["ToolCall"]] = relationship(
"ToolCall",
back_populates="message",
)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",
@@ -1005,7 +1039,9 @@ class ChatFolder(Base):
id: Mapped[int] = mapped_column(primary_key=True)
# Only null if auth is off
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
name: Mapped[str | None] = mapped_column(String, nullable=True)
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=0)
@@ -1136,7 +1172,9 @@ class DocumentSet(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
description: Mapped[str] = mapped_column(String)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
# Whether changes to the document set have been propagated
is_up_to_date: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
# If `False`, then the document set is not visible to users who are not explicitly
@@ -1180,7 +1218,9 @@ class Prompt(Base):
__tablename__ = "prompt"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
name: Mapped[str] = mapped_column(String)
description: Mapped[str] = mapped_column(String)
system_prompt: Mapped[str] = mapped_column(Text)
@@ -1215,9 +1255,13 @@ class Tool(Base):
openapi_schema: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
custom_headers: Mapped[list[dict[str, str]] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# user who created / owns the tool. Will be None for built-in tools.
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
# Relationship to Persona through the association table
@@ -1241,7 +1285,9 @@ class Persona(Base):
__tablename__ = "persona"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
name: Mapped[str] = mapped_column(String)
description: Mapped[str] = mapped_column(String)
# Number of chunks to pass to the LLM for generation.
@@ -1270,9 +1316,18 @@ class Persona(Base):
starter_messages: Mapped[list[StarterMessage] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# Default personas are configured via backend during deployment
search_start_date: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), default=None
)
# Built-in personas are configured via backend during deployment
# Treated specially (cannot be user edited etc.)
default_persona: Mapped[bool] = mapped_column(Boolean, default=False)
builtin_persona: Mapped[bool] = mapped_column(Boolean, default=False)
# Default personas are personas created by admins and are automatically added
# to all users' assistants list.
is_default_persona: Mapped[bool] = mapped_column(
Boolean, default=False, nullable=False
)
# controls whether the persona is available to be selected by users
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
# controls the ordering of personas in the UI
@@ -1323,10 +1378,10 @@ class Persona(Base):
# Default personas loaded via yaml cannot have the same name
__table_args__ = (
Index(
"_default_persona_name_idx",
"_builtin_persona_name_idx",
"name",
unique=True,
postgresql_where=(default_persona == True), # noqa: E712
postgresql_where=(builtin_persona == True), # noqa: E712
),
)
@@ -1350,55 +1405,6 @@ class ChannelConfig(TypedDict):
follow_up_tags: NotRequired[list[str]]
class StandardAnswerCategory(Base):
__tablename__ = "standard_answer_category"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",
secondary=StandardAnswer__StandardAnswerCategory.__table__,
back_populates="categories",
)
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
"SlackBotConfig",
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
back_populates="standard_answer_categories",
)
class StandardAnswer(Base):
__tablename__ = "standard_answer"
id: Mapped[int] = mapped_column(primary_key=True)
keyword: Mapped[str] = mapped_column(String)
answer: Mapped[str] = mapped_column(String)
active: Mapped[bool] = mapped_column(Boolean)
match_regex: Mapped[bool] = mapped_column(Boolean)
match_any_keywords: Mapped[bool] = mapped_column(Boolean)
__table_args__ = (
Index(
"unique_keyword_active",
keyword,
active,
unique=True,
postgresql_where=(active == True), # noqa: E712
),
)
categories: Mapped[list[StandardAnswerCategory]] = relationship(
"StandardAnswerCategory",
secondary=StandardAnswer__StandardAnswerCategory.__table__,
back_populates="standard_answers",
)
chat_messages: Mapped[list[ChatMessage]] = relationship(
"ChatMessage",
secondary=ChatMessage__StandardAnswer.__table__,
back_populates="standard_answers",
)
class SlackBotResponseType(str, PyEnum):
QUOTES = "quotes"
CITATIONS = "citations"
@@ -1424,7 +1430,7 @@ class SlackBotConfig(Base):
)
persona: Mapped[Persona | None] = relationship("Persona")
standard_answer_categories: Mapped[list[StandardAnswerCategory]] = relationship(
standard_answer_categories: Mapped[list["StandardAnswerCategory"]] = relationship(
"StandardAnswerCategory",
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
back_populates="slack_bot_configs",
@@ -1486,7 +1492,9 @@ class SamlAccount(Base):
__tablename__ = "saml"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), unique=True)
user_id: Mapped[int] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), unique=True
)
encrypted_cookie: Mapped[str] = mapped_column(Text, unique=True)
expires_at: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True))
updated_at: Mapped[datetime.datetime] = mapped_column(
@@ -1505,7 +1513,7 @@ class User__UserGroup(Base):
ForeignKey("user_group.id"), primary_key=True
)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id"), primary_key=True, nullable=True
ForeignKey("user.id", ondelete="CASCADE"), primary_key=True, nullable=True
)
@@ -1654,94 +1662,70 @@ class TokenRateLimit__UserGroup(Base):
)
class StandardAnswerCategory(Base):
__tablename__ = "standard_answer_category"
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",
secondary=StandardAnswer__StandardAnswerCategory.__table__,
back_populates="categories",
)
slack_bot_configs: Mapped[list["SlackBotConfig"]] = relationship(
"SlackBotConfig",
secondary=SlackBotConfig__StandardAnswerCategory.__table__,
back_populates="standard_answer_categories",
)
class StandardAnswer(Base):
__tablename__ = "standard_answer"
id: Mapped[int] = mapped_column(primary_key=True)
keyword: Mapped[str] = mapped_column(String)
answer: Mapped[str] = mapped_column(String)
active: Mapped[bool] = mapped_column(Boolean)
match_regex: Mapped[bool] = mapped_column(Boolean)
match_any_keywords: Mapped[bool] = mapped_column(Boolean)
__table_args__ = (
Index(
"unique_keyword_active",
keyword,
active,
unique=True,
postgresql_where=(active == True), # noqa: E712
),
)
categories: Mapped[list[StandardAnswerCategory]] = relationship(
"StandardAnswerCategory",
secondary=StandardAnswer__StandardAnswerCategory.__table__,
back_populates="standard_answers",
)
chat_messages: Mapped[list[ChatMessage]] = relationship(
"ChatMessage",
secondary=ChatMessage__StandardAnswer.__table__,
back_populates="standard_answers",
)
"""Tables related to Permission Sync"""
class PermissionSyncStatus(str, PyEnum):
IN_PROGRESS = "in_progress"
SUCCESS = "success"
FAILED = "failed"
class PermissionSyncJobType(str, PyEnum):
USER_LEVEL = "user_level"
GROUP_LEVEL = "group_level"
class PermissionSyncRun(Base):
"""Represents one run of a permission sync job. For some given cc_pair, it is either sync-ing
the users or it is sync-ing the groups"""
__tablename__ = "permission_sync_run"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
# Not strictly needed but makes it easy to use without fetching from cc_pair
source_type: Mapped[DocumentSource] = mapped_column(
Enum(DocumentSource, native_enum=False)
)
# Currently all sync jobs are handled as a group permission sync or a user permission sync
update_type: Mapped[PermissionSyncJobType] = mapped_column(
Enum(PermissionSyncJobType)
)
cc_pair_id: Mapped[int | None] = mapped_column(
ForeignKey("connector_credential_pair.id"), nullable=True
)
status: Mapped[PermissionSyncStatus] = mapped_column(Enum(PermissionSyncStatus))
error_msg: Mapped[str | None] = mapped_column(Text, default=None)
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
cc_pair: Mapped[ConnectorCredentialPair] = relationship("ConnectorCredentialPair")
class ExternalPermission(Base):
class User__ExternalUserGroupId(Base):
"""Maps user info both internal and external to the name of the external group
This maps the user to all of their external groups so that the external group name can be
attached to the ACL list matching during query time. User level permissions can be handled by
directly adding the Danswer user to the doc ACL list"""
__tablename__ = "external_permission"
__tablename__ = "user__external_user_group_id"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
# Email is needed because we want to keep track of users not in Danswer to simplify process
# when the user joins
user_email: Mapped[str] = mapped_column(String)
source_type: Mapped[DocumentSource] = mapped_column(
Enum(DocumentSource, native_enum=False)
)
external_permission_group: Mapped[str] = mapped_column(String)
user = relationship("User")
class EmailToExternalUserCache(Base):
"""A way to map users IDs in the external tool to a user in Danswer or at least an email for
when the user joins. Used as a cache for when fetching external groups which have their own
user ids, this can easily be mapped back to users already known in Danswer without needing
to call external APIs to get the user emails.
This way when groups are updated in the external tool and we need to update the mapping of
internal users to the groups, we can sync the internal users to the external groups they are
part of using this.
Ie. User Chris is part of groups alpha, beta, and we can update this if Chris is no longer
part of alpha in some external tool.
"""
__tablename__ = "email_to_external_user_cache"
id: Mapped[int] = mapped_column(Integer, primary_key=True)
external_user_id: Mapped[str] = mapped_column(String)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
# Email is needed because we want to keep track of users not in Danswer to simplify process
# when the user joins
user_email: Mapped[str] = mapped_column(String)
source_type: Mapped[DocumentSource] = mapped_column(
Enum(DocumentSource, native_enum=False)
)
user = relationship("User")
user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True)
# These group ids have been prefixed by the source type
external_user_group_id: Mapped[str] = mapped_column(String, primary_key=True)
cc_pair_id: Mapped[int] = mapped_column(ForeignKey("connector_credential_pair.id"))
class UsageReport(Base):
@@ -1757,7 +1741,7 @@ class UsageReport(Base):
# if None, report was auto-generated
requestor_user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id"), nullable=True
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()

View File

@@ -1,4 +1,5 @@
from collections.abc import Sequence
from datetime import datetime
from functools import lru_cache
from uuid import UUID
@@ -178,6 +179,7 @@ def create_update_persona(
except ValueError as e:
logger.exception("Failed to create persona")
raise HTTPException(status_code=400, detail=str(e))
return PersonaSnapshot.from_model(persona)
@@ -210,6 +212,22 @@ def update_persona_shared_users(
)
def update_persona_public_status(
persona_id: int,
is_public: bool,
db_session: Session,
user: User | None,
) -> None:
persona = fetch_persona_by_id(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
if user and user.role != UserRole.ADMIN and persona.user_id != user.id:
raise ValueError("You don't have permission to modify this persona")
persona.is_public = is_public
db_session.commit()
def get_prompts(
user_id: UUID | None,
db_session: Session,
@@ -242,7 +260,7 @@ def get_personas(
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
if not include_default:
stmt = stmt.where(Persona.default_persona.is_(False))
stmt = stmt.where(Persona.builtin_persona.is_(False))
if not include_slack_bot_personas:
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
if not include_deleted:
@@ -290,7 +308,7 @@ def mark_delete_persona_by_name(
) -> None:
stmt = (
update(Persona)
.where(Persona.name == persona_name, Persona.default_persona == is_default)
.where(Persona.name == persona_name, Persona.builtin_persona == is_default)
.values(deleted=True)
)
@@ -390,7 +408,6 @@ def upsert_persona(
document_set_ids: list[int] | None = None,
tool_ids: list[int] | None = None,
persona_id: int | None = None,
default_persona: bool = False,
commit: bool = True,
icon_color: str | None = None,
icon_shape: int | None = None,
@@ -398,6 +415,9 @@ def upsert_persona(
display_priority: int | None = None,
is_visible: bool = True,
remove_image: bool | None = None,
search_start_date: datetime | None = None,
builtin_persona: bool = False,
is_default_persona: bool = False,
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
chunks_below: int = CONTEXT_CHUNKS_BELOW,
) -> Persona:
@@ -438,8 +458,8 @@ def upsert_persona(
validate_persona_tools(tools)
if persona:
if not default_persona and persona.default_persona:
raise ValueError("Cannot update default persona with non-default.")
if not builtin_persona and persona.builtin_persona:
raise ValueError("Cannot update builtin persona with non-builtin.")
# this checks if the user has permission to edit the persona
persona = fetch_persona_by_id(
@@ -454,7 +474,7 @@ def upsert_persona(
persona.llm_relevance_filter = llm_relevance_filter
persona.llm_filter_extraction = llm_filter_extraction
persona.recency_bias = recency_bias
persona.default_persona = default_persona
persona.builtin_persona = builtin_persona
persona.llm_model_provider_override = llm_model_provider_override
persona.llm_model_version_override = llm_model_version_override
persona.starter_messages = starter_messages
@@ -466,6 +486,8 @@ def upsert_persona(
persona.uploaded_image_id = uploaded_image_id
persona.display_priority = display_priority
persona.is_visible = is_visible
persona.search_start_date = search_start_date
persona.is_default_persona = is_default_persona
# Do not delete any associations manually added unless
# a new updated list is provided
@@ -493,7 +515,7 @@ def upsert_persona(
llm_relevance_filter=llm_relevance_filter,
llm_filter_extraction=llm_filter_extraction,
recency_bias=recency_bias,
default_persona=default_persona,
builtin_persona=builtin_persona,
prompts=prompts or [],
document_sets=document_sets or [],
llm_model_provider_override=llm_model_provider_override,
@@ -505,6 +527,8 @@ def upsert_persona(
uploaded_image_id=uploaded_image_id,
display_priority=display_priority,
is_visible=is_visible,
search_start_date=search_start_date,
is_default_persona=is_default_persona,
)
db_session.add(persona)
@@ -534,7 +558,7 @@ def delete_old_default_personas(
Need a more graceful fix later or those need to never have IDs"""
stmt = (
update(Persona)
.where(Persona.default_persona, Persona.id > 0)
.where(Persona.builtin_persona, Persona.id > 0)
.values(deleted=True, name=func.concat(Persona.name, "_old"))
)
@@ -551,6 +575,7 @@ def update_persona_visibility(
persona = fetch_persona_by_id(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
persona.is_visible = is_visible
db_session.commit()
@@ -563,13 +588,15 @@ def validate_persona_tools(tools: list[Tool]) -> None:
)
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> Sequence[Prompt]:
def get_prompts_by_ids(prompt_ids: list[int], db_session: Session) -> list[Prompt]:
"""Unsafe, can fetch prompts from all users"""
if not prompt_ids:
return []
prompts = db_session.scalars(select(Prompt).where(Prompt.id.in_(prompt_ids))).all()
prompts = db_session.scalars(
select(Prompt).where(Prompt.id.in_(prompt_ids)).where(Prompt.deleted.is_(False))
).all()
return prompts
return list(prompts)
def get_prompt_by_id(
@@ -650,9 +677,7 @@ def get_persona_by_id(
result = db_session.execute(persona_stmt)
persona = result.scalar_one_or_none()
if persona is None:
raise ValueError(
f"Persona with ID {persona_id} does not exist or does not belong to user"
)
raise ValueError(f"Persona with ID {persona_id} does not exist")
return persona
# or check if user owns persona
@@ -715,7 +740,7 @@ def delete_persona_by_name(
persona_name: str, db_session: Session, is_default: bool = True
) -> None:
stmt = delete(Persona).where(
Persona.name == persona_name, Persona.default_persona == is_default
Persona.name == persona_name, Persona.builtin_persona == is_default
)
db_session.execute(stmt)

View File

@@ -1,4 +1,5 @@
from collections.abc import Sequence
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -14,8 +15,11 @@ from danswer.db.models import User
from danswer.db.persona import get_default_prompt
from danswer.db.persona import mark_persona_as_deleted
from danswer.db.persona import upsert_persona
from danswer.db.standard_answer import fetch_standard_answer_categories_by_ids
from danswer.search.enums import RecencyBiasSetting
from danswer.utils.errors import EERequiredError
from danswer.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
def _build_persona_name(channel_names: list[str]) -> str:
@@ -62,7 +66,7 @@ def create_slack_bot_persona(
llm_model_version_override=None,
starter_messages=None,
is_public=True,
default_persona=False,
is_default_persona=False,
db_session=db_session,
commit=False,
)
@@ -70,6 +74,10 @@ def create_slack_bot_persona(
return persona
def _no_ee_standard_answer_categories(*args: Any, **kwargs: Any) -> list:
return []
def insert_slack_bot_config(
persona_id: int | None,
channel_config: ChannelConfig,
@@ -78,14 +86,29 @@ def insert_slack_bot_config(
enable_auto_filters: bool,
db_session: Session,
) -> SlackBotConfig:
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=standard_answer_category_ids,
db_session=db_session,
)
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
raise ValueError(
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
versioned_fetch_standard_answer_categories_by_ids = (
fetch_versioned_implementation_with_fallback(
"danswer.db.standard_answer",
"fetch_standard_answer_categories_by_ids",
_no_ee_standard_answer_categories,
)
)
existing_standard_answer_categories = (
versioned_fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=standard_answer_category_ids,
db_session=db_session,
)
)
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
if len(existing_standard_answer_categories) == 0:
raise EERequiredError(
"Standard answers are a paid Enterprise Edition feature - enable EE or remove standard answer categories"
)
else:
raise ValueError(
f"Some or all categories with ids {standard_answer_category_ids} do not exist"
)
slack_bot_config = SlackBotConfig(
persona_id=persona_id,
@@ -117,9 +140,18 @@ def update_slack_bot_config(
f"Unable to find slack bot config with ID {slack_bot_config_id}"
)
existing_standard_answer_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=standard_answer_category_ids,
db_session=db_session,
versioned_fetch_standard_answer_categories_by_ids = (
fetch_versioned_implementation_with_fallback(
"danswer.db.standard_answer",
"fetch_standard_answer_categories_by_ids",
_no_ee_standard_answer_categories,
)
)
existing_standard_answer_categories = (
versioned_fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=standard_answer_category_ids,
db_session=db_session,
)
)
if len(existing_standard_answer_categories) != len(standard_answer_category_ids):
raise ValueError(

View File

@@ -1,202 +0,0 @@
from collections.abc import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import StandardAnswer
from danswer.db.models import StandardAnswerCategory
from danswer.utils.logger import setup_logger
logger = setup_logger()
def check_category_validity(category_name: str) -> bool:
"""If a category name is too long, it should not be used (it will cause an error in Postgres
as the unique constraint can only apply to entries that are less than 2704 bytes).
Additionally, extremely long categories are not really usable / useful."""
if len(category_name) > 255:
logger.error(
f"Category with name '{category_name}' is too long, cannot be used"
)
return False
return True
def insert_standard_answer_category(
category_name: str, db_session: Session
) -> StandardAnswerCategory:
if not check_category_validity(category_name):
raise ValueError(f"Invalid category name: {category_name}")
standard_answer_category = StandardAnswerCategory(name=category_name)
db_session.add(standard_answer_category)
db_session.commit()
return standard_answer_category
def insert_standard_answer(
keyword: str,
answer: str,
category_ids: list[int],
match_regex: bool,
match_any_keywords: bool,
db_session: Session,
) -> StandardAnswer:
existing_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=category_ids,
db_session=db_session,
)
if len(existing_categories) != len(category_ids):
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
standard_answer = StandardAnswer(
keyword=keyword,
answer=answer,
categories=existing_categories,
active=True,
match_regex=match_regex,
match_any_keywords=match_any_keywords,
)
db_session.add(standard_answer)
db_session.commit()
return standard_answer
def update_standard_answer(
standard_answer_id: int,
keyword: str,
answer: str,
category_ids: list[int],
match_regex: bool,
match_any_keywords: bool,
db_session: Session,
) -> StandardAnswer:
standard_answer = db_session.scalar(
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
)
if standard_answer is None:
raise ValueError(f"No standard answer with id {standard_answer_id}")
existing_categories = fetch_standard_answer_categories_by_ids(
standard_answer_category_ids=category_ids,
db_session=db_session,
)
if len(existing_categories) != len(category_ids):
raise ValueError(f"Some or all categories with ids {category_ids} do not exist")
standard_answer.keyword = keyword
standard_answer.answer = answer
standard_answer.categories = list(existing_categories)
standard_answer.match_regex = match_regex
standard_answer.match_any_keywords = match_any_keywords
db_session.commit()
return standard_answer
def remove_standard_answer(
standard_answer_id: int,
db_session: Session,
) -> None:
standard_answer = db_session.scalar(
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
)
if standard_answer is None:
raise ValueError(f"No standard answer with id {standard_answer_id}")
standard_answer.active = False
db_session.commit()
def update_standard_answer_category(
standard_answer_category_id: int,
category_name: str,
db_session: Session,
) -> StandardAnswerCategory:
standard_answer_category = db_session.scalar(
select(StandardAnswerCategory).where(
StandardAnswerCategory.id == standard_answer_category_id
)
)
if standard_answer_category is None:
raise ValueError(
f"No standard answer category with id {standard_answer_category_id}"
)
if not check_category_validity(category_name):
raise ValueError(f"Invalid category name: {category_name}")
standard_answer_category.name = category_name
db_session.commit()
return standard_answer_category
def fetch_standard_answer_category(
standard_answer_category_id: int,
db_session: Session,
) -> StandardAnswerCategory | None:
return db_session.scalar(
select(StandardAnswerCategory).where(
StandardAnswerCategory.id == standard_answer_category_id
)
)
def fetch_standard_answer_categories_by_ids(
standard_answer_category_ids: list[int],
db_session: Session,
) -> Sequence[StandardAnswerCategory]:
return db_session.scalars(
select(StandardAnswerCategory).where(
StandardAnswerCategory.id.in_(standard_answer_category_ids)
)
).all()
def fetch_standard_answer_categories(
db_session: Session,
) -> Sequence[StandardAnswerCategory]:
return db_session.scalars(select(StandardAnswerCategory)).all()
def fetch_standard_answer(
standard_answer_id: int,
db_session: Session,
) -> StandardAnswer | None:
return db_session.scalar(
select(StandardAnswer).where(StandardAnswer.id == standard_answer_id)
)
def fetch_standard_answers(db_session: Session) -> Sequence[StandardAnswer]:
return db_session.scalars(
select(StandardAnswer).where(StandardAnswer.active.is_(True))
).all()
def create_initial_default_standard_answer_category(db_session: Session) -> None:
default_category_id = 0
default_category_name = "General"
default_category = fetch_standard_answer_category(
standard_answer_category_id=default_category_id,
db_session=db_session,
)
if default_category is not None:
if default_category.name != default_category_name:
raise ValueError(
"DB is not in a valid initial state. "
"Default standard answer category does not have expected name."
)
return
standard_answer_category = StandardAnswerCategory(
id=default_category_id,
name=default_category_name,
)
db_session.add(standard_answer_category)
db_session.commit()

View File

@@ -5,6 +5,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import Tool
from danswer.server.features.tool.models import Header
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -25,6 +26,7 @@ def create_tool(
name: str,
description: str | None,
openapi_schema: dict[str, Any] | None,
custom_headers: list[Header] | None,
user_id: UUID | None,
db_session: Session,
) -> Tool:
@@ -33,6 +35,9 @@ def create_tool(
description=description,
in_code_tool_id=None,
openapi_schema=openapi_schema,
custom_headers=[header.dict() for header in custom_headers]
if custom_headers
else [],
user_id=user_id,
)
db_session.add(new_tool)
@@ -45,6 +50,7 @@ def update_tool(
name: str | None,
description: str | None,
openapi_schema: dict[str, Any] | None,
custom_headers: list[Header] | None,
user_id: UUID | None,
db_session: Session,
) -> Tool:
@@ -60,6 +66,8 @@ def update_tool(
tool.openapi_schema = openapi_schema
if user_id is not None:
tool.user_id = user_id
if custom_headers is not None:
tool.custom_headers = [header.dict() for header in custom_headers]
db_session.commit()
return tool

View File

@@ -2,6 +2,7 @@ from collections.abc import Sequence
from uuid import UUID
from fastapi_users.password import PasswordHelper
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -22,8 +23,23 @@ def list_users(
return db_session.scalars(stmt).unique().all()
def get_users_by_emails(
db_session: Session, emails: list[str]
) -> tuple[list[User], list[str]]:
# Use distinct to avoid duplicates
stmt = select(User).filter(User.email.in_(emails)) # type: ignore
found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list
found_users_emails = [user.email for user in found_users]
missing_user_emails = [email for email in emails if email not in found_users_emails]
return found_users, missing_user_emails
def get_user_by_email(email: str, db_session: Session) -> User | None:
user = db_session.query(User).filter(User.email == email).first() # type: ignore
user = (
db_session.query(User)
.filter(func.lower(User.email) == func.lower(email))
.first()
)
return user
@@ -34,20 +50,50 @@ def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None:
return user
def add_non_web_user_if_not_exists(email: str, db_session: Session) -> User:
user = get_user_by_email(email, db_session)
if user is not None:
return user
def _generate_non_web_user(email: str) -> User:
fastapi_users_pw_helper = PasswordHelper()
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
user = User(
return User(
email=email,
hashed_password=hashed_pass,
has_web_login=False,
role=UserRole.BASIC,
)
def add_non_web_user_if_not_exists(db_session: Session, email: str) -> User:
user = get_user_by_email(email, db_session)
if user is not None:
return user
user = _generate_non_web_user(email=email)
db_session.add(user)
db_session.commit()
return user
def add_non_web_user_if_not_exists__no_commit(db_session: Session, email: str) -> User:
user = get_user_by_email(email, db_session)
if user is not None:
return user
user = _generate_non_web_user(email=email)
db_session.add(user)
db_session.flush() # generate id
return user
def batch_add_non_web_user_if_not_exists__no_commit(
db_session: Session, emails: list[str]
) -> list[User]:
found_users, missing_user_emails = get_users_by_emails(db_session, emails)
new_users: list[User] = []
for email in missing_user_emails:
new_users.append(_generate_non_web_user(email=email))
db_session.add_all(new_users)
db_session.flush() # generate ids
return found_users + new_users

View File

@@ -177,6 +177,30 @@ class Updatable(abc.ABC):
- Whether the document is hidden or not, hidden documents are not returned from search
"""
@abc.abstractmethod
def update_single(self, update_request: UpdateRequest) -> None:
"""
Updates some set of chunks for a document. The document and fields to update
are specified in the update request. Each update request in the list applies
its changes to a list of document ids.
None values mean that the field does not need an update.
The rationale for a single update function is that it allows retries and parallelism
to happen at a higher / more strategic level, is simpler to read, and allows
us to individually handle error conditions per document.
Parameters:
- update_request: for a list of document ids in the update request, apply the same updates
to all of the documents with those ids.
Return:
- an HTTPStatus code. The code can used to decide whether to fail immediately,
retry, etc. Although this method likely hits an HTTP API behind the
scenes, the usage of HTTPStatus is a convenience and the interface is not
actually HTTP specific.
"""
raise NotImplementedError
@abc.abstractmethod
def update(self, update_requests: list[UpdateRequest]) -> None:
"""

View File

@@ -377,6 +377,91 @@ class VespaIndex(DocumentIndex):
time.monotonic() - update_start,
)
def update_single(self, update_request: UpdateRequest) -> None:
"""Note: if the document id does not exist, the update will be a no-op and the
function will complete with no errors or exceptions.
Handle other exceptions if you wish to implement retry behavior
"""
if len(update_request.document_ids) != 1:
raise ValueError("update_request must contain a single document id")
# Handle Vespa character limitations
# Mutating update_request but it's not used later anyway
update_request.document_ids = [
replace_invalid_doc_id_characters(doc_id)
for doc_id in update_request.document_ids
]
# update_start = time.monotonic()
# Fetch all chunks for each document ahead of time
index_names = [self.index_name]
if self.secondary_index_name:
index_names.append(self.secondary_index_name)
chunk_id_start_time = time.monotonic()
all_doc_chunk_ids: list[str] = []
for index_name in index_names:
for document_id in update_request.document_ids:
# this calls vespa and can raise http exceptions
doc_chunk_ids = get_all_vespa_ids_for_document_id(
document_id=document_id,
index_name=index_name,
filters=None,
get_large_chunks=True,
)
all_doc_chunk_ids.extend(doc_chunk_ids)
logger.debug(
f"Took {time.monotonic() - chunk_id_start_time:.2f} seconds to fetch all Vespa chunk IDs"
)
# Build the _VespaUpdateRequest objects
update_dict: dict[str, dict] = {"fields": {}}
if update_request.boost is not None:
update_dict["fields"][BOOST] = {"assign": update_request.boost}
if update_request.document_sets is not None:
update_dict["fields"][DOCUMENT_SETS] = {
"assign": {
document_set: 1 for document_set in update_request.document_sets
}
}
if update_request.access is not None:
update_dict["fields"][ACCESS_CONTROL_LIST] = {
"assign": {acl_entry: 1 for acl_entry in update_request.access.to_acl()}
}
if update_request.hidden is not None:
update_dict["fields"][HIDDEN] = {"assign": update_request.hidden}
if not update_dict["fields"]:
logger.error("Update request received but nothing to update")
return
processed_update_requests: list[_VespaUpdateRequest] = []
for document_id in update_request.document_ids:
for doc_chunk_id in all_doc_chunk_ids:
processed_update_requests.append(
_VespaUpdateRequest(
document_id=document_id,
url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}",
update_request=update_dict,
)
)
with httpx.Client(http2=True) as http_client:
for update in processed_update_requests:
http_client.put(
update.url,
headers={"Content-Type": "application/json"},
json=update.update_request,
)
# logger.debug(
# "Finished updating Vespa documents in %.2f seconds",
# time.monotonic() - update_start,
# )
return
def delete(self, doc_ids: list[str]) -> None:
logger.info(f"Deleting {len(doc_ids)} documents from Vespa")

View File

@@ -13,8 +13,6 @@ class ChatFileType(str, Enum):
DOC = "document"
# Plain text only contain the text
PLAIN_TEXT = "plain_text"
# csv types contain the binary data
CSV = "csv"
class FileDescriptor(TypedDict):

View File

@@ -1,4 +1,3 @@
import base64
from collections.abc import Callable
from io import BytesIO
from typing import Any
@@ -17,27 +16,6 @@ from danswer.file_store.models import InMemoryChatFile
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
def save_base64_image(base64_image: str) -> str:
with get_session_context_manager() as db_session:
if base64_image.startswith("data:image"):
base64_image = base64_image.split(",", 1)[1]
image_data = base64.b64decode(base64_image)
unique_id = str(uuid4())
file_io = BytesIO(image_data)
file_store = get_default_file_store(db_session)
file_store.save_file(
file_name=unique_id,
content=file_io,
display_name="GeneratedImage",
file_origin=FileOrigin.CHAT_IMAGE_GEN,
file_type="image/png",
)
return unique_id
def load_chat_file(
file_descriptor: FileDescriptor, db_session: Session
) -> InMemoryChatFile:

View File

@@ -265,7 +265,13 @@ def index_doc_batch(
Note that the documents should already be batched at this point so that it does not inflate the
memory requirements"""
no_access = DocumentAccess.build(user_ids=[], user_groups=[], is_public=False)
no_access = DocumentAccess.build(
user_emails=[],
user_groups=[],
external_user_emails=[],
external_user_group_ids=[],
is_public=False,
)
ctx = index_doc_batch_prepare(
document_batch=document_batch,

View File

@@ -16,7 +16,6 @@ from danswer.chat.models import LlmDoc
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.configs.constants import MessageType
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage
@@ -41,13 +40,11 @@ from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import ToolChoiceOptions
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.tools.analysis.analysis_tool import CSVAnalysisTool
from danswer.tools.custom.custom_tool_prompt_builder import (
build_user_message_for_custom_tool_for_non_tool_calling_llm,
)
from danswer.tools.force import filter_tools_for_force_tool_use
from danswer.tools.force import ForceUseTool
from danswer.tools.graphing.graphing_tool import GraphingTool
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
@@ -71,7 +68,6 @@ from danswer.tools.tool_runner import ToolRunner
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.logger import setup_logger
from shared_configs.configs import MAX_TOOL_CALLS
logger = setup_logger()
@@ -165,10 +161,6 @@ class Answer:
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
self._is_cancelled = False
self.final_context_docs: list = []
self.current_streamed_output: list = []
self.processing_stream: list = []
def _update_prompt_builder_for_search_tool(
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
) -> None:
@@ -187,6 +179,7 @@ class Answer:
if self.answer_style_config.citation_config
else False
),
history_message=self.single_message_history or "",
)
)
elif self.answer_style_config.quotes_config:
@@ -204,50 +197,41 @@ class Answer:
) -> Iterator[
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
]:
tool_calls = 0
initiated = False
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
while tool_calls < (1 if self.force_use_tool.force_use else MAX_TOOL_CALLS):
if initiated:
yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
initiated = True
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
tool_call_chunk: AIMessageChunk | None = None
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
tool_call_chunk = AIMessageChunk(content="")
tool_call_chunk.tool_calls = [
{
"name": self.force_use_tool.tool_name,
"args": self.force_use_tool.args,
"id": str(uuid4()),
}
]
else:
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
tool_call_chunk: AIMessageChunk | None = None
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
# / need to generate the args
tool_call_chunk = AIMessageChunk(
content="",
)
tool_call_chunk.tool_calls = [
{
"name": self.force_use_tool.tool_name,
"args": self.force_use_tool.args,
"id": str(uuid4()),
}
]
else:
# if tool calling is supported, first try the raw message
# to see if we don't need to use any tools
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question, self.prompt_config, self.latest_query_files
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question,
self.prompt_config,
self.latest_query_files,
tool_calls,
)
)
prompt = prompt_builder.build()
final_tool_definitions = [
tool.tool_definition()
for tool in filter_tools_for_force_tool_use(
self.tools, self.force_use_tool
)
prompt = prompt_builder.build()
]
final_tool_definitions = [
tool.tool_definition()
for tool in filter_tools_for_force_tool_use(
self.tools, self.force_use_tool
)
]
print(final_tool_definitions)
for message in self.llm.stream(
prompt=prompt,
tools=final_tool_definitions if final_tool_definitions else None,
@@ -274,129 +258,67 @@ class Answer:
)
if not tool_call_chunk:
logger.info("Skipped tool call but generated message")
return
return # no tool call needed
tool_call_requests = tool_call_chunk.tool_calls
print(tool_call_requests)
# if we have a tool call, we need to call the tool
tool_call_requests = tool_call_chunk.tool_calls
for tool_call_request in tool_call_requests:
known_tools_by_name = [
tool for tool in self.tools if tool.name == tool_call_request["name"]
]
for tool_call_request in tool_call_requests:
tool_calls += 1
known_tools_by_name = [
tool
for tool in self.tools
if tool.name == tool_call_request["name"]
]
if not known_tools_by_name:
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
if self.tools:
tool = self.tools[0]
else:
continue
if not known_tools_by_name:
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
if self.tools:
tool = self.tools[0]
else:
tool = known_tools_by_name[0]
continue
else:
tool = known_tools_by_name[0]
tool_args = (
self.force_use_tool.args
if self.force_use_tool.tool_name == tool.name
and self.force_use_tool.args
else tool_call_request["args"]
)
tool_args = (
self.force_use_tool.args
if self.force_use_tool.tool_name == tool.name
and self.force_use_tool.args
else tool_call_request["args"]
)
tool_runner = ToolRunner(tool, tool_args)
yield tool_runner.kickoff()
yield from tool_runner.tool_responses()
tool_runner = ToolRunner(tool, tool_args)
tool_kickoff = tool_runner.kickoff()
yield tool_kickoff
tool_call_summary = ToolCallSummary(
tool_call_request=tool_call_chunk,
tool_call_result=build_tool_message(
tool_call_request, tool_runner.tool_message_content()
),
)
yield from tool_runner.tool_responses()
tool_responses = []
for response in tool_runner.tool_responses():
tool_responses.append(response)
yield response
tool_call_summary = ToolCallSummary(
tool_call_request=tool_call_chunk,
tool_call_result=build_tool_message(
tool_call_request, tool_runner.tool_message_content()
),
)
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
self._update_prompt_builder_for_search_tool(prompt_builder, [])
elif tool.name == ImageGenerationTool._NAME:
img_urls = [
img_generation_result["url"]
for img_generation_result in tool_runner.tool_final_result().tool_result
]
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question, img_urls=img_urls
)
)
yield tool_runner.tool_final_result()
# Update message history with tool call and response
self.message_history.append(
PreviousMessage(
message=self.question,
message_type=MessageType.USER,
token_count=len(self.llm_tokenizer.encode(self.question)),
tool_call=None,
files=[],
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
self._update_prompt_builder_for_search_tool(prompt_builder, [])
elif tool.name == ImageGenerationTool._NAME:
img_urls = [
img_generation_result["url"]
for img_generation_result in tool_runner.tool_final_result().tool_result
]
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question, img_urls=img_urls
)
)
yield tool_runner.tool_final_result()
self.message_history.append(
PreviousMessage(
message=str(tool_call_request),
message_type=MessageType.ASSISTANT,
token_count=len(
self.llm_tokenizer.encode(str(tool_call_request))
),
tool_call=None,
files=[],
)
)
self.message_history.append(
PreviousMessage(
message="\n".join(str(response) for response in tool_responses),
message_type=MessageType.SYSTEM,
token_count=len(
self.llm_tokenizer.encode(
"\n".join(str(response) for response in tool_responses)
)
),
tool_call=None,
files=[],
)
)
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
# Generate response based on updated message history
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
yield from self._process_llm_stream(
prompt=prompt,
tools=[tool.tool_definition() for tool in self.tools],
)
response_content = ""
for content in self._process_llm_stream(prompt=prompt, tools=None):
if isinstance(content, str):
response_content += content
yield content
# Update message history with LLM response
self.message_history.append(
PreviousMessage(
message=response_content,
message_type=MessageType.ASSISTANT,
token_count=len(self.llm_tokenizer.encode(response_content)),
tool_call=None,
files=[],
)
)
return
# This method processes the LLM stream and yields the content or stop information
def _process_llm_stream(
@@ -425,234 +347,139 @@ class Answer:
) -> Iterator[
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
]:
tool_calls = 0
initiated = False
while tool_calls < (1 if self.force_use_tool.force_use else MAX_TOOL_CALLS):
if initiated:
yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
chosen_tool_and_args: tuple[Tool, dict] | None = None
initiated = True
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
chosen_tool_and_args: tuple[Tool, dict] | None = None
if self.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = next(
iter(
[
tool
for tool in self.tools
if tool.name == self.force_use_tool.tool_name
]
),
None,
)
if not tool:
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
if self.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = next(
iter(
[
tool
for tool in self.tools
if tool.name == self.force_use_tool.tool_name
]
),
None,
)
if not tool:
raise RuntimeError(
f"Tool '{self.force_use_tool.tool_name}' not found"
)
tool_args = (
self.force_use_tool.args
if self.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=self.question,
history=self.message_history,
llm=self.llm,
force_run=True,
)
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
chosen_tool_and_args = (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=self.tools,
tool_args = (
self.force_use_tool.args
if self.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=self.question,
history=self.message_history,
llm=self.llm,
force_run=True,
)
)
available_tools_and_args = [
(self.tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
chosen_tool_and_args = (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=self.tools,
query=self.question,
history=self.message_history,
llm=self.llm,
)
available_tools_and_args = [
(self.tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=self.message_history,
query=self.question,
llm=self.llm,
)
if available_tools_and_args
else None
)
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=self.message_history,
query=self.question,
llm=self.llm,
)
if available_tools_and_args
else None
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
if not chosen_tool_and_args:
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question, self.prompt_config, self.latest_query_files
)
)
prompt = prompt_builder.build()
yield from self._process_llm_stream(
prompt=prompt,
tools=None,
)
return
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
if not chosen_tool_and_args:
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question, self.prompt_config, self.latest_query_files
)
)
prompt = prompt_builder.build()
yield from self._process_llm_stream(
prompt=prompt,
tools=None,
)
return
tool_calls += 1
tool, tool_args = chosen_tool_and_args
print("tool args")
print(tool_args)
tool_runner = ToolRunner(tool, tool_args, self.llm)
yield tool_runner.kickoff()
tool_responses = []
file_name = tool_runner.args["filename"]
print(f"file ame is {file_name}")
csv_file = None
for message in self.message_history:
if message.files:
csv_file = next(
(file for file in message.files if file.filename == file_name),
None,
)
if csv_file:
break
print(self.latest_query_files)
if csv_file is None:
raise ValueError(
f"CSV file with name '{file_name}' not found in latest query files."
)
print("csv file found")
tool_runner.args["filename"] = csv_file.content
tool, tool_args = chosen_tool_and_args
tool_runner = ToolRunner(tool, tool_args)
yield tool_runner.kickoff()
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
final_context_documents = None
for response in tool_runner.tool_responses():
tool_responses.append(response)
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_context_documents = cast(list[LlmDoc], response.response)
yield response
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
final_context_documents = None
for response in tool_runner.tool_responses():
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_context_documents = cast(list[LlmDoc], response.response)
yield response
if final_context_documents is None:
raise RuntimeError(
f"{tool.name} did not return final context documents"
)
self._update_prompt_builder_for_search_tool(
prompt_builder, final_context_documents
if final_context_documents is None:
raise RuntimeError(
f"{tool.name} did not return final context documents"
)
elif tool.name == ImageGenerationTool._NAME:
for response in tool_runner.tool_responses():
if response.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], response.response
)
# img_urls = [img.url for img in img_generation_response]
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question,
img_urls=[img.url for img in img_generation_response],
)
)
yield response
elif tool.name == CSVAnalysisTool._NAME:
for response in tool_runner.tool_responses():
yield response
elif tool.name == GraphingTool._NAME:
for response in tool_runner.tool_responses():
print("RESOS")
print(response)
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question,
# img_urls=img_urls,
)
)
else:
prompt_builder.update_user_prompt(
HumanMessage(
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
self.question,
tool.name,
*tool_runner.tool_responses(),
)
)
)
final_result = tool_runner.tool_final_result()
yield final_result
# Update message history
self.message_history.extend(
[
PreviousMessage(
message=str(self.question),
message_type=MessageType.USER,
token_count=len(self.llm_tokenizer.encode(str(self.question))),
tool_call=None,
files=[],
),
PreviousMessage(
message=f"Tool used: {tool.name}",
message_type=MessageType.ASSISTANT,
token_count=len(
self.llm_tokenizer.encode(f"Tool used: {tool.name}")
),
tool_call=None,
files=[],
),
PreviousMessage(
message=str(final_result),
message_type=MessageType.SYSTEM,
token_count=len(self.llm_tokenizer.encode(str(final_result))),
tool_call=None,
files=[],
),
]
self._update_prompt_builder_for_search_tool(
prompt_builder, final_context_documents
)
elif tool.name == ImageGenerationTool._NAME:
img_urls = []
for response in tool_runner.tool_responses():
if response.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], response.response
)
img_urls = [img.url for img in img_generation_response]
# Generate response based on updated message history
prompt = prompt_builder.build()
yield response
response_content = ""
for content in self._process_llm_stream(prompt=prompt, tools=None):
if isinstance(content, str):
response_content += content
yield content
# Update message history with LLM response
self.message_history.append(
PreviousMessage(
message=response_content,
message_type=MessageType.ASSISTANT,
token_count=len(self.llm_tokenizer.encode(response_content)),
tool_call=None,
files=[],
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question,
img_urls=img_urls,
)
)
else:
prompt_builder.update_user_prompt(
HumanMessage(
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
self.question,
tool.name,
*tool_runner.tool_responses(),
)
)
)
final = tool_runner.tool_final_result()
yield final
prompt = prompt_builder.build()
yield from self._process_llm_stream(prompt=prompt, tools=None)
@property
def processed_streamed_output(self) -> AnswerStream:
@@ -669,8 +496,6 @@ class Answer:
else self._raw_output_for_non_explicit_tool_calling_llms()
)
self.processing_stream = []
def _process_stream(
stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo],
) -> AnswerStream:
@@ -711,70 +536,64 @@ class Answer:
yield message
else:
process_answer_stream_fn = _get_answer_stream_processor(
context_docs=final_context_docs or [],
# if doc selection is enabled, then search_results will be None,
# so we need to use the final_context_docs
doc_id_to_rank_map=map_document_id_order(
search_results or final_context_docs or []
),
answer_style_configs=self.answer_style_config,
)
# assumes all tool responses will come first, then the final answer
break
stream_stop_info = None
new_kickoff = None
if not self.skip_gen_ai_answer_generation:
process_answer_stream_fn = _get_answer_stream_processor(
context_docs=final_context_docs or [],
# if doc selection is enabled, then search_results will be None,
# so we need to use the final_context_docs
doc_id_to_rank_map=map_document_id_order(
search_results or final_context_docs or []
),
answer_style_configs=self.answer_style_config,
)
def _stream() -> Iterator[str]:
nonlocal stream_stop_info
nonlocal new_kickoff
stream_stop_info = None
yield cast(str, message)
for item in stream:
if isinstance(item, StreamStopInfo):
stream_stop_info = item
return
if isinstance(item, ToolCallKickoff):
new_kickoff = item
return
else:
yield cast(str, item)
def _stream() -> Iterator[str]:
nonlocal stream_stop_info
yield cast(str, message)
for item in stream:
if isinstance(item, StreamStopInfo):
stream_stop_info = item
return
yield from process_answer_stream_fn(_stream())
# this should never happen, but we're seeing weird behavior here so handling for now
if not isinstance(item, str):
logger.error(
f"Received non-string item in answer stream: {item}. Skipping."
)
continue
if stream_stop_info:
yield stream_stop_info
yield item
# handle new tool call (continuation of message)
if new_kickoff:
yield new_kickoff
yield from process_answer_stream_fn(_stream())
if stream_stop_info:
yield stream_stop_info
processed_stream = []
for processed_packet in _process_stream(output_generator):
if (
isinstance(processed_packet, StreamStopInfo)
and processed_packet.stop_reason == StreamStopReason.NEW_RESPONSE
):
self.current_streamed_output = self.processing_stream
self.processing_stream = []
self.processing_stream.append(processed_packet)
processed_stream.append(processed_packet)
yield processed_packet
self.current_streamed_output = self.processing_stream
self._processed_stream = self.processing_stream
self._processed_stream = processed_stream
@property
def llm_answer(self) -> str:
answer = ""
if not self._processed_stream and not self.current_streamed_output:
return ""
for packet in self.current_streamed_output or self._processed_stream or []:
for packet in self.processed_streamed_output:
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
return answer
@property
def citations(self) -> list[CitationInfo]:
citations: list[CitationInfo] = []
for packet in self.current_streamed_output:
for packet in self.processed_streamed_output:
if isinstance(packet, CitationInfo):
citations.append(packet)
@@ -789,4 +608,5 @@ class Answer:
if not self.is_connected():
logger.debug("Answer stream has been cancelled")
self._is_cancelled = not self.is_connected()
return self._is_cancelled

View File

@@ -33,7 +33,7 @@ class PreviousMessage(BaseModel):
token_count: int
message_type: MessageType
files: list[InMemoryChatFile]
tool_call: ToolCallFinalResult | None
tool_calls: list[ToolCallFinalResult]
@classmethod
def from_chat_message(
@@ -51,13 +51,14 @@ class PreviousMessage(BaseModel):
for file in available_files
if str(file.file_id) in message_file_ids
],
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
)
def to_langchain_msg(self) -> BaseMessage:

View File

@@ -36,10 +36,7 @@ def default_build_system_message(
def default_build_user_message(
user_query: str,
prompt_config: PromptConfig,
files: list[InMemoryChatFile] = [],
previous_tool_calls: int = 0,
user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = []
) -> HumanMessage:
user_prompt = (
CHAT_USER_CONTEXT_FREE_PROMPT.format(
@@ -48,12 +45,6 @@ def default_build_user_message(
if prompt_config.task_prompt
else user_query
)
if previous_tool_calls > 0:
user_prompt = (
f"You have already generated the above so do not call a tool if not necessary. "
f"Remember the query is: `{user_prompt}`"
)
user_prompt = user_prompt.strip()
user_msg = HumanMessage(
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
@@ -122,25 +113,25 @@ class AnswerPromptBuilder:
final_messages_with_tokens.append(self.user_message_and_token_cnt)
# if tool_call_summary:
# final_messages_with_tokens.append(
# (
# tool_call_summary.tool_call_request,
# check_message_tokens(
# tool_call_summary.tool_call_request,
# self.llm_tokenizer_encode_func,
# ),
# )
# )
# final_messages_with_tokens.append(
# (
# tool_call_summary.tool_call_result,
# check_message_tokens(
# tool_call_summary.tool_call_result,
# self.llm_tokenizer_encode_func,
# ),
# )
# )
if tool_call_summary:
final_messages_with_tokens.append(
(
tool_call_summary.tool_call_request,
check_message_tokens(
tool_call_summary.tool_call_request,
self.llm_tokenizer_encode_func,
),
)
)
final_messages_with_tokens.append(
(
tool_call_summary.tool_call_result,
check_message_tokens(
tool_call_summary.tool_call_result,
self.llm_tokenizer_encode_func,
),
)
)
return drop_messages_history_overflow(
final_messages_with_tokens, self.max_tokens

View File

@@ -29,6 +29,9 @@ from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
from danswer.search.models import InferenceChunk
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_prompt_tokens(prompt_config: PromptConfig) -> int:
@@ -156,6 +159,7 @@ def build_citations_user_message(
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
task_prompt=task_prompt_with_reminder,
user_query=question,
history_block=history_message,
)
user_prompt = user_prompt.strip()

View File

@@ -85,6 +85,15 @@ def extract_citations_from_stream(
curr_segment += token
llm_out += token
# Handle code blocks without language tags
if "`" in curr_segment:
if curr_segment.endswith("`"):
continue
elif "```" in curr_segment:
piece_that_comes_after = curr_segment.split("```")[1][0]
if piece_that_comes_after == "\n" and in_code_block(llm_out):
curr_segment = curr_segment.replace("```", "```plaintext")
citation_pattern = r"\[(\d+)\]"
citations_found = list(re.finditer(citation_pattern, curr_segment))

View File

@@ -29,7 +29,6 @@ from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import LLMConfig
from danswer.llm.interfaces import ToolChoiceOptions
from danswer.tools.graphing.models import GRAPHING_RESPONSE_ID
from danswer.utils.logger import setup_logger
@@ -99,7 +98,6 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
@@ -126,21 +124,12 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
"name": message.name,
}
elif isinstance(message, ToolMessage):
if message.id == GRAPHING_RESPONSE_ID:
message_dict = {
"tool_call_id": message.tool_call_id,
"role": "tool",
"name": message.name or "",
"content": "a graph",
}
else:
message_dict = {
"tool_call_id": message.tool_call_id,
"role": "tool",
"name": message.name or "",
"content": "a graph",
}
message_dict = {
"tool_call_id": message.tool_call_id,
"role": "tool",
"name": message.name or "",
"content": message.content,
}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:

View File

@@ -24,6 +24,8 @@ class WellKnownLLMProviderDescriptor(BaseModel):
OPENAI_PROVIDER_NAME = "openai"
OPEN_AI_MODEL_NAMES = [
"o1-mini",
"o1-preview",
"gpt-4",
"gpt-4o",
"gpt-4o-mini",

View File

@@ -47,7 +47,9 @@ if TYPE_CHECKING:
logger = setup_logger()
def litellm_exception_to_error_msg(e: Exception, llm: LLM) -> str:
def litellm_exception_to_error_msg(
e: Exception, llm: LLM, fallback_to_error_msg: bool = False
) -> str:
error_msg = str(e)
if isinstance(e, BadRequestError):
@@ -94,7 +96,7 @@ def litellm_exception_to_error_msg(e: Exception, llm: LLM) -> str:
error_msg = "Request timed out: The operation took too long to complete. Please try again."
elif isinstance(e, APIError):
error_msg = f"API error: An error occurred while communicating with the API. Details: {str(e)}"
else:
elif not fallback_to_error_msg:
error_msg = "An unexpected error occurred while processing your request. Please try again later."
return error_msg
@@ -112,7 +114,7 @@ def translate_danswer_msg_to_langchain(
content = build_content_with_imgs(msg.message, files)
if msg.message_type == MessageType.SYSTEM:
return SystemMessage(content=content)
raise ValueError("System messages are not currently part of history")
if msg.message_type == MessageType.ASSISTANT:
return AIMessage(content=content)
if msg.message_type == MessageType.USER:
@@ -133,21 +135,6 @@ def translate_history_to_basemessages(
return history_basemessages, history_token_counts
def _process_csv_file(file: InMemoryChatFile) -> str:
import pandas as pd
import io
df = pd.read_csv(io.StringIO(file.content.decode("utf-8")))
csv_preview = df.head().to_string()
file_name_section = (
f"CSV FILE NAME: {file.filename}\n"
if file.filename
else "CSV FILE (NO NAME PROVIDED):\n"
)
return f"{file_name_section}{CODE_BLOCK_PAT.format(csv_preview)}\n\n\n"
def _build_content(
message: str,
files: list[InMemoryChatFile] | None = None,
@@ -158,26 +145,16 @@ def _build_content(
if files
else None
)
csv_files = (
[file for file in files if file.file_type == ChatFileType.CSV]
if files
else None
)
if not text_files and not csv_files:
if not text_files:
return message
final_message_with_files = "FILES:\n\n"
for file in text_files or []:
for file in text_files:
file_content = file.content.decode("utf-8")
file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else ""
final_message_with_files += (
f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n"
)
for file in csv_files or []:
final_message_with_files += _process_csv_file(file)
final_message_with_files += message
return final_message_with_files

View File

@@ -51,7 +51,6 @@ from danswer.db.credentials import create_initial_public_credential
from danswer.db.document import check_docs_exist
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import init_sqlalchemy_engine
from danswer.db.engine import warm_up_connections
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import expire_index_attempts
from danswer.db.llm import fetch_default_provider
@@ -62,7 +61,6 @@ from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_current_search_settings
from danswer.db.search_settings import update_secondary_search_settings
from danswer.db.standard_answer import create_initial_default_standard_answer_category
from danswer.db.swap_index import check_index_swap
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
@@ -111,6 +109,8 @@ from danswer.server.query_and_chat.query_backend import (
from danswer.server.query_and_chat.query_backend import basic_router as query_router
from danswer.server.settings.api import admin_router as settings_admin_router
from danswer.server.settings.api import basic_router as settings_router
from danswer.server.settings.store import load_settings
from danswer.server.settings.store import store_settings
from danswer.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
@@ -125,10 +125,10 @@ from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import CORS_ALLOWED_ORIGIN
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
@@ -186,9 +186,6 @@ def setup_postgres(db_session: Session) -> None:
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.notice("Verifying default standard answer category exists.")
create_initial_default_standard_answer_category(db_session)
logger.notice("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls()
@@ -245,6 +242,12 @@ def update_default_multipass_indexing(db_session: Session) -> None:
)
update_current_search_settings(db_session, updated_settings)
# Update settings with GPU availability
settings = load_settings()
settings.gpu_enabled = gpu_available
store_settings(settings)
logger.notice(f"Updated settings with GPU availability: {gpu_available}")
else:
logger.debug(
"Existing docs or connectors found. Skipping multipass indexing update."
@@ -365,7 +368,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.notice("Generative AI Q&A disabled")
# fill up Postgres connection pools
await warm_up_connections()
# await warm_up_connections()
# We cache this at the beginning so there is no delay in the first telemetry
get_or_generate_uuid()
@@ -591,7 +594,7 @@ def get_application() -> FastAPI:
application.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Change this to the list of allowed origins if needed
allow_origins=CORS_ALLOWED_ORIGIN, # Configurable via environment variable
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],

View File

@@ -26,6 +26,7 @@ from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.chat import update_search_docs_table_with_relevance
from danswer.db.engine import get_session_context_manager
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.db.persona import get_prompt_by_id
from danswer.llm.answering.answer import Answer
@@ -60,7 +61,7 @@ from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
from ee.danswer.server.query_and_chat.utils import create_temporary_persona
logger = setup_logger()
@@ -118,7 +119,17 @@ def stream_answer_objects(
one_shot=True,
danswerbot_flow=danswerbot_flow,
)
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona)
temporary_persona: Persona | None = None
if query_req.persona_config is not None:
new_persona = create_temporary_persona(
db_session=db_session, persona_config=query_req.persona_config, user=user
)
temporary_persona = new_persona
persona = temporary_persona if temporary_persona else chat_session.persona
llm, fast_llm = get_llms_for_persona(persona=persona)
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
@@ -153,11 +164,11 @@ def stream_answer_objects(
prompt_id=query_req.prompt_id, user=None, db_session=db_session
)
if prompt is None:
if not chat_session.persona.prompts:
if not persona.prompts:
raise RuntimeError(
"Persona does not have any prompts - this should never happen"
)
prompt = chat_session.persona.prompts[0]
prompt = persona.prompts[0]
# Create the first User query message
new_user_message = create_new_chat_message(
@@ -174,9 +185,7 @@ def stream_answer_objects(
prompt_config = PromptConfig.from_model(prompt)
document_pruning_config = DocumentPruningConfig(
max_chunks=int(
chat_session.persona.num_chunks
if chat_session.persona.num_chunks is not None
else default_num_chunks
persona.num_chunks if persona.num_chunks is not None else default_num_chunks
),
max_tokens=max_document_tokens,
)
@@ -187,16 +196,16 @@ def stream_answer_objects(
evaluation_type=LLMEvaluationType.SKIP
if DISABLE_LLM_DOC_RELEVANCE
else query_req.evaluation_type,
persona=chat_session.persona,
persona=persona,
retrieval_options=query_req.retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
bypass_acl=bypass_acl,
chunks_above=query_req.chunks_above,
chunks_below=query_req.chunks_below,
full_doc=query_req.full_doc,
bypass_acl=bypass_acl,
)
answer_config = AnswerStyleConfig(
@@ -209,13 +218,15 @@ def stream_answer_objects(
question=query_msg.message,
answer_style_config=answer_config,
prompt_config=PromptConfig.from_model(prompt),
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona)),
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=persona)),
single_message_history=history_str,
tools=[search_tool],
force_use_tool=ForceUseTool(
force_use=True,
tool_name=search_tool.name,
args={"query": rephrased_query},
tools=[search_tool] if search_tool else [],
force_use_tool=(
ForceUseTool(
tool_name=search_tool.name,
args={"query": rephrased_query},
force_use=True,
)
),
# for now, don't use tool calling for this flow, as we haven't
# tested quotes with tool calling too much yet
@@ -223,9 +234,7 @@ def stream_answer_objects(
return_contexts=query_req.return_contexts,
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
)
# won't be any ImageGenerationDisplay responses since that tool is never passed in
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
# for one-shot flow, don't currently do anything with these
if isinstance(packet, ToolResponse):
@@ -261,6 +270,7 @@ def stream_answer_objects(
applied_time_cutoff=search_response_summary.final_filters.time_cutoff,
recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
)
yield initial_response
elif packet.id == SEARCH_DOC_CONTENT_ID:
@@ -287,6 +297,7 @@ def stream_answer_objects(
relevance_summary=evaluation_response,
)
yield evaluation_response
else:
yield packet

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import BaseModel
from pydantic import Field
from pydantic import model_validator
@@ -8,6 +10,8 @@ from danswer.chat.models import DanswerQuotes
from danswer.chat.models import QADocsResponse
from danswer.configs.constants import MessageType
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
from danswer.search.models import ChunkContext
from danswer.search.models import RerankingDetails
from danswer.search.models import RetrievalDetails
@@ -23,10 +27,49 @@ class ThreadMessage(BaseModel):
role: MessageType = MessageType.USER
class PromptConfig(BaseModel):
name: str
description: str = ""
system_prompt: str
task_prompt: str = ""
include_citations: bool = True
datetime_aware: bool = True
class DocumentSetConfig(BaseModel):
id: int
class ToolConfig(BaseModel):
id: int
class PersonaConfig(BaseModel):
name: str
description: str
search_type: SearchType = SearchType.SEMANTIC
num_chunks: float | None = None
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
prompts: list[PromptConfig] = Field(default_factory=list)
prompt_ids: list[int] = Field(default_factory=list)
document_set_ids: list[int] = Field(default_factory=list)
tools: list[ToolConfig] = Field(default_factory=list)
tool_ids: list[int] = Field(default_factory=list)
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
class DirectQARequest(ChunkContext):
persona_config: PersonaConfig | None = None
persona_id: int | None = None
messages: list[ThreadMessage]
prompt_id: int | None
persona_id: int
prompt_id: int | None = None
multilingual_query_expansion: list[str] | None = None
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
rerank_settings: RerankingDetails | None = None
@@ -43,6 +86,12 @@ class DirectQARequest(ChunkContext):
# If True, skips generative an AI response to the search query
skip_gen_ai_answer_generation: bool = False
@model_validator(mode="after")
def check_persona_fields(self) -> "DirectQARequest":
if (self.persona_config is None) == (self.persona_id is None):
raise ValueError("Exactly one of persona_config or persona_id must be set")
return self
@model_validator(mode="after")
def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest":
if self.chain_of_thought and self.prompt_id is not None:

View File

@@ -39,46 +39,6 @@ CHAT_USER_CONTEXT_FREE_PROMPT = f"""
""".strip()
GRAPHING_QUERY_REPHRASE_GRAPH = f"""
Given the following conversation and a follow-up input, generate Python code using matplotlib to create the requested graph.
IMPORTANT: The code should be complete and executable, including data generation if necessary.
Focus on creating a clear and informative visualization based on the user's request.
Guidelines:
- Import matplotlib.pyplot as plt at the beginning of the code.
- Generate sample data if specific data is not provided in the conversation.
- Use appropriate graph types (line, bar, scatter, pie) based on the nature of the data and request.
- Include proper labeling (title, x-axis, y-axis, legend) in the graph.
- Use plt.figure() to create the figure and assign it to a variable named 'fig'.
- Do not include plt.show() at the end of the code.
{GENERAL_SEP_PAT}
Chat History:
{{chat_history}}
{GENERAL_SEP_PAT}
Follow Up Input: {{question}}
Python Code for Graph (Respond with only the Python code to generate the graph):
```python
# Your code here
```
""".strip()
GRAPHING_GET_FILE_NAME_PROMPT = f"""
Given the following conversation, a follow-up input, and a list of available CSV files,
provide the name of the CSV file to analyze.
{GENERAL_SEP_PAT}
Chat History:
{{chat_history}}
{GENERAL_SEP_PAT}
Follow Up Input: {{question}}
{GENERAL_SEP_PAT}
Available CSV Files:
{{file_list}}
{GENERAL_SEP_PAT}
CSV File Name to Analyze:
"""
# Design considerations for the below:
# - In case of uncertainty, favor yes search so place the "yes" sections near the start of the
# prompt and after the no section as well to deemphasize the no section

View File

@@ -109,6 +109,9 @@ CITATIONS_PROMPT_FOR_TOOL_CALLING = f"""
Refer to the provided context documents when responding to me.{DEFAULT_IGNORE_STATEMENT} \
You should always get right to the point, and never use extraneous language.
CHAT HISTORY:
{{history_block}}
{{task_prompt}}
{QUESTION_PAT.upper()}

View File

@@ -84,6 +84,7 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
# Multilingual Expansion
multilingual_expansion=search_settings.multilingual_expansion,
rerank_api_url=search_settings.rerank_api_url,
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
)

View File

@@ -67,6 +67,9 @@ def retrieval_preprocessing(
]
time_filter = preset_filters.time_cutoff
if time_filter is None and persona:
time_filter = persona.search_start_date
source_filter = preset_filters.source_type
auto_detect_time_filter = True
@@ -154,7 +157,7 @@ def retrieval_preprocessing(
final_filters = IndexFilters(
source_type=preset_filters.source_type or predicted_source_filters,
document_set=preset_filters.document_set,
time_cutoff=preset_filters.time_cutoff or predicted_time_cutoff,
time_cutoff=time_filter or predicted_time_cutoff,
tags=preset_filters.tags, # Tags are never auto-extracted
access_control_list=user_acl_filters,
)

View File

@@ -16,8 +16,9 @@ from danswer.db.connector_credential_pair import remove_credential_from_connecto
from danswer.db.connector_credential_pair import (
update_connector_credential_pair_from_id,
)
from danswer.db.document import get_document_cnts_for_cc_pairs
from danswer.db.document import get_document_counts_for_cc_pairs
from danswer.db.engine import get_session
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
@@ -95,7 +96,7 @@ def get_cc_pair_full_info(
)
document_count_info_list = list(
get_document_cnts_for_cc_pairs(
get_document_counts_for_cc_pairs(
db_session=db_session,
cc_pair_identifiers=[cc_pair_identifier],
)
@@ -201,7 +202,7 @@ def associate_credential_to_connector(
db_session=db_session,
user=user,
target_group_ids=metadata.groups,
object_is_public=metadata.is_public,
object_is_public=metadata.access_type == AccessType.PUBLIC,
)
try:
@@ -211,7 +212,8 @@ def associate_credential_to_connector(
connector_id=connector_id,
credential_id=credential_id,
cc_pair_name=metadata.name,
is_public=True if metadata.is_public is None else metadata.is_public,
access_type=metadata.access_type,
auto_sync_options=metadata.auto_sync_options,
groups=metadata.groups,
)

View File

@@ -62,8 +62,9 @@ from danswer.db.credentials import delete_gmail_service_account_credentials
from danswer.db.credentials import delete_google_drive_service_account_credentials
from danswer.db.credentials import fetch_credential_by_id
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import get_document_cnts_for_cc_pairs
from danswer.db.document import get_document_counts_for_cc_pairs
from danswer.db.engine import get_session
from danswer.db.enums import AccessType
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
@@ -511,7 +512,7 @@ def get_connector_indexing_status(
for index_attempt in latest_index_attempts
}
document_count_info = get_document_cnts_for_cc_pairs(
document_count_info = get_document_counts_for_cc_pairs(
db_session=db_session,
cc_pair_identifiers=cc_pair_identifiers,
)
@@ -559,7 +560,7 @@ def get_connector_indexing_status(
cc_pair_status=cc_pair.status,
connector=ConnectorSnapshot.from_connector_db_model(connector),
credential=CredentialSnapshot.from_credential_db_model(credential),
public_doc=cc_pair.is_public,
access_type=cc_pair.access_type,
owner=credential.user.email if credential.user else "",
groups=group_cc_pair_relationships_dict.get(cc_pair.id, []),
last_finished_status=(
@@ -668,12 +669,15 @@ def create_connector_with_mock_credential(
credential = create_credential(
mock_credential, user=user, db_session=db_session
)
access_type = (
AccessType.PUBLIC if connector_data.is_public else AccessType.PRIVATE
)
response = add_credential_to_connector(
db_session=db_session,
user=user,
connector_id=cast(int, connector_response.id), # will aways be an int
credential_id=credential.id,
is_public=connector_data.is_public or False,
access_type=access_type,
cc_pair_name=connector_data.name,
groups=connector_data.groups,
)
@@ -968,7 +972,7 @@ def get_basic_connector_indexing_status(
)
for cc_pair in cc_pairs
]
document_count_info = get_document_cnts_for_cc_pairs(
document_count_info = get_document_counts_for_cc_pairs(
db_session=db_session,
cc_pair_identifiers=cc_pair_identifiers,
)

View File

@@ -10,6 +10,7 @@ from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import DocumentErrorSummary
from danswer.connectors.models import InputType
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import Connector
from danswer.db.models import ConnectorCredentialPair
@@ -218,7 +219,7 @@ class CCPairFullInfo(BaseModel):
number_of_index_attempts: int
last_index_attempt_status: IndexingStatus | None
latest_deletion_attempt: DeletionAttemptSnapshot | None
is_public: bool
access_type: AccessType
is_editable_for_current_user: bool
deletion_failure_message: str | None
@@ -261,7 +262,7 @@ class CCPairFullInfo(BaseModel):
number_of_index_attempts=number_of_index_attempts,
last_index_attempt_status=last_indexing_status,
latest_deletion_attempt=latest_deletion_attempt,
is_public=cc_pair_model.is_public,
access_type=cc_pair_model.access_type,
is_editable_for_current_user=is_editable_for_current_user,
deletion_failure_message=cc_pair_model.deletion_failure_message,
)
@@ -288,7 +289,7 @@ class ConnectorIndexingStatus(BaseModel):
credential: CredentialSnapshot
owner: str
groups: list[int]
public_doc: bool
access_type: AccessType
last_finished_status: IndexingStatus | None
last_status: IndexingStatus | None
last_success: datetime | None
@@ -306,7 +307,8 @@ class ConnectorCredentialPairIdentifier(BaseModel):
class ConnectorCredentialPairMetadata(BaseModel):
name: str | None = None
is_public: bool | None = None
access_type: AccessType
auto_sync_options: dict[str, Any] | None = None
groups: list[int] = Field(default_factory=list)

View File

@@ -47,7 +47,6 @@ class DocumentSet(BaseModel):
description: str
cc_pair_descriptors: list[ConnectorCredentialPairDescriptor]
is_up_to_date: bool
contains_non_public: bool
is_public: bool
# For Private Document Sets, who should be able to access these
users: list[UUID]
@@ -59,12 +58,6 @@ class DocumentSet(BaseModel):
id=document_set_model.id,
name=document_set_model.name,
description=document_set_model.description,
contains_non_public=any(
[
not cc_pair.is_public
for cc_pair in document_set_model.connector_credential_pairs
]
),
cc_pair_descriptors=[
ConnectorCredentialPairDescriptor(
id=cc_pair.id,

View File

@@ -3,6 +3,7 @@ from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import UploadFile
from pydantic import BaseModel
@@ -20,6 +21,7 @@ from danswer.db.persona import get_personas
from danswer.db.persona import mark_persona_as_deleted
from danswer.db.persona import mark_persona_as_not_deleted
from danswer.db.persona import update_all_personas_display_priority
from danswer.db.persona import update_persona_public_status
from danswer.db.persona import update_persona_shared_users
from danswer.db.persona import update_persona_visibility
from danswer.file_store.file_store import get_default_file_store
@@ -43,6 +45,10 @@ class IsVisibleRequest(BaseModel):
is_visible: bool
class IsPublicRequest(BaseModel):
is_public: bool
@admin_router.patch("/{persona_id}/visible")
def patch_persona_visibility(
persona_id: int,
@@ -58,6 +64,25 @@ def patch_persona_visibility(
)
@basic_router.patch("/{persona_id}/public")
def patch_user_presona_public_status(
persona_id: int,
is_public_request: IsPublicRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
update_persona_public_status(
persona_id=persona_id,
is_public=is_public_request.is_public,
db_session=db_session,
user=user,
)
except ValueError as e:
logger.exception("Failed to update persona public status")
raise HTTPException(status_code=403, detail=str(e))
@admin_router.put("/display-priority")
def patch_persona_display_priority(
display_priority_request: DisplayPriorityRequest,
@@ -122,25 +147,6 @@ def upload_file(
return {"file_id": file_id}
@admin_router.post("/upload-csv")
def upload_csv(
file: UploadFile,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> dict[str, str]:
file_store = get_default_file_store(db_session)
file_type = ChatFileType.CSV
file_id = str(uuid.uuid4())
file_store.save_file(
file_name=file_id,
content=file.file,
display_name=file.filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=file.content_type or file_type.value,
)
return {"file_id": file_id}
"""Endpoints for all"""

View File

@@ -1,3 +1,4 @@
from datetime import datetime
from uuid import UUID
from pydantic import BaseModel
@@ -12,7 +13,6 @@ from danswer.server.features.tool.api import ToolSnapshot
from danswer.server.models import MinimalUserSnapshot
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -38,6 +38,9 @@ class CreatePersonaRequest(BaseModel):
icon_shape: int | None = None
uploaded_image_id: str | None = None # New field for uploaded image
remove_image: bool | None = None
is_default_persona: bool = False
display_priority: int | None = None
search_start_date: datetime | None = None
class PersonaSnapshot(BaseModel):
@@ -54,7 +57,7 @@ class PersonaSnapshot(BaseModel):
llm_model_provider_override: str | None
llm_model_version_override: str | None
starter_messages: list[StarterMessage] | None
default_persona: bool
builtin_persona: bool
prompts: list[PromptSnapshot]
tools: list[ToolSnapshot]
document_sets: list[DocumentSet]
@@ -63,6 +66,8 @@ class PersonaSnapshot(BaseModel):
icon_color: str | None
icon_shape: int | None
uploaded_image_id: str | None = None
is_default_persona: bool
search_start_date: datetime | None = None
@classmethod
def from_model(
@@ -93,7 +98,8 @@ class PersonaSnapshot(BaseModel):
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
starter_messages=persona.starter_messages,
default_persona=persona.default_persona,
builtin_persona=persona.builtin_persona,
is_default_persona=persona.is_default_persona,
prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts],
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
document_sets=[
@@ -108,6 +114,7 @@ class PersonaSnapshot(BaseModel):
icon_color=persona.icon_color,
icon_shape=persona.icon_shape,
uploaded_image_id=persona.uploaded_image_id,
search_start_date=persona.search_start_date,
)

View File

@@ -15,6 +15,8 @@ from danswer.db.tools import delete_tool
from danswer.db.tools import get_tool_by_id
from danswer.db.tools import get_tools
from danswer.db.tools import update_tool
from danswer.server.features.tool.models import CustomToolCreate
from danswer.server.features.tool.models import CustomToolUpdate
from danswer.server.features.tool.models import ToolSnapshot
from danswer.tools.custom.openapi_parsing import MethodSpec
from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
@@ -24,18 +26,6 @@ router = APIRouter(prefix="/tool")
admin_router = APIRouter(prefix="/admin/tool")
class CustomToolCreate(BaseModel):
name: str
description: str | None = None
definition: dict[str, Any]
class CustomToolUpdate(BaseModel):
name: str | None = None
description: str | None = None
definition: dict[str, Any] | None = None
def _validate_tool_definition(definition: dict[str, Any]) -> None:
try:
validate_openapi_schema(definition)
@@ -54,6 +44,7 @@ def create_custom_tool(
name=tool_data.name,
description=tool_data.description,
openapi_schema=tool_data.definition,
custom_headers=tool_data.custom_headers,
user_id=user.id if user else None,
db_session=db_session,
)
@@ -74,6 +65,7 @@ def update_custom_tool(
name=tool_data.name,
description=tool_data.description,
openapi_schema=tool_data.definition,
custom_headers=tool_data.custom_headers,
user_id=user.id if user else None,
db_session=db_session,
)

View File

@@ -12,6 +12,7 @@ class ToolSnapshot(BaseModel):
definition: dict[str, Any] | None
display_name: str
in_code_tool_id: str | None
custom_headers: list[Any] | None
@classmethod
def from_model(cls, tool: Tool) -> "ToolSnapshot":
@@ -22,4 +23,24 @@ class ToolSnapshot(BaseModel):
definition=tool.openapi_schema,
display_name=tool.display_name or tool.name,
in_code_tool_id=tool.in_code_tool_id,
custom_headers=tool.custom_headers,
)
class Header(BaseModel):
key: str
value: str
class CustomToolCreate(BaseModel):
name: str
description: str | None = None
definition: dict[str, Any]
custom_headers: list[Header] | None = None
class CustomToolUpdate(BaseModel):
name: str | None = None
description: str | None = None
definition: dict[str, Any] | None = None
custom_headers: list[Header] | None = None

View File

@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_curator_or_admin_user
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from danswer.configs.constants import DanswerCeleryPriority
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import KV_GEN_AI_KEY_CHECK_TIME
from danswer.db.connector_credential_pair import get_connector_credential_pair
@@ -146,7 +147,7 @@ def create_deletion_attempt_for_connector_id(
db_session: Session = Depends(get_session),
) -> None:
from danswer.background.celery.celery_app import (
cleanup_connector_credential_pair_task,
check_for_connector_deletion_task,
)
connector_id = connector_credential_pair_identifier.connector_id
@@ -191,9 +192,10 @@ def create_deletion_attempt_for_connector_id(
cc_pair_id=cc_pair.id,
status=ConnectorCredentialPairStatus.DELETING,
)
# actually kick off the deletion
cleanup_connector_credential_pair_task.apply_async(
kwargs=dict(connector_id=connector_id, credential_id=credential_id),
# run the beat task to pick up this deletion early
check_for_connector_deletion_task.apply_async(
priority=DanswerCeleryPriority.HIGH,
)
if cc_pair.connector.source == DocumentSource.FILE:

View File

@@ -18,6 +18,9 @@ class TestEmbeddingRequest(BaseModel):
api_url: str | None = None
model_name: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
class CloudEmbeddingProvider(BaseModel):
provider_type: EmbeddingProvider

View File

@@ -3,6 +3,7 @@ from collections.abc import Callable
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
@@ -17,6 +18,7 @@ from danswer.llm.factory import get_default_llms
from danswer.llm.factory import get_llm
from danswer.llm.llm_provider_options import fetch_available_well_known_llms
from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from danswer.llm.utils import litellm_exception_to_error_msg
from danswer.llm.utils import test_llm
from danswer.server.manage.llm.models import FullLLMProvider
from danswer.server.manage.llm.models import LLMProviderDescriptor
@@ -77,7 +79,10 @@ def test_llm_configuration(
)
if error:
raise HTTPException(status_code=400, detail=error)
client_error_msg = litellm_exception_to_error_msg(
error, llm, fallback_to_error_msg=True
)
raise HTTPException(status_code=400, detail=client_error_msg)
@admin_router.post("/test/default")
@@ -118,10 +123,22 @@ def list_llm_providers(
@admin_router.put("/provider")
def put_llm_provider(
llm_provider: LLMProviderUpsertRequest,
is_creation: bool = Query(
True,
description="True if updating an existing provider, False if creating a new one",
),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> FullLLMProvider:
return upsert_llm_provider(llm_provider=llm_provider, db_session=db_session)
try:
return upsert_llm_provider(
llm_provider=llm_provider,
db_session=db_session,
is_creation=is_creation,
)
except ValueError as e:
logger.exception("Failed to upsert LLM Provider")
raise HTTPException(status_code=400, detail=str(e))
@admin_router.delete("/provider/{provider_id}")

View File

@@ -1,6 +1,4 @@
import re
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING
from pydantic import BaseModel
@@ -17,13 +15,12 @@ from danswer.db.models import AllowedAnswerFilters
from danswer.db.models import ChannelConfig
from danswer.db.models import SlackBotConfig as SlackBotConfigModel
from danswer.db.models import SlackBotResponseType
from danswer.db.models import StandardAnswer as StandardAnswerModel
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
from danswer.db.models import User
from danswer.search.models import SavedSearchSettings
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.models import FullUserSnapshot
from danswer.server.models import InvitedUserSnapshot
from ee.danswer.server.manage.models import StandardAnswerCategory
if TYPE_CHECKING:
@@ -43,6 +40,9 @@ class AuthTypeResponse(BaseModel):
class UserPreferences(BaseModel):
chosen_assistants: list[int] | None = None
hidden_assistants: list[int] = []
visible_assistants: list[int] = []
default_model: str | None = None
@@ -76,6 +76,8 @@ class UserInfo(BaseModel):
UserPreferences(
chosen_assistants=user.chosen_assistants,
default_model=user.default_model,
hidden_assistants=user.hidden_assistants,
visible_assistants=user.visible_assistants,
)
),
# set to None if TRACK_EXTERNAL_IDP_EXPIRY is False so that we avoid cases
@@ -119,95 +121,6 @@ class HiddenUpdateRequest(BaseModel):
hidden: bool
class StandardAnswerCategoryCreationRequest(BaseModel):
name: str
class StandardAnswerCategory(BaseModel):
id: int
name: str
@classmethod
def from_model(
cls, standard_answer_category: StandardAnswerCategoryModel
) -> "StandardAnswerCategory":
return cls(
id=standard_answer_category.id,
name=standard_answer_category.name,
)
class StandardAnswer(BaseModel):
id: int
keyword: str
answer: str
categories: list[StandardAnswerCategory]
match_regex: bool
match_any_keywords: bool
@classmethod
def from_model(cls, standard_answer_model: StandardAnswerModel) -> "StandardAnswer":
return cls(
id=standard_answer_model.id,
keyword=standard_answer_model.keyword,
answer=standard_answer_model.answer,
match_regex=standard_answer_model.match_regex,
match_any_keywords=standard_answer_model.match_any_keywords,
categories=[
StandardAnswerCategory.from_model(standard_answer_category_model)
for standard_answer_category_model in standard_answer_model.categories
],
)
class StandardAnswerCreationRequest(BaseModel):
keyword: str
answer: str
categories: list[int]
match_regex: bool
match_any_keywords: bool
@field_validator("categories", mode="before")
@classmethod
def validate_categories(cls, value: list[int]) -> list[int]:
if len(value) < 1:
raise ValueError(
"At least one category must be attached to a standard answer"
)
return value
@model_validator(mode="after")
def validate_only_match_any_if_not_regex(self) -> Any:
if self.match_regex and self.match_any_keywords:
raise ValueError(
"Can only match any keywords in keyword mode, not regex mode"
)
return self
@model_validator(mode="after")
def validate_keyword_if_regex(self) -> Any:
if not self.match_regex:
# no validation for keywords
return self
try:
re.compile(self.keyword)
return self
except re.error as err:
if isinstance(err.pattern, bytes):
raise ValueError(
f'invalid regex pattern r"{err.pattern.decode()}" in `keyword`: {err.msg}'
)
else:
pattern = f'r"{err.pattern}"' if err.pattern is not None else ""
raise ValueError(
" ".join(
["invalid regex pattern", pattern, f"in `keyword`: {err.msg}"]
)
)
class SlackBotTokens(BaseModel):
bot_token: str
app_token: str
@@ -233,6 +146,7 @@ class SlackBotConfigCreationRequest(BaseModel):
# list of user emails
follow_up_tags: list[str] | None = None
response_type: SlackBotResponseType
# XXX this is going away soon
standard_answer_categories: list[int] = Field(default_factory=list)
@field_validator("answer_filters", mode="before")
@@ -257,6 +171,7 @@ class SlackBotConfig(BaseModel):
persona: PersonaSnapshot | None
channel_config: ChannelConfig
response_type: SlackBotResponseType
# XXX this is going away soon
standard_answer_categories: list[StandardAnswerCategory]
enable_auto_filters: bool
@@ -275,6 +190,7 @@ class SlackBotConfig(BaseModel):
),
channel_config=slack_bot_config_model.channel_config,
response_type=slack_bot_config_model.response_type,
# XXX this is going away soon
standard_answer_categories=[
StandardAnswerCategory.from_model(standard_answer_category_model)
for standard_answer_category_model in slack_bot_config_model.standard_answer_categories

View File

@@ -108,6 +108,7 @@ def create_slack_bot_config(
persona_id=persona_id,
channel_config=channel_config,
response_type=slack_bot_config_creation_request.response_type,
# XXX this is going away soon
standard_answer_category_ids=slack_bot_config_creation_request.standard_answer_categories,
db_session=db_session,
enable_auto_filters=slack_bot_config_creation_request.enable_auto_filters,

View File

@@ -31,13 +31,18 @@ from danswer.configs.app_configs import VALID_EMAIL_DOMAINS
from danswer.configs.constants import AuthType
from danswer.db.engine import get_session
from danswer.db.models import AccessToken
from danswer.db.models import DocumentSet__User
from danswer.db.models import Persona__User
from danswer.db.models import SamlAccount
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.db.users import get_user_by_email
from danswer.db.users import list_users
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.server.manage.models import AllUsersResponse
from danswer.server.manage.models import UserByEmail
from danswer.server.manage.models import UserInfo
from danswer.server.manage.models import UserPreferences
from danswer.server.manage.models import UserRoleResponse
from danswer.server.manage.models import UserRoleUpdateRequest
from danswer.server.models import FullUserSnapshot
@@ -45,6 +50,7 @@ from danswer.server.models import InvitedUserSnapshot
from danswer.server.models import MinimalUserSnapshot
from danswer.utils.logger import setup_logger
from ee.danswer.db.api_key import is_api_key_email_address
from ee.danswer.db.external_perm import delete_user__ext_group_for_user__no_commit
from ee.danswer.db.user_group import remove_curator_status__no_commit
logger = setup_logger()
@@ -237,10 +243,25 @@ async def delete_user(
db_session.expunge(user_to_delete)
try:
# Delete related OAuthAccounts first
for oauth_account in user_to_delete.oauth_accounts:
db_session.delete(oauth_account)
delete_user__ext_group_for_user__no_commit(
db_session=db_session,
user_id=user_to_delete.id,
)
db_session.query(SamlAccount).filter(
SamlAccount.user_id == user_to_delete.id
).delete()
db_session.query(DocumentSet__User).filter(
DocumentSet__User.user_id == user_to_delete.id
).delete()
db_session.query(Persona__User).filter(
Persona__User.user_id == user_to_delete.id
).delete()
db_session.query(User__UserGroup).filter(
User__UserGroup.user_id == user_to_delete.id
).delete()
db_session.delete(user_to_delete)
db_session.commit()
@@ -254,6 +275,10 @@ async def delete_user(
logger.info(f"Deleted user {user_to_delete.email}")
except Exception as e:
import traceback
full_traceback = traceback.format_exc()
logger.error(f"Full stack trace:\n{full_traceback}")
db_session.rollback()
logger.error(f"Error deleting user {user_to_delete.email}: {str(e)}")
raise HTTPException(status_code=500, detail="Error deleting user")
@@ -423,3 +448,64 @@ def update_user_assistant_list(
.values(chosen_assistants=request.chosen_assistants)
)
db_session.commit()
def update_assistant_list(
preferences: UserPreferences, assistant_id: int, show: bool
) -> UserPreferences:
visible_assistants = preferences.visible_assistants or []
hidden_assistants = preferences.hidden_assistants or []
chosen_assistants = preferences.chosen_assistants or []
if show:
if assistant_id not in visible_assistants:
visible_assistants.append(assistant_id)
if assistant_id in hidden_assistants:
hidden_assistants.remove(assistant_id)
if assistant_id not in chosen_assistants:
chosen_assistants.append(assistant_id)
else:
if assistant_id in visible_assistants:
visible_assistants.remove(assistant_id)
if assistant_id not in hidden_assistants:
hidden_assistants.append(assistant_id)
if assistant_id in chosen_assistants:
chosen_assistants.remove(assistant_id)
preferences.visible_assistants = visible_assistants
preferences.hidden_assistants = hidden_assistants
preferences.chosen_assistants = chosen_assistants
return preferences
@router.patch("/user/assistant-list/update/{assistant_id}")
def update_user_assistant_visibility(
assistant_id: int,
show: bool,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
if user is None:
if AUTH_TYPE == AuthType.DISABLED:
store = get_dynamic_config_store()
no_auth_user = fetch_no_auth_user(store)
preferences = no_auth_user.preferences
updated_preferences = update_assistant_list(preferences, assistant_id, show)
set_no_auth_user_preferences(store, updated_preferences)
return
else:
raise RuntimeError("This should never happen")
user_preferences = UserInfo.from_model(user).preferences
updated_preferences = update_assistant_list(user_preferences, assistant_id, show)
db_session.execute(
update(User)
.where(User.id == user.id) # type: ignore
.values(
hidden_assistants=updated_preferences.hidden_assistants,
visible_assistants=updated_preferences.visible_assistants,
chosen_assistants=updated_preferences.chosen_assistants,
)
)
db_session.commit()

View File

@@ -164,7 +164,7 @@ def get_chat_session(
chat_session_id=session_id,
description=chat_session.description,
persona_id=chat_session.persona_id,
persona_name=chat_session.persona.name,
persona_name=chat_session.persona.name if chat_session.persona else None,
current_alternate_model=chat_session.current_alternate_model,
messages=[
translate_db_message_to_chat_message_detail(
@@ -281,17 +281,14 @@ async def is_disconnected(request: Request) -> Callable[[], bool]:
def is_disconnected_sync() -> bool:
future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop)
try:
result = not future.result(timeout=0.01)
return result
return not future.result(timeout=0.01)
except asyncio.TimeoutError:
logger.error("Asyncio timed out while checking client connection")
return True
except asyncio.CancelledError:
logger.error("Asyncio timed out")
return True
except Exception as e:
error_msg = str(e)
logger.critical(
f"An unexpected error occurred with the disconnect check coroutine: {error_msg}"
f"An unexpected error occured with the disconnect check coroutine: {error_msg}"
)
return True
@@ -520,6 +517,7 @@ def upload_files_for_chat(
image_content_types = {"image/jpeg", "image/png", "image/webp"}
text_content_types = {
"text/plain",
"text/csv",
"text/markdown",
"text/x-markdown",
"text/x-config",
@@ -529,9 +527,6 @@ def upload_files_for_chat(
"text/xml",
"application/x-yaml",
}
csv_content_types = {
"text/csv",
}
document_content_types = {
"application/pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
@@ -541,10 +536,8 @@ def upload_files_for_chat(
"application/epub+zip",
}
allowed_content_types = (
image_content_types.union(text_content_types)
.union(document_content_types)
.union(csv_content_types)
allowed_content_types = image_content_types.union(text_content_types).union(
document_content_types
)
for file in files:
@@ -552,12 +545,8 @@ def upload_files_for_chat(
if file.content_type in image_content_types:
error_detail = "Unsupported image file type. Supported image types include .jpg, .jpeg, .png, .webp."
elif file.content_type in text_content_types:
error_detail = "Unsupported text file type. Supported text types include .txt, .md, .mdx, .conf, "
error_detail = "Unsupported text file type. Supported text types include .txt, .csv, .md, .mdx, .conf, "
".log, .tsv."
elif file.content_type in csv_content_types:
error_detail = (
"Unsupported csv file type. Supported CSV types include .csv "
)
else:
error_detail = (
"Unsupported document file type. Supported document types include .pdf, .docx, .pptx, .xlsx, "
@@ -583,8 +572,6 @@ def upload_files_for_chat(
file_type = ChatFileType.IMAGE
elif file.content_type in document_content_types:
file_type = ChatFileType.DOC
elif file.content_type in csv_content_types:
file_type = ChatFileType.CSV
else:
file_type = ChatFileType.PLAIN_TEXT
@@ -597,7 +584,6 @@ def upload_files_for_chat(
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=file.content_type or file_type.value,
)
print(f"FILE TYPE IS {file_type}")
# if the file is a doc, extract text and store that so we don't need
# to re-extract it every time we send a message

View File

@@ -136,7 +136,7 @@ class RenameChatSessionResponse(BaseModel):
class ChatSessionDetails(BaseModel):
id: int
name: str
persona_id: int
persona_id: int | None = None
time_created: str
shared_status: ChatSessionSharedStatus
folder_id: int | None = None
@@ -178,7 +178,7 @@ class ChatMessageDetail(BaseModel):
chat_session_id: int | None = None
citations: dict[int, int] | None = None
files: list[FileDescriptor]
tool_call: ToolCallFinalResult | None
tool_calls: list[ToolCallFinalResult]
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
@@ -196,8 +196,8 @@ class SearchSessionDetailResponse(BaseModel):
class ChatSessionDetailResponse(BaseModel):
chat_session_id: int
description: str
persona_id: int
persona_name: str
persona_id: int | None = None
persona_name: str | None
messages: list[ChatMessageDetail]
time_created: datetime
shared_status: ChatSessionSharedStatus

View File

@@ -37,6 +37,7 @@ class Settings(BaseModel):
search_page_enabled: bool = True
default_page: PageType = PageType.SEARCH
maximum_chat_retention_days: int | None = None
gpu_enabled: bool | None = None
def check_validity(self) -> None:
chat_page_enabled = self.chat_page_enabled

View File

@@ -1,22 +1,19 @@
import base64
import json
from datetime import datetime
from typing import Any
class DateTimeAndBytesEncoder(json.JSONEncoder):
"""Custom JSON encoder that converts datetime objects to ISO format strings and bytes to base64."""
class DateTimeEncoder(json.JSONEncoder):
"""Custom JSON encoder that converts datetime objects to ISO format strings."""
def default(self, obj: Any) -> Any:
if isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(obj, bytes):
return base64.b64encode(obj).decode("utf-8")
return super().default(obj)
def get_json_line(
json_dict: dict[str, Any], encoder: type[json.JSONEncoder] = DateTimeAndBytesEncoder
json_dict: dict[str, Any], encoder: type[json.JSONEncoder] = DateTimeEncoder
) -> str:
"""
Convert a dictionary to a JSON string with datetime handling, and add a newline.

Some files were not shown because too many files have changed in this diff Show More