Compare commits

..

74 Commits
bot_nit ... k

Author SHA1 Message Date
Richard Kuo (Danswer)
a13dea160a use redis completion signal to double check exit code 2024-12-12 10:56:11 -08:00
pablonyx
ca172f3306 Merge pull request #3442 from onyx-dot-app/vespa_seeding_fix
Update initial seeding for latency requirements
2024-12-12 09:59:50 -08:00
pablodanswer
e5d0587efa pre-commit 2024-12-12 09:12:08 -08:00
pablonyx
a9516202fe update conditional (#3446) 2024-12-12 17:07:30 +00:00
Richard Kuo
d23fca96c4 reverse commit (fix later) 2024-12-11 22:19:10 -08:00
pablodanswer
a45724c899 run black 2024-12-11 19:18:06 -08:00
pablodanswer
34e250407a k 2024-12-11 19:14:10 -08:00
pablodanswer
046c0fbe3e update indexing 2024-12-11 19:08:05 -08:00
pablonyx
76595facef Merge pull request #3432 from onyx-dot-app/vercel_preview
Enable Vercel Preview
2024-12-11 18:55:14 -08:00
pablodanswer
af2d548766 k 2024-12-11 18:52:47 -08:00
Weves
7c29b1e028 add more egnyte failure logging 2024-12-11 18:19:55 -08:00
pablonyx
a52c821e78 Merge pull request #3436 from onyx-dot-app/cloud_improvements
cloud improvements
2024-12-11 17:06:06 -08:00
pablonyx
0770a587f1 remove slack workspace (#3394)
* remove slack workspace

* update client tokens

* fix up

* clean up docs

* fix up tests
2024-12-12 01:01:43 +00:00
hagen-danswer
748b79b0ef Added text for empty table and cascade delete for slack bot deletion (#3390)
* fixed fk issue for slack bot deletion

* Added text for empty table and cascade delete for slack bot deletion
2024-12-12 01:00:32 +00:00
pablonyx
9cacb373ef let users specify resourcing caps (#3403)
* let users specify resourcing caps

* functioanl resource limits

* improve defaults

* k

* update

* update comment + refer to proper resource

* self nit

* update var names
2024-12-12 00:59:41 +00:00
pablodanswer
21967d4b6f cloud improvements 2024-12-11 16:48:00 -08:00
pablodanswer
f5d638161b k 2024-12-11 15:35:44 -08:00
pablodanswer
0b5013b47d k 2024-12-11 15:34:26 -08:00
pablodanswer
1b846fbf06 update config 2024-12-11 15:17:11 -08:00
hagen-danswer
cae8a131a2 Made frontend conditional check for source (#3434) 2024-12-11 22:46:32 +00:00
pablonyx
72b4e8e9fe Clean citation cards (#3396)
* seed

* initial steps

* clean up

* fully clickable
2024-12-11 21:37:11 +00:00
pablonyx
c04e2f14d9 remove double x (#3387) 2024-12-11 21:36:58 +00:00
pablonyx
b40a12d5d7 clean up cursor pointers (#3385)
* update

* nit
2024-12-11 21:36:43 +00:00
pablonyx
5e7d454ebe Merge pull request #3433 from onyx-dot-app/silence_integration
Silence Slack Permission Sync test flakiness
2024-12-11 13:49:52 -08:00
pablodanswer
238509c536 silence 2024-12-11 13:48:37 -08:00
pablodanswer
d7f8cf8f18 testing 2024-12-11 13:36:10 -08:00
pablodanswer
5d810d373e k 2024-12-11 13:32:09 -08:00
joachim-danswer
9455576078 Mismatch issue of Documents shown and Citation number in text fix (#3421)
* Mismatch issue of Documents shown and Citation number in text fix

When document order presented to LLM differs from order shown to user, wrong doc numbers are cited.

Fix:
 - SearchTool.get_search_result  returns now final and initial ranking
 - initial ranking is passed through a few objects and used for replacement in citation processing

Notes:
 - the citation_num in the CitationInfo() object has not been changed.

* PR fixes

 - linting
 - removed erroneous tab
 - added a substitution test case
 - adjusted original citation extraction use case

* Included a key test and

* Fixed extra spaces

* Updated test documentation

Updated:
 - test_citation_substitution (changed description)
 - test_citation_processing (removed data only relevant for the substitution)
2024-12-11 19:58:24 +00:00
rkuo-danswer
71421bb782 better handling around index attempts that don't exist and remove unn… (#3417)
* better handling around index attempts that don't exist and remove unnecessary index attempt deletions

* don't delete index attempts, just update them

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-11 19:32:04 +00:00
pablonyx
b88cb388b7 Faster api hashing (#3423)
* migrate hashing to run faster v1

* k
2024-12-11 19:30:05 +00:00
Wendi
639986001f Fix bug (title overflow) (#3431) 2024-12-11 12:09:44 -08:00
pablonyx
e7a7e78969 clean up csv prompt + frontend (#3393)
* clean up csv prompt + frontend

* nit

* nit

* detect uploading

* upload
2024-12-11 19:10:34 +00:00
rkuo-danswer
e255ff7d23 editable refresh and prune for connectors (#3406)
* editable refresh and prune for connectors

* add extra validations on pruning/refresh frequency

* fix validation

* fix icon usage

* fix TextFormField error formatting

* nit

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
Co-authored-by: pablodanswer <pablo@danswer.ai>
2024-12-11 19:04:09 +00:00
pablonyx
1be2502112 finalize (#3398)
Co-authored-by: hagen-danswer <hagen@danswer.ai>
2024-12-11 18:52:20 +00:00
pablonyx
f2bedb8fdd Borders (#3388)
* remove double x

* incorporate base default padding for modals
2024-12-11 18:47:26 +00:00
pablonyx
637404f482 Connector page lists (pending feedback) (#3415)
* v1 (pending feedback)

* nits

* nit
2024-12-11 18:45:27 +00:00
pablonyx
daae146920 recognize updates (#3397) 2024-12-11 18:19:00 +00:00
pablonyx
d95959fb41 base role setting fix (#3381)
* base role setting fix

* update user tables

* finalize

* minor cleanup

* fix chromatic
2024-12-11 18:09:47 +00:00
rkuo-danswer
c667d28e7a update helm charts for onyx-dot-app rebrand (#3412)
* update helm charts for onyx-dot-app rebrand

* fix helm chart testing config

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-11 18:08:39 +00:00
pablonyx
9e0b482f47 k (#3399) 2024-12-11 18:05:39 +00:00
pablonyx
fa84eb657f cleaner citations (#3389) 2024-12-11 17:36:15 +00:00
pablonyx
264df3441b Various clean ups (#3413)
* tbd

* minor

* prettify

* update sidebar values
2024-12-11 17:19:14 +00:00
pablonyx
b9bad8b7a0 fix wikipedia icon (#3395) 2024-12-11 09:03:29 -08:00
pablonyx
600ebb6432 remove doc sets (#3400) 2024-12-11 16:31:14 +00:00
pablonyx
09fe8ea868 improved display - no odd cutoffs (#3401) 2024-12-11 16:09:19 +00:00
evan-danswer
ad6be03b4d centered score in feedbac panel (#3426) 2024-12-11 08:19:53 -08:00
rkuo-danswer
65d2511216 change text and formatting to guide users away from thinking "Back to… (#3382)
* change text and formatting to guide users away from thinking "Back to Danswer" is a back button

* regular text color and different icon

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-11 03:31:27 +00:00
Weves
113bf19c65 Remove dev-only check 2024-12-10 19:04:21 -08:00
Yuhong Sun
6026536110 Model Server Async (#3386)
* need-verify

* fix some lib calls

* k

* tests

* k

* k

* k

* Address the comments

* fix comment
2024-12-11 01:33:44 +00:00
Weves
056b671cd4 Small tweaks to get Egynte to work on our cloud 2024-12-10 17:43:46 -08:00
pablonyx
8d83ae2ee8 fix linear (#3402) 2024-12-11 00:45:06 +00:00
Yuhong Sun
ca988f5c5f Max File Size (#3422)
* k

* k

* k
2024-12-11 00:06:47 +00:00
Chris Weaver
4e4214b82c Egnyte connector (#3420) 2024-12-10 16:07:33 -08:00
Yuhong Sun
fe83f676df k (#3404) 2024-12-10 23:27:48 +00:00
hagen-danswer
6d6e12119b made external group emails lowercase (#3410) 2024-12-10 22:08:00 +00:00
pablonyx
1f2b7cb9c8 strip text for slackbot (#3416)
* stripe text for slackbot

* k
2024-12-10 21:42:35 +00:00
pablonyx
878a189011 delete input prompts (#3380)
* delete input prompts

* nit

* remove vestigial test

* nit
2024-12-10 21:36:40 +00:00
hagen-danswer
48c10271c2 fixed ephemeral slackbot messages (#3409) 2024-12-10 18:00:34 +00:00
evan-danswer
c6a79d847e fix typo (#3408)
expliticly -> explicitly
2024-12-10 16:44:42 +00:00
hagen-danswer
1bc3f8b96f Revert "Fixed ephemeral slackbot messages"
This reverts commit 7f6a6944d6.
2024-12-10 08:18:31 -08:00
hagen-danswer
7f6a6944d6 Fixed ephemeral slackbot messages 2024-12-10 07:57:28 -08:00
Weves
06f4146597 Bump litellm to support Nova models from AWS 2024-12-09 21:19:11 -08:00
hagen-danswer
7ea73d5a5a Temp slackbot url error Fix (#3392) 2024-12-09 18:34:38 -08:00
Weves
30dfe6dcb4 Add better vertex support + LLM form cleanup 2024-12-09 13:44:44 -08:00
Yuhong Sun
dc5d5dfe05 README Update (#3383) 2024-12-09 13:17:53 -08:00
pablonyx
0746e0be5b unify toggling (#3378) 2024-12-09 19:48:06 +00:00
Chris Weaver
970320bd49 Persona / prompt hardening (#3375)
* Persona / prompt hardening

* fix it
2024-12-09 03:39:59 +00:00
Chris Weaver
4a7bd5578e Fix Confluence perm sync for cloud users (#3374) 2024-12-09 01:41:30 +00:00
Chris Weaver
874b098a4b Add more logging + retries to teams connector (#3369) 2024-12-08 00:56:34 +00:00
pablodanswer
ce18b63eea hide oauth sources (#3368) 2024-12-07 23:57:37 +00:00
Yuhong Sun
7a919c3589 Dev Version Niceness 2024-12-07 15:10:13 -08:00
rkuo-danswer
631bac4432 Bugfix/log exit code (#3362)
* log the exit code of the spawned task

* exitcode can be negative

* mypy fixes
2024-12-06 22:32:59 +00:00
hagen-danswer
53428f6e9c More logging/fixes (#3364)
* More logging for external group syncing

* Fixed edge case where some spaces were not being fetched

* made refresh frequency for confluence syncs configurable

* clarity
2024-12-06 21:56:29 +00:00
pablodanswer
53b3dcbace fix slackbot channel config nullable (#3363)
* fix slackbot

* nit
2024-12-06 21:24:36 +00:00
160 changed files with 3800 additions and 3254 deletions

View File

@@ -1,48 +1,48 @@
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
<!-- DANSWER_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/README.md"} -->
<a name="readme-top"></a>
<h2 align="center">
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/LogoOnyx.png?raw=true)" /></a>
</h2>
<p align="center">
<p align="center">Open Source Gen-AI Chat + Unified Search.</p>
<p align="center">Open Source Gen-AI + Enterprise Search.</p>
<p align="center">
<a href="https://docs.danswer.dev/" target="_blank">
<a href="https://docs.onyx.app/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2sslpdbyq-iIbTaNIVPBw_i_4vrujLYQ" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
<img src="https://img.shields.io/badge/discord-join-blue.svg?logo=discord&logoColor=white" alt="Discord">
</a>
<a href="https://github.com/danswer-ai/danswer/blob/main/README.md" target="_blank">
<a href="https://github.com/onyx-dot-app/onyx/blob/main/README.md" target="_blank">
<img src="https://img.shields.io/static/v1?label=license&message=MIT&color=blue" alt="License">
</a>
</p>
<strong>[Danswer](https://www.danswer.ai/)</strong> is the AI Assistant connected to your company's docs, apps, and people.
Danswer provides a Chat interface and plugs into any LLM of your choice. Danswer can be deployed anywhere and for any
<strong>[Onyx](https://www.onyx.app/)</strong> (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
own control. Danswer is MIT licensed and designed to be modular and easily extensible. The system also comes fully ready
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for
configuring Personas (AI Assistants) and their Prompts.
configuring AI Assistants.
Danswer also serves as a Unified Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
By combining LLMs and team specific knowledge, Danswer becomes a subject matter expert for the team. Imagine ChatGPT if
Onyx also serves as a Enterprise Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
By combining LLMs and team specific knowledge, Onyx becomes a subject matter expert for the team. Imagine ChatGPT if
it had access to your team's unique knowledge! It enables questions such as "A customer wants feature X, is this already
supported?" or "Where's the pull request for feature Y?"
<h3>Usage</h3>
Danswer Web App:
Onyx Web App:
https://github.com/danswer-ai/danswer/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
Or, plug Danswer into your existing Slack workflows (more integrations to come 😁):
Or, plug Onyx into your existing Slack workflows (more integrations to come 😁):
https://github.com/danswer-ai/danswer/assets/25087905/3e19739b-d178-4371-9a38-011430bdec1b
@@ -52,16 +52,16 @@ For more details on the Admin UI to manage connectors and users, check out our
## Deployment
Danswer can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
`docker compose` command. Checkout our [docs](https://docs.danswer.dev/quickstart) to learn more.
Onyx can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/danswer-ai/danswer/tree/main/deployment/kubernetes).
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment/kubernetes).
## 💃 Main Features
* Chat UI with the ability to select documents to chat with.
* Create custom AI Assistants with different prompts and backing knowledge sets.
* Connect Danswer with LLM of your choice (self-host for a fully airgapped solution).
* Connect Onyx with LLM of your choice (self-host for a fully airgapped solution).
* Document Search + AI Answers for natural language queries.
* Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
* Slack integration to get answers and search results directly in Slack.
@@ -75,12 +75,12 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
* Organizational understanding and ability to locate and suggest experts from your team.
## Other Notable Benefits of Danswer
## Other Notable Benefits of Onyx
* User Authentication with document level access management.
* Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
* Admin Dashboard to configure connectors, document-sets, access, etc.
* Custom deep learning models + learn from user feedback.
* Easy deployment and ability to host Danswer anywhere of your choosing.
* Easy deployment and ability to host Onyx anywhere of your choosing.
## 🔌 Connectors
@@ -108,10 +108,10 @@ Efficiently pulls the latest changes from:
## 📚 Editions
There are two editions of Danswer:
There are two editions of Onyx:
* Danswer Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Danswer you will get if you follow the Deployment guide above.
* Danswer Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
* Onyx Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Onyx you will get if you follow the Deployment guide above.
* Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
* Single Sign-On (SSO), with support for both SAML and OIDC
* Role-based access control
* Document permission inheritance from connected sources
@@ -119,24 +119,24 @@ There are two editions of Danswer:
* Whitelabeling
* API key authentication
* Encryption of secrets
* Any many more! Checkout [our website](https://www.danswer.ai/) for the latest.
* Any many more! Checkout [our website](https://www.onyx.app/) for the latest.
To try the Danswer Enterprise Edition:
To try the Onyx Enterprise Edition:
1. Checkout our [Cloud product](https://app.danswer.ai/signup).
2. For self-hosting, contact us at [founders@danswer.ai](mailto:founders@danswer.ai) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
1. Checkout our [Cloud product](https://cloud.onyx.app/signup).
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
## ⭐Star History
[![Star History Chart](https://api.star-history.com/svg?repos=danswer-ai/danswer&type=Date)](https://star-history.com/#danswer-ai/danswer&Date)
[![Star History Chart](https://api.star-history.com/svg?repos=onyx-dot-app/onyx&type=Date)](https://star-history.com/#onyx-dot-app/onyx&Date)
## ✨Contributors
<a href="https://github.com/danswer-ai/danswer/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=danswer-ai/danswer"/>
<a href="https://github.com/onyx-dot-app/onyx/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=onyx-dot-app/onyx"/>
</a>
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">

View File

@@ -0,0 +1,57 @@
"""delete_input_prompts
Revision ID: bf7a81109301
Revises: f7a894b06d02
Create Date: 2024-12-09 12:00:49.884228
"""
from alembic import op
import sqlalchemy as sa
import fastapi_users_db_sqlalchemy
# revision identifiers, used by Alembic.
revision = "bf7a81109301"
down_revision = "f7a894b06d02"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_table("inputprompt__user")
op.drop_table("inputprompt")
def downgrade() -> None:
op.create_table(
"inputprompt",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("prompt", sa.String(), nullable=False),
sa.Column("content", sa.String(), nullable=False),
sa.Column("active", sa.Boolean(), nullable=False),
sa.Column("is_public", sa.Boolean(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"inputprompt__user",
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["input_prompt_id"],
["inputprompt.id"],
),
sa.ForeignKeyConstraint(
["user_id"],
["inputprompt.id"],
),
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
)

View File

@@ -1,3 +1,4 @@
import hashlib
import secrets
import uuid
from urllib.parse import quote
@@ -18,7 +19,8 @@ _API_KEY_HEADER_NAME = "Authorization"
# organizations like the Internet Engineering Task Force (IETF).
_API_KEY_HEADER_ALTERNATIVE_NAME = "X-Danswer-Authorization"
_BEARER_PREFIX = "Bearer "
_API_KEY_PREFIX = "dn_"
_API_KEY_PREFIX = "on_"
_DEPRECATED_API_KEY_PREFIX = "dn_"
_API_KEY_LEN = 192
@@ -52,7 +54,9 @@ def extract_tenant_from_api_key_header(request: Request) -> str | None:
api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
if not api_key.startswith(_API_KEY_PREFIX):
if not api_key.startswith(_API_KEY_PREFIX) and not api_key.startswith(
_DEPRECATED_API_KEY_PREFIX
):
return None
parts = api_key[len(_API_KEY_PREFIX) :].split(".", 1)
@@ -63,10 +67,19 @@ def extract_tenant_from_api_key_header(request: Request) -> str | None:
return unquote(tenant_id) if tenant_id else None
def _deprecated_hash_api_key(api_key: str) -> str:
return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS)
def hash_api_key(api_key: str) -> str:
# NOTE: no salt is needed, as the API key is randomly generated
# and overlaps are impossible
return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS)
if api_key.startswith(_API_KEY_PREFIX):
return hashlib.sha256(api_key.encode("utf-8")).hexdigest()
elif api_key.startswith(_DEPRECATED_API_KEY_PREFIX):
return _deprecated_hash_api_key(api_key)
else:
raise ValueError(f"Invalid API key prefix: {api_key[:3]}")
def build_displayable_api_key(api_key: str) -> str:

View File

@@ -9,7 +9,6 @@ from danswer.utils.special_types import JSON_ro
def get_invited_users() -> list[str]:
try:
store = get_kv_store()
return cast(list, store.load(KV_USER_STORE_KEY))
except KvKeyNotFoundError:
return list()

View File

@@ -131,7 +131,7 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
def user_needs_to_be_verified() -> bool:
if AUTH_TYPE == AuthType.BASIC:
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
return REQUIRE_EMAIL_VERIFICATION
# For other auth types, if the user is authenticated it's assumed that

View File

@@ -598,7 +598,7 @@ def connector_indexing_proxy_task(
db_session,
"Connector termination signal detected",
)
finally:
except Exception:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
logger.exception(
@@ -640,14 +640,41 @@ def connector_indexing_proxy_task(
continue
if job.status == "error":
task_logger.error(
"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"error={job.exception()}"
)
ignore_exitcode = False
exit_code: int | None = None
if job.process:
exit_code = job.process.exitcode
# seeing non-deterministic behavior where spawned tasks occasionally return exit code 1
# even though logging clearly indicates that they completed successfully
# to work around this, we ignore the job error state if the completion signal is OK
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if ignore_exitcode:
task_logger.warning(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code}"
)
else:
task_logger.error(
"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code} "
f"error={job.exception()}"
)
job.release()
break

View File

@@ -680,17 +680,28 @@ def monitor_ccpair_indexing_taskset(
)
task_logger.warning(msg)
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
if index_attempt:
if (
index_attempt.status != IndexingStatus.CANCELED
and index_attempt.status != IndexingStatus.FAILED
):
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
try:
index_attempt = get_index_attempt(
db_session, payload.index_attempt_id
)
if index_attempt:
if (
index_attempt.status != IndexingStatus.CANCELED
and index_attempt.status != IndexingStatus.FAILED
):
mark_attempt_failed(
index_attempt_id=payload.index_attempt_id,
db_session=db_session,
failure_reason=msg,
)
except Exception:
task_logger.exception(
"monitor_ccpair_indexing_taskset - transient exception marking index attempt as failed: "
f"attempt={payload.index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
redis_connector_index.reset()
return

View File

@@ -82,7 +82,7 @@ class SimpleJob:
return "running"
elif self.process.exitcode is None:
return "cancelled"
elif self.process.exitcode > 0:
elif self.process.exitcode != 0:
return "error"
else:
return "finished"
@@ -123,7 +123,8 @@ class SimpleJobClient:
self._cleanup_completed_jobs()
if len(self.jobs) >= self.n_workers:
logger.debug(
f"No available workers to run job. Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
f"No available workers to run job. "
f"Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
)
return None

View File

@@ -206,7 +206,9 @@ class Answer:
# + figure out what the next LLM call should be
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
search_result = SearchTool.get_search_result(current_llm_call) or []
search_result, displayed_search_results_map = SearchTool.get_search_result(
current_llm_call
) or ([], {})
# Quotes are no longer supported
# answer_handler: AnswerResponseHandler
@@ -224,6 +226,7 @@ class Answer:
answer_handler = CitationResponseHandler(
context_docs=search_result,
doc_id_to_rank_map=map_document_id_order(search_result),
display_doc_order_dict=displayed_search_results_map,
)
response_handler_manager = LLMResponseHandlerManager(

View File

@@ -35,13 +35,18 @@ class DummyAnswerResponseHandler(AnswerResponseHandler):
class CitationResponseHandler(AnswerResponseHandler):
def __init__(
self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
self,
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_order_dict: dict[str, int],
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.display_doc_order_dict = display_doc_order_dict
self.citation_processor = CitationProcessor(
context_docs=self.context_docs,
doc_id_to_rank_map=self.doc_id_to_rank_map,
display_doc_order_dict=self.display_doc_order_dict,
)
self.processed_text = ""
self.citations: list[CitationInfo] = []

View File

@@ -22,12 +22,16 @@ class CitationProcessor:
self,
context_docs: list[LlmDoc],
doc_id_to_rank_map: DocumentIdOrderMapping,
display_doc_order_dict: dict[str, int],
stop_stream: str | None = STOP_STREAM_PAT,
):
self.context_docs = context_docs
self.doc_id_to_rank_map = doc_id_to_rank_map
self.stop_stream = stop_stream
self.order_mapping = doc_id_to_rank_map.order_mapping
self.display_doc_order_dict = (
display_doc_order_dict # original order of docs to displayed to user
)
self.llm_out = ""
self.max_citation_num = len(context_docs)
self.citation_order: list[int] = []
@@ -98,6 +102,18 @@ class CitationProcessor:
self.citation_order.index(real_citation_num) + 1
)
# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_doc_order_dict:
displayed_citation_num = self.display_doc_order_dict[
context_llm_doc.document_id
]
else:
displayed_citation_num = real_citation_num
logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
)
# Skip consecutive citations of the same work
if target_citation_num in self.current_citations:
start, end = citation.span()
@@ -118,6 +134,7 @@ class CitationProcessor:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# stay with the original for now (order of LLM cites)
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
@@ -139,6 +156,7 @@ class CitationProcessor:
if target_citation_num not in self.cited_inds:
self.cited_inds.add(target_citation_num)
yield CitationInfo(
# stay with the original for now (order of LLM cites)
citation_num=target_citation_num,
document_id=context_llm_doc.document_id,
)
@@ -148,7 +166,8 @@ class CitationProcessor:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{target_citation_num}]]({link})"
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
# + f"[[{target_citation_num}]]({link})"
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
@@ -156,7 +175,8 @@ class CitationProcessor:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{target_citation_num}]]()"
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
# + f"[[{target_citation_num}]]()"
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length

View File

@@ -348,6 +348,12 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
)
# Egnyte specific configs
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
)
@@ -411,21 +417,28 @@ LARGE_CHUNK_RATIO = 4
# We don't want the metadata to overwhelm the actual contents of the chunk
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
# Timeout to wait for job's last update before killing it, in hours
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))
CLEANUP_INDEXING_JOBS_TIMEOUT = int(
os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT") or 3
)
# The indexer will warn in the logs whenver a document exceeds this threshold (in bytes)
INDEXING_SIZE_WARNING_THRESHOLD = int(
os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD", 100 * 1024 * 1024)
os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD") or 100 * 1024 * 1024
)
# during indexing, will log verbose memory diff stats every x batches and at the end.
# 0 disables this behavior and is the default.
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0)
# During an indexing attempt, specifies the number of batches which are allowed to
# exception without aborting the attempt.
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT", 0))
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0)
# Maximum file size in a document to be indexed
MAX_DOCUMENT_CHARS = int(os.environ.get("MAX_DOCUMENT_CHARS") or 5_000_000)
MAX_FILE_SIZE_BYTES = int(
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
) # 2GB in bytes
#####
# Miscellaneous

View File

@@ -3,7 +3,6 @@ import os
PROMPTS_YAML = "./danswer/seeding/prompts.yaml"
PERSONAS_YAML = "./danswer/seeding/personas.yaml"
INPUT_PROMPT_YAML = "./danswer/seeding/input_prompts.yaml"
NUM_RETURNED_HITS = 50
# Used for LLM filtering and reranking

View File

@@ -132,6 +132,7 @@ class DocumentSource(str, Enum):
NOT_APPLICABLE = "not_applicable"
FRESHDESK = "freshdesk"
FIREFLIES = "fireflies"
EGNYTE = "egnyte"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]

View File

@@ -368,4 +368,5 @@ def build_confluence_client(
backoff_and_retry=True,
max_backoff_retries=10,
max_backoff_seconds=60,
cloud=is_cloud,
)

View File

@@ -0,0 +1,384 @@
import io
import os
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from logging import Logger
from typing import Any
from typing import cast
from typing import IO
import requests
from retry import retry
from danswer.configs.app_configs import EGNYTE_BASE_DOMAIN
from danswer.configs.app_configs import EGNYTE_CLIENT_ID
from danswer.configs.app_configs import EGNYTE_CLIENT_SECRET
from danswer.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import OAuthConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import ConnectorMissingCredentialError
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.file_processing.extract_file_text import detect_encoding
from danswer.file_processing.extract_file_text import extract_file_text
from danswer.file_processing.extract_file_text import get_file_ext
from danswer.file_processing.extract_file_text import is_text_file_extension
from danswer.file_processing.extract_file_text import is_valid_file_ext
from danswer.file_processing.extract_file_text import read_text_file
from danswer.utils.logger import setup_logger
logger = setup_logger()
_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
_TIMEOUT = 60
def _request_with_retries(
method: str,
url: str,
data: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
timeout: int = _TIMEOUT,
stream: bool = False,
tries: int = 8,
delay: float = 1,
backoff: float = 2,
) -> requests.Response:
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
def _make_request() -> requests.Response:
response = requests.request(
method,
url,
data=data,
headers=headers,
params=params,
timeout=timeout,
stream=stream,
)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
if e.response.status_code != 403:
logger.exception(
f"Failed to call Egnyte API.\n"
f"URL: {url}\n"
f"Headers: {headers}\n"
f"Data: {data}\n"
f"Params: {params}"
)
raise e
return response
return _make_request()
def _parse_last_modified(last_modified: str) -> datetime:
return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace(
tzinfo=timezone.utc
)
def _process_egnyte_file(
file_metadata: dict[str, Any],
file_content: IO,
base_url: str,
folder_path: str | None = None,
) -> Document | None:
"""Process an Egnyte file into a Document object
Args:
file_data: The file data from Egnyte API
file_content: The raw content of the file in bytes
base_url: The base URL for the Egnyte instance
folder_path: Optional folder path to filter results
"""
# Skip if file path doesn't match folder path filter
if folder_path and not file_metadata["path"].startswith(folder_path):
raise ValueError(
f"File path {file_metadata['path']} does not match folder path {folder_path}"
)
file_name = file_metadata["name"]
extension = get_file_ext(file_name)
if not is_valid_file_ext(extension):
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
return None
# Extract text content based on file type
if is_text_file_extension(file_name):
encoding = detect_encoding(file_content)
file_content_raw, file_metadata = read_text_file(
file_content, encoding=encoding, ignore_danswer_metadata=False
)
else:
file_content_raw = extract_file_text(
file=file_content,
file_name=file_name,
break_on_unprocessable=True,
)
# Build the web URL for the file
web_url = f"{base_url}/navigate/file/{file_metadata['group_id']}"
# Create document metadata
metadata: dict[str, str | list[str]] = {
"file_path": file_metadata["path"],
"last_modified": file_metadata.get("last_modified", ""),
}
# Add lock info if present
if lock_info := file_metadata.get("lock_info"):
metadata[
"lock_owner"
] = f"{lock_info.get('first_name', '')} {lock_info.get('last_name', '')}"
# Create the document owners
primary_owner = None
if uploaded_by := file_metadata.get("uploaded_by"):
primary_owner = BasicExpertInfo(
email=uploaded_by, # Using username as email since that's what we have
)
# Create the document
return Document(
id=f"egnyte-{file_metadata['entry_id']}",
sections=[Section(text=file_content_raw.strip(), link=web_url)],
source=DocumentSource.EGNYTE,
semantic_identifier=file_name,
metadata=metadata,
doc_updated_at=(
_parse_last_modified(file_metadata["last_modified"])
if "last_modified" in file_metadata
else None
),
primary_owners=[primary_owner] if primary_owner else None,
)
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
def __init__(
self,
folder_path: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.domain = "" # will always be set in `load_credentials`
self.folder_path = folder_path or "" # Root folder if not specified
self.batch_size = batch_size
self.access_token: str | None = None
@classmethod
def oauth_id(cls) -> DocumentSource:
return DocumentSource.EGNYTE
@classmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
if EGNYTE_LOCALHOST_OVERRIDE:
base_domain = EGNYTE_LOCALHOST_OVERRIDE
callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte"
return (
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
f"?client_id={EGNYTE_CLIENT_ID}"
f"&redirect_uri={callback_uri}"
f"&scope=Egnyte.filesystem"
f"&state={state}"
f"&response_type=code"
)
@classmethod
def oauth_code_to_token(cls, code: str) -> dict[str, Any]:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_CLIENT_SECRET:
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
# Exchange code for token
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
data = {
"client_id": EGNYTE_CLIENT_ID,
"client_secret": EGNYTE_CLIENT_SECRET,
"code": code,
"grant_type": "authorization_code",
"redirect_uri": f"{EGNYTE_LOCALHOST_OVERRIDE or ''}/connector/oauth/callback/egnyte",
"scope": "Egnyte.filesystem",
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = _request_with_retries(
method="POST",
url=url,
data=data,
headers=headers,
# try a lot faster since this is a realtime flow
backoff=0,
delay=0.1,
)
if not response.ok:
raise RuntimeError(f"Failed to exchange code for token: {response.text}")
token_data = response.json()
return {
"domain": EGNYTE_BASE_DOMAIN,
"access_token": token_data["access_token"],
}
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.domain = credentials["domain"]
self.access_token = credentials["access_token"]
return None
def _get_files_list(
self,
path: str,
) -> list[dict[str, Any]]:
if not self.access_token or not self.domain:
raise ConnectorMissingCredentialError("Egnyte")
headers = {
"Authorization": f"Bearer {self.access_token}",
}
params: dict[str, Any] = {
"list_content": True,
}
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{path or ''}"
response = _request_with_retries(
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
)
if not response.ok:
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
data = response.json()
all_files: list[dict[str, Any]] = []
# Add files from current directory
all_files.extend(data.get("files", []))
# Recursively traverse folders
for item in data.get("folders", []):
all_files.extend(self._get_files_list(item["path"]))
return all_files
def _filter_files(
self,
files: list[dict[str, Any]],
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> list[dict[str, Any]]:
filtered_files = []
for file in files:
if file["is_folder"]:
continue
file_modified = _parse_last_modified(file["last_modified"])
if start_time and file_modified < start_time:
continue
if end_time and file_modified > end_time:
continue
filtered_files.append(file)
return filtered_files
def _process_files(
self,
start_time: datetime | None = None,
end_time: datetime | None = None,
) -> Generator[list[Document], None, None]:
files = self._get_files_list(self.folder_path)
files = self._filter_files(files, start_time, end_time)
current_batch: list[Document] = []
for file in files:
try:
# Set up request with streaming enabled
headers = {
"Authorization": f"Bearer {self.access_token}",
}
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}"
response = _request_with_retries(
method="GET",
url=url,
headers=headers,
timeout=_TIMEOUT,
stream=True,
)
if not response.ok:
logger.error(
f"Failed to fetch file content: {file['path']} (status code: {response.status_code})"
)
continue
# Stream the response content into a BytesIO buffer
buffer = io.BytesIO()
for chunk in response.iter_content(chunk_size=8192):
if chunk:
buffer.write(chunk)
# Reset buffer's position to the start
buffer.seek(0)
# Process the streamed file content
doc = _process_egnyte_file(
file_metadata=file,
file_content=buffer,
base_url=_EGNYTE_APP_BASE.format(domain=self.domain),
folder_path=self.folder_path,
)
if doc is not None:
current_batch.append(doc)
if len(current_batch) >= self.batch_size:
yield current_batch
current_batch = []
except Exception:
logger.exception(f"Failed to process file {file['path']}")
continue
if current_batch:
yield current_batch
def load_from_state(self) -> GenerateDocumentsOutput:
yield from self._process_files()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_time = datetime.fromtimestamp(start, tz=timezone.utc)
end_time = datetime.fromtimestamp(end, tz=timezone.utc)
yield from self._process_files(start_time=start_time, end_time=end_time)
if __name__ == "__main__":
connector = EgnyteConnector()
connector.load_credentials(
{
"domain": os.environ["EGNYTE_DOMAIN"],
"access_token": os.environ["EGNYTE_ACCESS_TOKEN"],
}
)
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@@ -15,6 +15,7 @@ from danswer.connectors.danswer_jira.connector import JiraConnector
from danswer.connectors.discourse.connector import DiscourseConnector
from danswer.connectors.document360.connector import Document360Connector
from danswer.connectors.dropbox.connector import DropboxConnector
from danswer.connectors.egnyte.connector import EgnyteConnector
from danswer.connectors.file.connector import LocalFileConnector
from danswer.connectors.fireflies.connector import FirefliesConnector
from danswer.connectors.freshdesk.connector import FreshdeskConnector
@@ -40,7 +41,6 @@ from danswer.connectors.salesforce.connector import SalesforceConnector
from danswer.connectors.sharepoint.connector import SharepointConnector
from danswer.connectors.slab.connector import SlabConnector
from danswer.connectors.slack.connector import SlackPollConnector
from danswer.connectors.slack.load_connector import SlackLoadConnector
from danswer.connectors.teams.connector import TeamsConnector
from danswer.connectors.web.connector import WebConnector
from danswer.connectors.wikipedia.connector import WikipediaConnector
@@ -63,7 +63,6 @@ def identify_connector_class(
DocumentSource.WEB: WebConnector,
DocumentSource.FILE: LocalFileConnector,
DocumentSource.SLACK: {
InputType.LOAD_STATE: SlackLoadConnector,
InputType.POLL: SlackPollConnector,
InputType.SLIM_RETRIEVAL: SlackPollConnector,
},
@@ -103,6 +102,7 @@ def identify_connector_class(
DocumentSource.XENFORO: XenforoConnector,
DocumentSource.FRESHDESK: FreshdeskConnector,
DocumentSource.FIREFLIES: FirefliesConnector,
DocumentSource.EGNYTE: EgnyteConnector,
}
connector_by_source = connector_map.get(source, {})

View File

@@ -17,11 +17,11 @@ from danswer.connectors.models import BasicExpertInfo
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.db.engine import get_session_with_tenant
from danswer.file_processing.extract_file_text import check_file_ext_is_valid
from danswer.file_processing.extract_file_text import detect_encoding
from danswer.file_processing.extract_file_text import extract_file_text
from danswer.file_processing.extract_file_text import get_file_ext
from danswer.file_processing.extract_file_text import is_text_file_extension
from danswer.file_processing.extract_file_text import is_valid_file_ext
from danswer.file_processing.extract_file_text import load_files_from_zip
from danswer.file_processing.extract_file_text import read_pdf_file
from danswer.file_processing.extract_file_text import read_text_file
@@ -50,7 +50,7 @@ def _read_files_and_metadata(
file_content, ignore_dirs=True
):
yield os.path.join(directory_path, file_info.filename), file, metadata
elif check_file_ext_is_valid(extension):
elif is_valid_file_ext(extension):
yield file_name, file_content, metadata
else:
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
@@ -63,7 +63,7 @@ def _process_file(
pdf_pass: str | None = None,
) -> list[Document]:
extension = get_file_ext(file_name)
if not check_file_ext_is_valid(extension):
if not is_valid_file_ext(extension):
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
return []

View File

@@ -4,11 +4,13 @@ from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any
from typing import cast
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.app_configs import MAX_FILE_SIZE_BYTES
from danswer.configs.constants import DocumentSource
from danswer.connectors.google_drive.doc_conversion import build_slim_document
from danswer.connectors.google_drive.doc_conversion import (
@@ -452,12 +454,14 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if isinstance(self.creds, ServiceAccountCredentials)
else self._manage_oauth_retrieval
)
return retrieval_method(
drive_files = retrieval_method(
is_slim=is_slim,
start=start,
end=end,
)
return drive_files
def _extract_docs_from_google_drive(
self,
start: SecondsSinceUnixEpoch | None = None,
@@ -473,6 +477,15 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
files_to_process = []
# Gather the files into batches to be processed in parallel
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
if (
file.get("size")
and int(cast(str, file.get("size"))) > MAX_FILE_SIZE_BYTES
):
logger.warning(
f"Skipping file {file.get('name', 'Unknown')} as it is too large: {file.get('size')} bytes"
)
continue
files_to_process.append(file)
if len(files_to_process) >= LARGE_BATCH_SIZE:
yield from _process_files_batch(

View File

@@ -16,7 +16,7 @@ logger = setup_logger()
FILE_FIELDS = (
"nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, "
"shortcutDetails, owners(emailAddress))"
"shortcutDetails, owners(emailAddress), size)"
)
SLIM_FILE_FIELDS = (
"nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), "

View File

@@ -2,6 +2,7 @@ import abc
from collections.abc import Iterator
from typing import Any
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import Document
from danswer.connectors.models import SlimDocument
@@ -64,6 +65,23 @@ class SlimConnector(BaseConnector):
raise NotImplementedError
class OAuthConnector(BaseConnector):
@classmethod
@abc.abstractmethod
def oauth_id(cls) -> DocumentSource:
raise NotImplementedError
@classmethod
@abc.abstractmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
raise NotImplementedError
@classmethod
@abc.abstractmethod
def oauth_code_to_token(cls, code: str) -> dict[str, Any]:
raise NotImplementedError
# Event driven
class EventConnector(BaseConnector):
@abc.abstractmethod

View File

@@ -132,7 +132,6 @@ class LinearConnector(LoadConnector, PollConnector):
branchName
customerTicketCount
description
descriptionData
comments {
nodes {
url
@@ -215,5 +214,6 @@ class LinearConnector(LoadConnector, PollConnector):
if __name__ == "__main__":
connector = LinearConnector()
connector.load_credentials({"linear_api_key": os.environ["LINEAR_API_KEY"]})
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@@ -134,7 +134,6 @@ def get_latest_message_time(thread: ThreadType) -> datetime:
def thread_to_doc(
workspace: str,
channel: ChannelType,
thread: ThreadType,
slack_cleaner: SlackTextCleaner,
@@ -171,15 +170,15 @@ def thread_to_doc(
else first_message
)
doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}"
doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}".replace(
"\n", " "
)
return Document(
id=f"{channel_id}__{thread[0]['ts']}",
sections=[
Section(
link=get_message_link(
event=m, workspace=workspace, channel_id=channel_id
),
link=get_message_link(event=m, client=client, channel_id=channel_id),
text=slack_cleaner.index_clean(cast(str, m["text"])),
)
for m in thread
@@ -263,7 +262,6 @@ def filter_channels(
def _get_all_docs(
client: WebClient,
workspace: str,
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
oldest: str | None = None,
@@ -310,7 +308,6 @@ def _get_all_docs(
if filtered_thread:
channel_docs += 1
yield thread_to_doc(
workspace=workspace,
channel=channel,
thread=filtered_thread,
slack_cleaner=slack_cleaner,
@@ -373,14 +370,12 @@ def _get_all_doc_ids(
class SlackPollConnector(PollConnector, SlimConnector):
def __init__(
self,
workspace: str,
channels: list[str] | None = None,
# if specified, will treat the specified channel strings as
# regexes, and will only index channels that fully match the regexes
channel_regex_enabled: bool = False,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.workspace = workspace
self.channels = channels
self.channel_regex_enabled = channel_regex_enabled
self.batch_size = batch_size
@@ -414,7 +409,6 @@ class SlackPollConnector(PollConnector, SlimConnector):
documents: list[Document] = []
for document in _get_all_docs(
client=self.client,
workspace=self.workspace,
channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled,
# NOTE: need to impute to `None` instead of using 0.0, since Slack will
@@ -438,7 +432,6 @@ if __name__ == "__main__":
slack_channel = os.environ.get("SLACK_CHANNEL")
connector = SlackPollConnector(
workspace=os.environ["SLACK_WORKSPACE"],
channels=[slack_channel] if slack_channel else None,
)
connector.load_credentials({"slack_bot_token": os.environ["SLACK_BOT_TOKEN"]})

View File

@@ -1,140 +0,0 @@
import json
import os
from datetime import datetime
from datetime import timezone
from pathlib import Path
from typing import Any
from typing import cast
from danswer.configs.app_configs import INDEX_BATCH_SIZE
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import GenerateDocumentsOutput
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.slack.connector import filter_channels
from danswer.connectors.slack.utils import get_message_link
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_event_time(event: dict[str, Any]) -> datetime | None:
ts = event.get("ts")
if not ts:
return None
return datetime.fromtimestamp(float(ts), tz=timezone.utc)
class SlackLoadConnector(LoadConnector):
# WARNING: DEPRECATED, DO NOT USE
def __init__(
self,
workspace: str,
export_path_str: str,
channels: list[str] | None = None,
# if specified, will treat the specified channel strings as
# regexes, and will only index channels that fully match the regexes
channel_regex_enabled: bool = False,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.workspace = workspace
self.channels = channels
self.channel_regex_enabled = channel_regex_enabled
self.export_path_str = export_path_str
self.batch_size = batch_size
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
if credentials:
logger.warning("Unexpected credentials provided for Slack Load Connector")
return None
@staticmethod
def _process_batch_event(
slack_event: dict[str, Any],
channel: dict[str, Any],
matching_doc: Document | None,
workspace: str,
) -> Document | None:
if (
slack_event["type"] == "message"
and slack_event.get("subtype") != "channel_join"
):
if matching_doc:
return Document(
id=matching_doc.id,
sections=matching_doc.sections
+ [
Section(
link=get_message_link(
event=slack_event,
workspace=workspace,
channel_id=channel["id"],
),
text=slack_event["text"],
)
],
source=matching_doc.source,
semantic_identifier=matching_doc.semantic_identifier,
title="", # slack docs don't really have a "title"
doc_updated_at=get_event_time(slack_event),
metadata=matching_doc.metadata,
)
return Document(
id=slack_event["ts"],
sections=[
Section(
link=get_message_link(
event=slack_event,
workspace=workspace,
channel_id=channel["id"],
),
text=slack_event["text"],
)
],
source=DocumentSource.SLACK,
semantic_identifier=channel["name"],
title="", # slack docs don't really have a "title"
doc_updated_at=get_event_time(slack_event),
metadata={},
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
export_path = Path(self.export_path_str)
with open(export_path / "channels.json") as f:
all_channels = json.load(f)
filtered_channels = filter_channels(
all_channels, self.channels, self.channel_regex_enabled
)
document_batch: dict[str, Document] = {}
for channel_info in filtered_channels:
channel_dir_path = export_path / cast(str, channel_info["name"])
channel_file_paths = [
channel_dir_path / file_name
for file_name in os.listdir(channel_dir_path)
]
for path in channel_file_paths:
with open(path) as f:
events = cast(list[dict[str, Any]], json.load(f))
for slack_event in events:
doc = self._process_batch_event(
slack_event=slack_event,
channel=channel_info,
matching_doc=document_batch.get(
slack_event.get("thread_ts", "")
),
workspace=self.workspace,
)
if doc:
document_batch[doc.id] = doc
if len(document_batch) >= self.batch_size:
yield list(document_batch.values())
yield list(document_batch.values())

View File

@@ -2,6 +2,7 @@ import re
import time
from collections.abc import Callable
from collections.abc import Generator
from functools import lru_cache
from functools import wraps
from typing import Any
from typing import cast
@@ -21,19 +22,21 @@ basic_retry_wrapper = retry_builder()
_SLACK_LIMIT = 900
@lru_cache()
def get_base_url(token: str) -> str:
"""Retrieve and cache the base URL of the Slack workspace based on the client token."""
client = WebClient(token=token)
return client.auth_test()["url"]
def get_message_link(
event: dict[str, Any], workspace: str, channel_id: str | None = None
event: dict[str, Any], client: WebClient, channel_id: str | None = None
) -> str:
channel_id = channel_id or cast(
str, event["channel"]
) # channel must either be present in the event or passed in
message_ts = cast(str, event["ts"])
message_ts_without_dot = message_ts.replace(".", "")
thread_ts = cast(str | None, event.get("thread_ts"))
return (
f"https://{workspace}.slack.com/archives/{channel_id}/p{message_ts_without_dot}"
+ (f"?thread_ts={thread_ts}" if thread_ts else "")
)
channel_id = channel_id or event["channel"]
message_ts = event["ts"]
response = client.chat_getPermalink(channel=channel_id, message_ts=message_ts)
permalink = response["permalink"]
return permalink
def _make_slack_api_call_logged(

View File

@@ -33,7 +33,7 @@ def get_created_datetime(chat_message: ChatMessage) -> datetime:
def _extract_channel_members(channel: Channel) -> list[BasicExpertInfo]:
channel_members_list: list[BasicExpertInfo] = []
members = channel.members.get().execute_query()
members = channel.members.get().execute_query_retry()
for member in members:
channel_members_list.append(BasicExpertInfo(display_name=member.display_name))
return channel_members_list
@@ -51,7 +51,7 @@ def _get_threads_from_channel(
end = end.replace(tzinfo=timezone.utc)
query = channel.messages.get()
base_messages: list[ChatMessage] = query.execute_query()
base_messages: list[ChatMessage] = query.execute_query_retry()
threads: list[list[ChatMessage]] = []
for base_message in base_messages:
@@ -65,7 +65,7 @@ def _get_threads_from_channel(
continue
reply_query = base_message.replies.get_all()
replies = reply_query.execute_query()
replies = reply_query.execute_query_retry()
# start a list containing the base message and its replies
thread: list[ChatMessage] = [base_message]
@@ -82,7 +82,7 @@ def _get_channels_from_teams(
channels_list: list[Channel] = []
for team in teams:
query = team.channels.get()
channels = query.execute_query()
channels = query.execute_query_retry()
channels_list.extend(channels)
return channels_list
@@ -210,7 +210,7 @@ class TeamsConnector(LoadConnector, PollConnector):
teams_list: list[Team] = []
teams = self.graph_client.teams.get().execute_query()
teams = self.graph_client.teams.get().execute_query_retry()
if len(self.requested_team_list) > 0:
adjusted_request_strings = [
@@ -234,14 +234,25 @@ class TeamsConnector(LoadConnector, PollConnector):
raise ConnectorMissingCredentialError("Teams")
teams = self._get_all_teams()
logger.debug(f"Found available teams: {[str(t) for t in teams]}")
if not teams:
msg = "No teams found."
logger.error(msg)
raise ValueError(msg)
channels = _get_channels_from_teams(
teams=teams,
)
logger.debug(f"Found available channels: {[c.id for c in channels]}")
if not channels:
msg = "No channels found."
logger.error(msg)
raise ValueError(msg)
# goes over channels, converts them into Document objects and then yields them in batches
doc_batch: list[Document] = []
for channel in channels:
logger.debug(f"Fetching threads from channel: {channel.id}")
thread_list = _get_threads_from_channel(channel, start=start, end=end)
for thread in thread_list:
converted_doc = _convert_thread_to_document(channel, thread)
@@ -259,8 +270,8 @@ class TeamsConnector(LoadConnector, PollConnector):
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.utcfromtimestamp(start)
end_datetime = datetime.utcfromtimestamp(end)
start_datetime = datetime.fromtimestamp(start, timezone.utc)
end_datetime = datetime.fromtimestamp(end, timezone.utc)
return self._fetch_from_teams(start=start_datetime, end=end_datetime)

View File

@@ -204,7 +204,8 @@ def _build_documents_blocks(
continue
seen_docs_identifiers.add(d.document_id)
doc_sem_id = d.semantic_identifier
# Strip newlines from the semantic identifier for Slackbot formatting
doc_sem_id = d.semantic_identifier.replace("\n", " ")
if d.source_type == DocumentSource.SLACK.value:
doc_sem_id = "#" + doc_sem_id

View File

@@ -373,7 +373,9 @@ def handle_regular_answer(
respond_in_thread(
client=client,
channel=channel,
receiver_ids=receiver_ids,
receiver_ids=[message_info.sender]
if message_info.is_bot_msg and message_info.sender
else receiver_ids,
text="Hello! Danswer has some results for you!",
blocks=all_blocks,
thread_ts=message_ts_to_respond_to,

View File

@@ -11,6 +11,7 @@ from retry import retry
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.models.blocks import Block
from slack_sdk.models.blocks import SectionBlock
from slack_sdk.models.metadata import Metadata
from slack_sdk.socket_mode import SocketModeClient
@@ -140,6 +141,40 @@ def remove_danswer_bot_tag(message_str: str, client: WebClient) -> str:
return re.sub(rf"<@{bot_tag_id}>\s", "", message_str)
def _check_for_url_in_block(block: Block) -> bool:
"""
Check if the block has a key that contains "url" in it
"""
block_dict = block.to_dict()
def check_dict_for_url(d: dict) -> bool:
for key, value in d.items():
if "url" in key.lower():
return True
if isinstance(value, dict):
if check_dict_for_url(value):
return True
elif isinstance(value, list):
for item in value:
if isinstance(item, dict) and check_dict_for_url(item):
return True
return False
return check_dict_for_url(block_dict)
def _build_error_block(error_message: str) -> Block:
"""
Build an error block to display in slack so that the user can see
the error without completely breaking
"""
display_text = (
"There was an error displaying all of the Onyx answers."
f" Please let an admin or an onyx developer know. Error: {error_message}"
)
return SectionBlock(text=display_text)
@retry(
tries=DANSWER_BOT_NUM_RETRIES,
delay=0.25,
@@ -162,24 +197,9 @@ def respond_in_thread(
message_ids: list[str] = []
if not receiver_ids:
slack_call = make_slack_api_rate_limited(client.chat_postMessage)
response = slack_call(
channel=channel,
text=text,
blocks=blocks,
thread_ts=thread_ts,
metadata=metadata,
unfurl_links=unfurl,
unfurl_media=unfurl,
)
if not response.get("ok"):
raise RuntimeError(f"Failed to post message: {response}")
message_ids.append(response["message_ts"])
else:
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
for receiver in receiver_ids:
try:
response = slack_call(
channel=channel,
user=receiver,
text=text,
blocks=blocks,
thread_ts=thread_ts,
@@ -187,8 +207,68 @@ def respond_in_thread(
unfurl_links=unfurl,
unfurl_media=unfurl,
)
if not response.get("ok"):
raise RuntimeError(f"Failed to post message: {response}")
except Exception as e:
logger.warning(f"Failed to post message: {e} \n blocks: {blocks}")
logger.warning("Trying again without blocks that have urls")
if not blocks:
raise e
blocks_without_urls = [
block for block in blocks if not _check_for_url_in_block(block)
]
blocks_without_urls.append(_build_error_block(str(e)))
# Try again wtihout blocks containing url
response = slack_call(
channel=channel,
text=text,
blocks=blocks_without_urls,
thread_ts=thread_ts,
metadata=metadata,
unfurl_links=unfurl,
unfurl_media=unfurl,
)
message_ids.append(response["message_ts"])
else:
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
for receiver in receiver_ids:
try:
response = slack_call(
channel=channel,
user=receiver,
text=text,
blocks=blocks,
thread_ts=thread_ts,
metadata=metadata,
unfurl_links=unfurl,
unfurl_media=unfurl,
)
except Exception as e:
logger.warning(f"Failed to post message: {e} \n blocks: {blocks}")
logger.warning("Trying again without blocks that have urls")
if not blocks:
raise e
blocks_without_urls = [
block for block in blocks if not _check_for_url_in_block(block)
]
blocks_without_urls.append(_build_error_block(str(e)))
# Try again wtihout blocks containing url
response = slack_call(
channel=channel,
user=receiver,
text=text,
blocks=blocks_without_urls,
thread_ts=thread_ts,
metadata=metadata,
unfurl_links=unfurl,
unfurl_media=unfurl,
)
message_ids.append(response["message_ts"])
return message_ids

View File

@@ -20,7 +20,6 @@ from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.server.documents.models import CredentialBase
from danswer.server.documents.models import CredentialDataUpdateRequest
from danswer.utils.logger import setup_logger
@@ -262,7 +261,8 @@ def _cleanup_credential__user_group_relationships__no_commit(
def alter_credential(
credential_id: int,
credential_data: CredentialDataUpdateRequest,
name: str,
credential_json: dict[str, Any],
user: User,
db_session: Session,
) -> Credential | None:
@@ -272,11 +272,13 @@ def alter_credential(
if credential is None:
return None
credential.name = credential_data.name
credential.name = name
# Update only the keys present in credential_data.credential_json
for key, value in credential_data.credential_json.items():
credential.credential_json[key] = value
# Assign a new dictionary to credential.credential_json
credential.credential_json = {
**credential.credential_json,
**credential_json,
}
credential.user_id = user.id if user is not None else None
db_session.commit()
@@ -309,8 +311,8 @@ def update_credential_json(
credential = fetch_credential_by_id(credential_id, user, db_session)
if credential is None:
return None
credential.credential_json = credential_json
credential.credential_json = credential_json
db_session.commit()
return credential

View File

@@ -522,12 +522,16 @@ def expire_index_attempts(
search_settings_id: int,
db_session: Session,
) -> None:
delete_query = (
delete(IndexAttempt)
not_started_query = (
update(IndexAttempt)
.where(IndexAttempt.search_settings_id == search_settings_id)
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
.values(
status=IndexingStatus.CANCELED,
error_msg="Canceled, likely due to model swap",
)
)
db_session.execute(delete_query)
db_session.execute(not_started_query)
update_query = (
update(IndexAttempt)
@@ -549,9 +553,14 @@ def cancel_indexing_attempts_for_ccpair(
include_secondary_index: bool = False,
) -> None:
stmt = (
delete(IndexAttempt)
update(IndexAttempt)
.where(IndexAttempt.connector_credential_pair_id == cc_pair_id)
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
.values(
status=IndexingStatus.CANCELED,
error_msg="Canceled by user",
time_started=datetime.now(timezone.utc),
)
)
if not include_secondary_index:

View File

@@ -1,202 +0,0 @@
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import InputPrompt
from danswer.db.models import User
from danswer.server.features.input_prompt.models import InputPromptSnapshot
from danswer.server.manage.models import UserInfo
from danswer.utils.logger import setup_logger
logger = setup_logger()
def insert_input_prompt_if_not_exists(
user: User | None,
input_prompt_id: int | None,
prompt: str,
content: str,
active: bool,
is_public: bool,
db_session: Session,
commit: bool = True,
) -> InputPrompt:
if input_prompt_id is not None:
input_prompt = (
db_session.query(InputPrompt).filter_by(id=input_prompt_id).first()
)
else:
query = db_session.query(InputPrompt).filter(InputPrompt.prompt == prompt)
if user:
query = query.filter(InputPrompt.user_id == user.id)
else:
query = query.filter(InputPrompt.user_id.is_(None))
input_prompt = query.first()
if input_prompt is None:
input_prompt = InputPrompt(
id=input_prompt_id,
prompt=prompt,
content=content,
active=active,
is_public=is_public or user is None,
user_id=user.id if user else None,
)
db_session.add(input_prompt)
if commit:
db_session.commit()
return input_prompt
def insert_input_prompt(
prompt: str,
content: str,
is_public: bool,
user: User | None,
db_session: Session,
) -> InputPrompt:
input_prompt = InputPrompt(
prompt=prompt,
content=content,
active=True,
is_public=is_public or user is None,
user_id=user.id if user is not None else None,
)
db_session.add(input_prompt)
db_session.commit()
return input_prompt
def update_input_prompt(
user: User | None,
input_prompt_id: int,
prompt: str,
content: str,
active: bool,
db_session: Session,
) -> InputPrompt:
input_prompt = db_session.scalar(
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
)
if input_prompt is None:
raise ValueError(f"No input prompt with id {input_prompt_id}")
if not validate_user_prompt_authorization(user, input_prompt):
raise HTTPException(status_code=401, detail="You don't own this prompt")
input_prompt.prompt = prompt
input_prompt.content = content
input_prompt.active = active
db_session.commit()
return input_prompt
def validate_user_prompt_authorization(
user: User | None, input_prompt: InputPrompt
) -> bool:
prompt = InputPromptSnapshot.from_model(input_prompt=input_prompt)
if prompt.user_id is not None:
if user is None:
return False
user_details = UserInfo.from_model(user)
if str(user_details.id) != str(prompt.user_id):
return False
return True
def remove_public_input_prompt(input_prompt_id: int, db_session: Session) -> None:
input_prompt = db_session.scalar(
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
)
if input_prompt is None:
raise ValueError(f"No input prompt with id {input_prompt_id}")
if not input_prompt.is_public:
raise HTTPException(status_code=400, detail="This prompt is not public")
db_session.delete(input_prompt)
db_session.commit()
def remove_input_prompt(
user: User | None, input_prompt_id: int, db_session: Session
) -> None:
input_prompt = db_session.scalar(
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
)
if input_prompt is None:
raise ValueError(f"No input prompt with id {input_prompt_id}")
if input_prompt.is_public:
raise HTTPException(
status_code=400, detail="Cannot delete public prompts with this method"
)
if not validate_user_prompt_authorization(user, input_prompt):
raise HTTPException(status_code=401, detail="You do not own this prompt")
db_session.delete(input_prompt)
db_session.commit()
def fetch_input_prompt_by_id(
id: int, user_id: UUID | None, db_session: Session
) -> InputPrompt:
query = select(InputPrompt).where(InputPrompt.id == id)
if user_id:
query = query.where(
(InputPrompt.user_id == user_id) | (InputPrompt.user_id is None)
)
else:
# If no user_id is provided, only fetch prompts without a user_id (aka public)
query = query.where(InputPrompt.user_id == None) # noqa
result = db_session.scalar(query)
if result is None:
raise HTTPException(422, "No input prompt found")
return result
def fetch_public_input_prompts(
db_session: Session,
) -> list[InputPrompt]:
query = select(InputPrompt).where(InputPrompt.is_public)
return list(db_session.scalars(query).all())
def fetch_input_prompts_by_user(
db_session: Session,
user_id: UUID | None,
active: bool | None = None,
include_public: bool = False,
) -> list[InputPrompt]:
query = select(InputPrompt)
if user_id is not None:
if include_public:
query = query.where(
(InputPrompt.user_id == user_id) | InputPrompt.is_public
)
else:
query = query.where(InputPrompt.user_id == user_id)
elif include_public:
query = query.where(InputPrompt.is_public)
if active is not None:
query = query.where(InputPrompt.active == active)
return list(db_session.scalars(query).all())

View File

@@ -159,9 +159,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
)
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
input_prompts: Mapped[list["InputPrompt"]] = relationship(
"InputPrompt", back_populates="user"
)
# Personas owned by this user
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
@@ -178,31 +175,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
)
class InputPrompt(Base):
__tablename__ = "inputprompt"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
prompt: Mapped[str] = mapped_column(String)
content: Mapped[str] = mapped_column(String)
active: Mapped[bool] = mapped_column(Boolean)
user: Mapped[User | None] = relationship("User", back_populates="input_prompts")
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
class InputPrompt__User(Base):
__tablename__ = "inputprompt__user"
input_prompt_id: Mapped[int] = mapped_column(
ForeignKey("inputprompt.id"), primary_key=True
)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("inputprompt.id"), primary_key=True
)
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
pass
@@ -596,6 +568,25 @@ class Connector(Base):
list["DocumentByConnectorCredentialPair"]
] = relationship("DocumentByConnectorCredentialPair", back_populates="connector")
# synchronize this validation logic with RefreshFrequencySchema etc on front end
# until we have a centralized validation schema
# TODO(rkuo): experiment with SQLAlchemy validators rather than manual checks
# https://docs.sqlalchemy.org/en/20/orm/mapped_attributes.html
def validate_refresh_freq(self) -> None:
if self.refresh_freq is not None:
if self.refresh_freq < 60:
raise ValueError(
"refresh_freq must be greater than or equal to 60 seconds."
)
def validate_prune_freq(self) -> None:
if self.prune_freq is not None:
if self.prune_freq < 86400:
raise ValueError(
"prune_freq must be greater than or equal to 86400 seconds."
)
class Credential(Base):
__tablename__ = "credential"
@@ -1530,6 +1521,7 @@ class SlackBot(Base):
slack_channel_configs: Mapped[list[SlackChannelConfig]] = relationship(
"SlackChannelConfig",
back_populates="slack_bot",
cascade="all, delete-orphan",
)

View File

@@ -453,9 +453,9 @@ def upsert_persona(
"""
if persona_id is not None:
persona = db_session.query(Persona).filter_by(id=persona_id).first()
existing_persona = db_session.query(Persona).filter_by(id=persona_id).first()
else:
persona = _get_persona_by_name(
existing_persona = _get_persona_by_name(
persona_name=name, user=user, db_session=db_session
)
@@ -481,62 +481,78 @@ def upsert_persona(
prompts = None
if prompt_ids is not None:
prompts = db_session.query(Prompt).filter(Prompt.id.in_(prompt_ids)).all()
if not prompts and prompt_ids:
raise ValueError("prompts not found")
if prompts is not None and len(prompts) == 0:
raise ValueError(
f"Invalid Persona config, no valid prompts "
f"specified. Specified IDs were: '{prompt_ids}'"
)
# ensure all specified tools are valid
if tools:
validate_persona_tools(tools)
if persona:
if existing_persona:
# Built-in personas can only be updated through YAML configuration.
# This ensures that core system personas are not modified unintentionally.
if persona.builtin_persona and not builtin_persona:
if existing_persona.builtin_persona and not builtin_persona:
raise ValueError("Cannot update builtin persona with non-builtin.")
# this checks if the user has permission to edit the persona
persona = fetch_persona_by_id(
db_session=db_session, persona_id=persona.id, user=user, get_editable=True
# will raise an Exception if the user does not have permission
existing_persona = fetch_persona_by_id(
db_session=db_session,
persona_id=existing_persona.id,
user=user,
get_editable=True,
)
# The following update excludes `default`, `built-in`, and display priority.
# Display priority is handled separately in the `display-priority` endpoint.
# `default` and `built-in` properties can only be set when creating a persona.
persona.name = name
persona.description = description
persona.num_chunks = num_chunks
persona.chunks_above = chunks_above
persona.chunks_below = chunks_below
persona.llm_relevance_filter = llm_relevance_filter
persona.llm_filter_extraction = llm_filter_extraction
persona.recency_bias = recency_bias
persona.llm_model_provider_override = llm_model_provider_override
persona.llm_model_version_override = llm_model_version_override
persona.starter_messages = starter_messages
persona.deleted = False # Un-delete if previously deleted
persona.is_public = is_public
persona.icon_color = icon_color
persona.icon_shape = icon_shape
existing_persona.name = name
existing_persona.description = description
existing_persona.num_chunks = num_chunks
existing_persona.chunks_above = chunks_above
existing_persona.chunks_below = chunks_below
existing_persona.llm_relevance_filter = llm_relevance_filter
existing_persona.llm_filter_extraction = llm_filter_extraction
existing_persona.recency_bias = recency_bias
existing_persona.llm_model_provider_override = llm_model_provider_override
existing_persona.llm_model_version_override = llm_model_version_override
existing_persona.starter_messages = starter_messages
existing_persona.deleted = False # Un-delete if previously deleted
existing_persona.is_public = is_public
existing_persona.icon_color = icon_color
existing_persona.icon_shape = icon_shape
if remove_image or uploaded_image_id:
persona.uploaded_image_id = uploaded_image_id
persona.is_visible = is_visible
persona.search_start_date = search_start_date
persona.category_id = category_id
existing_persona.uploaded_image_id = uploaded_image_id
existing_persona.is_visible = is_visible
existing_persona.search_start_date = search_start_date
existing_persona.category_id = category_id
# Do not delete any associations manually added unless
# a new updated list is provided
if document_sets is not None:
persona.document_sets.clear()
persona.document_sets = document_sets or []
existing_persona.document_sets.clear()
existing_persona.document_sets = document_sets or []
if prompts is not None:
persona.prompts.clear()
persona.prompts = prompts or []
existing_persona.prompts.clear()
existing_persona.prompts = prompts
if tools is not None:
persona.tools = tools or []
existing_persona.tools = tools or []
persona = existing_persona
else:
persona = Persona(
if not prompts:
raise ValueError(
"Invalid Persona config. "
"Must specify at least one prompt for a new persona."
)
new_persona = Persona(
id=persona_id,
user_id=user.id if user else None,
is_public=is_public,
@@ -549,7 +565,7 @@ def upsert_persona(
llm_filter_extraction=llm_filter_extraction,
recency_bias=recency_bias,
builtin_persona=builtin_persona,
prompts=prompts or [],
prompts=prompts,
document_sets=document_sets or [],
llm_model_provider_override=llm_model_provider_override,
llm_model_version_override=llm_model_version_override,
@@ -564,8 +580,8 @@ def upsert_persona(
is_default_persona=is_default_persona,
category_id=category_id,
)
db_session.add(persona)
db_session.add(new_persona)
persona = new_persona
if commit:
db_session.commit()
else:

View File

@@ -148,6 +148,7 @@ class Indexable(abc.ABC):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
fresh_index: bool = False,
) -> set[DocumentInsertionRecord]:
"""
Takes a list of document chunks and indexes them in the document index
@@ -165,9 +166,14 @@ class Indexable(abc.ABC):
only needs to index chunks into the PRIMARY index. Do not update the secondary index here,
it is done automatically outside of this code.
NOTE: The fresh_index parameter, when set to True, assumes no documents have been previously
indexed for the given index/tenant. This can be used to optimize the indexing process for
new or empty indices.
Parameters:
- chunks: Document chunks with all of the information needed for indexing to the document
index.
- fresh_index: Boolean indicating whether this is a fresh index with no existing documents.
Returns:
List of document ids which map to unique documents and are used for deduping chunks

View File

@@ -306,6 +306,7 @@ class VespaIndex(DocumentIndex):
def index(
self,
chunks: list[DocMetadataAwareIndexChunk],
fresh_index: bool = False,
) -> set[DocumentInsertionRecord]:
"""Receive a list of chunks from a batch of documents and index the chunks into Vespa along
with updating the associated permissions. Assumes that a document will not be split into
@@ -322,26 +323,29 @@ class VespaIndex(DocumentIndex):
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
get_vespa_http_client() as http_client,
):
# Check for existing documents, existing documents need to have all of their chunks deleted
# prior to indexing as the document size (num chunks) may have shrunk
first_chunks = [chunk for chunk in cleaned_chunks if chunk.chunk_id == 0]
for chunk_batch in batch_generator(first_chunks, BATCH_SIZE):
existing_docs.update(
get_existing_documents_from_chunks(
chunks=chunk_batch,
if not fresh_index:
# Check for existing documents, existing documents need to have all of their chunks deleted
# prior to indexing as the document size (num chunks) may have shrunk
first_chunks = [
chunk for chunk in cleaned_chunks if chunk.chunk_id == 0
]
for chunk_batch in batch_generator(first_chunks, BATCH_SIZE):
existing_docs.update(
get_existing_documents_from_chunks(
chunks=chunk_batch,
index_name=self.index_name,
http_client=http_client,
executor=executor,
)
)
for doc_id_batch in batch_generator(existing_docs, BATCH_SIZE):
delete_vespa_docs(
document_ids=doc_id_batch,
index_name=self.index_name,
http_client=http_client,
executor=executor,
)
)
for doc_id_batch in batch_generator(existing_docs, BATCH_SIZE):
delete_vespa_docs(
document_ids=doc_id_batch,
index_name=self.index_name,
http_client=http_client,
executor=executor,
)
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
batch_index_vespa_chunks(

View File

@@ -70,7 +70,7 @@ def get_file_ext(file_path_or_name: str | Path) -> str:
return extension
def check_file_ext_is_valid(ext: str) -> bool:
def is_valid_file_ext(ext: str) -> bool:
return ext in VALID_FILE_EXTENSIONS
@@ -364,7 +364,7 @@ def extract_file_text(
elif file_name is not None:
final_extension = get_file_ext(file_name)
if check_file_ext_is_valid(final_extension):
if is_valid_file_ext(final_extension):
return extension_to_function.get(final_extension, file_io_to_text)(file)
# Either the file somehow has no name or the extension is not one that we recognize

View File

@@ -1,4 +1,5 @@
import traceback
from collections.abc import Callable
from functools import partial
from http import HTTPStatus
from typing import Protocol
@@ -12,6 +13,7 @@ from danswer.access.access import get_access_for_documents
from danswer.access.models import DocumentAccess
from danswer.configs.app_configs import ENABLE_MULTIPASS_INDEXING
from danswer.configs.app_configs import INDEXING_EXCEPTION_LIMIT
from danswer.configs.app_configs import MAX_DOCUMENT_CHARS
from danswer.configs.constants import DEFAULT_BOOST
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
@@ -202,40 +204,13 @@ def index_doc_batch_with_handler(
def index_doc_batch_prepare(
document_batch: list[Document],
documents: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
ignore_time_skip: bool = False,
) -> DocumentBatchPrepareContext | None:
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
This preceeds indexing it into the actual document index."""
documents: list[Document] = []
for document in document_batch:
empty_contents = not any(section.text.strip() for section in document.sections)
if (
(not document.title or not document.title.strip())
and not document.semantic_identifier.strip()
and empty_contents
):
# Skip documents that have neither title nor content
# If the document doesn't have either, then there is no useful information in it
# This is again verified later in the pipeline after chunking but at that point there should
# already be no documents that are empty.
logger.warning(
f"Skipping document with ID {document.id} as it has neither title nor content."
)
continue
if document.title is not None and not document.title.strip() and empty_contents:
# The title is explicitly empty ("" and not None) and the document is empty
# so when building the chunk text representation, it will be empty and unuseable
logger.warning(
f"Skipping document with ID {document.id} as the chunks will be empty."
)
continue
documents.append(document)
# Create a trimmed list of docs that don't have a newer updated at
# Shortcuts the time-consuming flow on connector index retries
document_ids: list[str] = [document.id for document in documents]
@@ -282,17 +257,64 @@ def index_doc_batch_prepare(
)
def filter_documents(document_batch: list[Document]) -> list[Document]:
documents: list[Document] = []
for document in document_batch:
empty_contents = not any(section.text.strip() for section in document.sections)
if (
(not document.title or not document.title.strip())
and not document.semantic_identifier.strip()
and empty_contents
):
# Skip documents that have neither title nor content
# If the document doesn't have either, then there is no useful information in it
# This is again verified later in the pipeline after chunking but at that point there should
# already be no documents that are empty.
logger.warning(
f"Skipping document with ID {document.id} as it has neither title nor content."
)
continue
if document.title is not None and not document.title.strip() and empty_contents:
# The title is explicitly empty ("" and not None) and the document is empty
# so when building the chunk text representation, it will be empty and unuseable
logger.warning(
f"Skipping document with ID {document.id} as the chunks will be empty."
)
continue
section_chars = sum(len(section.text) for section in document.sections)
if (
MAX_DOCUMENT_CHARS
and len(document.title or document.semantic_identifier) + section_chars
> MAX_DOCUMENT_CHARS
):
# Skip documents that are too long, later on there are more memory intensive steps done on the text
# and the container will run out of memory and crash. Several other checks are included upstream but
# those are at the connector level so a catchall is still needed.
# Assumption here is that files that are that long, are generated files and not the type users
# generally care for.
logger.warning(
f"Skipping document with ID {document.id} as it is too long."
)
continue
documents.append(document)
return documents
@log_function_time(debug_only=True)
def index_doc_batch(
*,
document_batch: list[Document],
chunker: Chunker,
embedder: IndexingEmbedder,
document_index: DocumentIndex,
document_batch: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
ignore_time_skip: bool = False,
tenant_id: str | None = None,
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
) -> tuple[int, int]:
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
@@ -309,8 +331,11 @@ def index_doc_batch(
is_public=False,
)
logger.debug("Filtering Documents")
filtered_documents = filter_fnc(document_batch)
ctx = index_doc_batch_prepare(
document_batch=document_batch,
documents=filtered_documents,
index_attempt_metadata=index_attempt_metadata,
ignore_time_skip=ignore_time_skip,
db_session=db_session,

View File

@@ -268,12 +268,16 @@ class DefaultMultiLLM(LLM):
# NOTE: have to set these as environment variables for Litellm since
# not all are able to passed in but they always support them set as env
# variables
# variables. We'll also try passing them in, since litellm just ignores
# addtional kwargs (and some kwargs MUST be passed in rather than set as
# env variables)
if custom_config:
for k, v in custom_config.items():
os.environ[k] = v
model_kwargs = model_kwargs or {}
if custom_config:
model_kwargs.update(custom_config)
if extra_headers:
model_kwargs.update({"extra_headers": extra_headers})
if extra_body:

View File

@@ -1,5 +1,4 @@
import copy
import io
import json
from collections.abc import Callable
from collections.abc import Iterator
@@ -7,7 +6,6 @@ from typing import Any
from typing import cast
import litellm # type: ignore
import pandas as pd
import tiktoken
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
@@ -100,53 +98,32 @@ def litellm_exception_to_error_msg(
return error_msg
# Processes CSV files to show the first 5 rows and max_columns (default 40) columns
def _process_csv_file(file: InMemoryChatFile, max_columns: int = 40) -> str:
df = pd.read_csv(io.StringIO(file.content.decode("utf-8")))
csv_preview = df.head().to_string(max_cols=max_columns)
file_name_section = (
f"CSV FILE NAME: {file.filename}\n"
if file.filename
else "CSV FILE (NO NAME PROVIDED):\n"
)
return f"{file_name_section}{CODE_BLOCK_PAT.format(csv_preview)}\n\n\n"
def _build_content(
message: str,
files: list[InMemoryChatFile] | None = None,
) -> str:
"""Applies all non-image files."""
text_files = (
[file for file in files if file.file_type == ChatFileType.PLAIN_TEXT]
if files
else None
)
if not files:
return message
csv_files = (
[file for file in files if file.file_type == ChatFileType.CSV]
if files
else None
)
text_files = [
file
for file in files
if file.file_type in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV)
]
if not text_files and not csv_files:
if not text_files:
return message
final_message_with_files = "FILES:\n\n"
for file in text_files or []:
for file in text_files:
file_content = file.content.decode("utf-8")
file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else ""
final_message_with_files += (
f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n"
)
for file in csv_files or []:
final_message_with_files += _process_csv_file(file)
final_message_with_files += message
return final_message_with_files
return final_message_with_files + message
def build_content_with_imgs(

View File

@@ -52,12 +52,9 @@ from danswer.server.documents.connector import router as connector_router
from danswer.server.documents.credential import router as credential_router
from danswer.server.documents.document import router as document_router
from danswer.server.documents.indexing import router as indexing_router
from danswer.server.documents.standard_oauth import router as oauth_router
from danswer.server.features.document_set.api import router as document_set_router
from danswer.server.features.folder.api import router as folder_router
from danswer.server.features.input_prompt.api import (
admin_router as admin_input_prompt_router,
)
from danswer.server.features.input_prompt.api import basic_router as input_prompt_router
from danswer.server.features.notifications.api import router as notification_router
from danswer.server.features.persona.api import admin_router as admin_persona_router
from danswer.server.features.persona.api import basic_router as persona_router
@@ -258,8 +255,6 @@ def get_application() -> FastAPI:
)
include_router_with_global_prefix_prepended(application, persona_router)
include_router_with_global_prefix_prepended(application, admin_persona_router)
include_router_with_global_prefix_prepended(application, input_prompt_router)
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
include_router_with_global_prefix_prepended(application, notification_router)
include_router_with_global_prefix_prepended(application, prompt_router)
include_router_with_global_prefix_prepended(application, tool_router)
@@ -282,6 +277,7 @@ def get_application() -> FastAPI:
)
include_router_with_global_prefix_prepended(application, long_term_logs_router)
include_router_with_global_prefix_prepended(application, api_key_router)
include_router_with_global_prefix_prepended(application, oauth_router)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step

View File

@@ -1,24 +0,0 @@
input_prompts:
- id: -5
prompt: "Elaborate"
content: "Elaborate on the above, give me a more in depth explanation."
active: true
is_public: true
- id: -4
prompt: "Reword"
content: "Help me rewrite the following politely and concisely for professional communication:\n"
active: true
is_public: true
- id: -3
prompt: "Email"
content: "Write a professional email for me including a subject line, signature, etc. Template the parts that need editing with [ ]. The email should cover the following points:\n"
active: true
is_public: true
- id: -2
prompt: "Debug"
content: "Provide step-by-step troubleshooting instructions for the following issue:\n"
active: true
is_public: true

View File

@@ -196,7 +196,7 @@ def seed_initial_documents(
docs, chunks = _create_indexable_chunks(processed_docs, tenant_id)
index_doc_batch_prepare(
document_batch=docs,
documents=docs,
index_attempt_metadata=IndexAttemptMetadata(
connector_id=connector_id,
credential_id=PUBLIC_CREDENTIAL_ID,
@@ -216,7 +216,7 @@ def seed_initial_documents(
# as we just sent over the Vespa schema and there is a slight delay
index_with_retries = retry_builder()(document_index.index)
index_with_retries(chunks=chunks)
index_with_retries(chunks=chunks, fresh_index=True)
# Mock a run for the UI even though it did not actually call out to anything
mock_successful_index_attempt(

View File

@@ -1,13 +1,11 @@
import yaml
from sqlalchemy.orm import Session
from danswer.configs.chat_configs import INPUT_PROMPT_YAML
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import PERSONAS_YAML
from danswer.configs.chat_configs import PROMPTS_YAML
from danswer.context.search.enums import RecencyBiasSetting
from danswer.db.document_set import get_or_create_document_set_by_name
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
from danswer.db.models import DocumentSet as DocumentSetDBModel
from danswer.db.models import Persona
from danswer.db.models import Prompt as PromptDBModel
@@ -79,6 +77,9 @@ def load_personas_from_yaml(
if prompts:
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
if not prompt_ids:
raise ValueError("Invalid Persona config, no prompts exist")
p_id = persona.get("id")
tool_ids = []
@@ -123,45 +124,24 @@ def load_personas_from_yaml(
tool_ids=tool_ids,
builtin_persona=True,
is_public=True,
display_priority=existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority"),
is_visible=existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible"),
display_priority=(
existing_persona.display_priority
if existing_persona is not None
else persona.get("display_priority")
),
is_visible=(
existing_persona.is_visible
if existing_persona is not None
else persona.get("is_visible")
),
db_session=db_session,
)
def load_input_prompts_from_yaml(
db_session: Session, input_prompts_yaml: str = INPUT_PROMPT_YAML
) -> None:
with open(input_prompts_yaml, "r") as file:
data = yaml.safe_load(file)
all_input_prompts = data.get("input_prompts", [])
for input_prompt in all_input_prompts:
# If these prompts are deleted (which is a hard delete in the DB), on server startup
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
insert_input_prompt_if_not_exists(
user=None,
input_prompt_id=input_prompt.get("id"),
prompt=input_prompt["prompt"],
content=input_prompt["content"],
is_public=input_prompt["is_public"],
active=input_prompt.get("active", True),
db_session=db_session,
commit=True,
)
def load_chat_yamls(
db_session: Session,
prompt_yaml: str = PROMPTS_YAML,
personas_yaml: str = PERSONAS_YAML,
input_prompts_yaml: str = INPUT_PROMPT_YAML,
) -> None:
load_prompts_from_yaml(db_session, prompt_yaml)
load_personas_from_yaml(db_session, personas_yaml)
load_input_prompts_from_yaml(db_session, input_prompts_yaml)

View File

@@ -33,8 +33,6 @@ from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.enums import AccessType
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import count_index_attempts_for_connector
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
@@ -45,6 +43,7 @@ from danswer.db.search_settings import get_current_search_settings
from danswer.redis.redis_connector import RedisConnector
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import CCPairFullInfo
from danswer.server.documents.models import CCPropertyUpdateRequest
from danswer.server.documents.models import CCStatusUpdateRequest
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.server.documents.models import ConnectorCredentialPairMetadata
@@ -192,9 +191,6 @@ def update_cc_pair_status(
db_session
)
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
cancel_indexing_attempts_past_model(db_session)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
@@ -308,6 +304,46 @@ def update_cc_pair_name(
raise HTTPException(status_code=400, detail="Name must be unique")
@router.put("/admin/cc-pair/{cc_pair_id}/property")
def update_cc_pair_property(
cc_pair_id: int,
update_request: CCPropertyUpdateRequest, # in seconds
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> StatusResponse[int]:
cc_pair = get_connector_credential_pair_from_id(
cc_pair_id=cc_pair_id,
db_session=db_session,
user=user,
get_editable=True,
)
if not cc_pair:
raise HTTPException(
status_code=400, detail="CC Pair not found for current user's permissions"
)
# Can we centralize logic for updating connector properties
# so that we don't need to manually validate everywhere?
if update_request.name == "refresh_frequency":
cc_pair.connector.refresh_freq = int(update_request.value)
cc_pair.connector.validate_refresh_freq()
db_session.commit()
msg = "Refresh frequency updated successfully"
elif update_request.name == "pruning_frequency":
cc_pair.connector.prune_freq = int(update_request.value)
cc_pair.connector.validate_prune_freq()
db_session.commit()
msg = "Pruning frequency updated successfully"
else:
raise HTTPException(
status_code=400, detail=f"Property name {update_request.name} is not valid."
)
return StatusResponse(success=True, message=msg, data=cc_pair_id)
@router.get("/admin/cc-pair/{cc_pair_id}/last_pruned")
def get_cc_pair_last_pruned(
cc_pair_id: int,

View File

@@ -181,7 +181,13 @@ def update_credential_data(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> CredentialBase:
credential = alter_credential(credential_id, credential_update, user, db_session)
credential = alter_credential(
credential_id,
credential_update.name,
credential_update.credential_json,
user,
db_session,
)
if credential is None:
raise HTTPException(

View File

@@ -364,6 +364,11 @@ class RunConnectorRequest(BaseModel):
from_beginning: bool = False
class CCPropertyUpdateRequest(BaseModel):
name: str
value: str
"""Connectors Models"""

View File

@@ -0,0 +1,142 @@
import uuid
from typing import Annotated
from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import DocumentSource
from danswer.connectors.interfaces import OAuthConnector
from danswer.db.credentials import create_credential
from danswer.db.engine import get_current_tenant_id
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.redis.redis_pool import get_redis_client
from danswer.server.documents.models import CredentialBase
from danswer.utils.logger import setup_logger
from danswer.utils.subclasses import find_all_subclasses_in_dir
logger = setup_logger()
router = APIRouter(prefix="/connector/oauth")
_OAUTH_STATE_KEY_FMT = "oauth_state:{state}"
_OAUTH_STATE_EXPIRATION_SECONDS = 10 * 60 # 10 minutes
# Cache for OAuth connectors, populated at module load time
_OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {}
def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]:
"""Walk through the connectors package to find all OAuthConnector implementations"""
global _OAUTH_CONNECTORS
if _OAUTH_CONNECTORS: # Return cached connectors if already discovered
return _OAUTH_CONNECTORS
oauth_connectors = find_all_subclasses_in_dir(
cast(type[OAuthConnector], OAuthConnector), "danswer.connectors"
)
_OAUTH_CONNECTORS = {cls.oauth_id(): cls for cls in oauth_connectors}
return _OAUTH_CONNECTORS
# Discover OAuth connectors at module load time
_discover_oauth_connectors()
class AuthorizeResponse(BaseModel):
redirect_url: str
@router.get("/authorize/{source}")
def oauth_authorize(
source: DocumentSource,
desired_return_url: Annotated[str | None, Query()] = None,
_: User = Depends(current_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> AuthorizeResponse:
"""Initiates the OAuth flow by redirecting to the provider's auth page"""
oauth_connectors = _discover_oauth_connectors()
if source not in oauth_connectors:
raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}")
connector_cls = oauth_connectors[source]
base_url = WEB_DOMAIN
# store state in redis
if not desired_return_url:
desired_return_url = f"{base_url}/admin/connectors/{source}?step=0"
redis_client = get_redis_client(tenant_id=tenant_id)
state = str(uuid.uuid4())
redis_client.set(
_OAUTH_STATE_KEY_FMT.format(state=state),
desired_return_url,
ex=_OAUTH_STATE_EXPIRATION_SECONDS,
)
return AuthorizeResponse(
redirect_url=connector_cls.oauth_authorization_url(base_url, state)
)
class CallbackResponse(BaseModel):
redirect_url: str
@router.get("/callback/{source}")
def oauth_callback(
source: DocumentSource,
code: Annotated[str, Query()],
state: Annotated[str, Query()],
db_session: Session = Depends(get_session),
user: User = Depends(current_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> CallbackResponse:
"""Handles the OAuth callback and exchanges the code for tokens"""
oauth_connectors = _discover_oauth_connectors()
if source not in oauth_connectors:
raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}")
connector_cls = oauth_connectors[source]
# get state from redis
redis_client = get_redis_client(tenant_id=tenant_id)
original_url_bytes = cast(
bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state))
)
if not original_url_bytes:
raise HTTPException(status_code=400, detail="Invalid OAuth state")
original_url = original_url_bytes.decode("utf-8")
token_info = connector_cls.oauth_code_to_token(code)
# Create a new credential with the token info
credential_data = CredentialBase(
credential_json=token_info,
admin_public=True, # Or based on some logic/parameter
source=source,
name=f"{source.title()} OAuth Credential",
)
credential = create_credential(
credential_data=credential_data,
user=user,
db_session=db_session,
)
return CallbackResponse(
redirect_url=(
f"{original_url}?credentialId={credential.id}"
if "?" not in original_url
else f"{original_url}&credentialId={credential.id}"
)
)

View File

@@ -1,134 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.input_prompt import fetch_input_prompt_by_id
from danswer.db.input_prompt import fetch_input_prompts_by_user
from danswer.db.input_prompt import fetch_public_input_prompts
from danswer.db.input_prompt import insert_input_prompt
from danswer.db.input_prompt import remove_input_prompt
from danswer.db.input_prompt import remove_public_input_prompt
from danswer.db.input_prompt import update_input_prompt
from danswer.db.models import User
from danswer.server.features.input_prompt.models import CreateInputPromptRequest
from danswer.server.features.input_prompt.models import InputPromptSnapshot
from danswer.server.features.input_prompt.models import UpdateInputPromptRequest
from danswer.utils.logger import setup_logger
logger = setup_logger()
basic_router = APIRouter(prefix="/input_prompt")
admin_router = APIRouter(prefix="/admin/input_prompt")
@basic_router.get("")
def list_input_prompts(
user: User | None = Depends(current_user),
include_public: bool = False,
db_session: Session = Depends(get_session),
) -> list[InputPromptSnapshot]:
user_prompts = fetch_input_prompts_by_user(
user_id=user.id if user is not None else None,
db_session=db_session,
include_public=include_public,
)
return [InputPromptSnapshot.from_model(prompt) for prompt in user_prompts]
@basic_router.get("/{input_prompt_id}")
def get_input_prompt(
input_prompt_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> InputPromptSnapshot:
input_prompt = fetch_input_prompt_by_id(
id=input_prompt_id,
user_id=user.id if user is not None else None,
db_session=db_session,
)
return InputPromptSnapshot.from_model(input_prompt=input_prompt)
@basic_router.post("")
def create_input_prompt(
create_input_prompt_request: CreateInputPromptRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> InputPromptSnapshot:
input_prompt = insert_input_prompt(
prompt=create_input_prompt_request.prompt,
content=create_input_prompt_request.content,
is_public=create_input_prompt_request.is_public,
user=user,
db_session=db_session,
)
return InputPromptSnapshot.from_model(input_prompt)
@basic_router.patch("/{input_prompt_id}")
def patch_input_prompt(
input_prompt_id: int,
update_input_prompt_request: UpdateInputPromptRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> InputPromptSnapshot:
try:
updated_input_prompt = update_input_prompt(
user=user,
input_prompt_id=input_prompt_id,
prompt=update_input_prompt_request.prompt,
content=update_input_prompt_request.content,
active=update_input_prompt_request.active,
db_session=db_session,
)
except ValueError as e:
error_msg = "Error occurred while updated input prompt"
logger.warn(f"{error_msg}. Stack trace: {e}")
raise HTTPException(status_code=404, detail=error_msg)
return InputPromptSnapshot.from_model(updated_input_prompt)
@basic_router.delete("/{input_prompt_id}")
def delete_input_prompt(
input_prompt_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
remove_input_prompt(user, input_prompt_id, db_session)
except ValueError as e:
error_msg = "Error occurred while deleting input prompt"
logger.warn(f"{error_msg}. Stack trace: {e}")
raise HTTPException(status_code=404, detail=error_msg)
@admin_router.delete("/{input_prompt_id}")
def delete_public_input_prompt(
input_prompt_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
try:
remove_public_input_prompt(input_prompt_id, db_session)
except ValueError as e:
error_msg = "Error occurred while deleting input prompt"
logger.warn(f"{error_msg}. Stack trace: {e}")
raise HTTPException(status_code=404, detail=error_msg)
@admin_router.get("")
def list_public_input_prompts(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[InputPromptSnapshot]:
user_prompts = fetch_public_input_prompts(
db_session=db_session,
)
return [InputPromptSnapshot.from_model(prompt) for prompt in user_prompts]

View File

@@ -1,47 +0,0 @@
from uuid import UUID
from pydantic import BaseModel
from danswer.db.models import InputPrompt
from danswer.utils.logger import setup_logger
logger = setup_logger()
class CreateInputPromptRequest(BaseModel):
prompt: str
content: str
is_public: bool
class UpdateInputPromptRequest(BaseModel):
prompt: str
content: str
active: bool
class InputPromptResponse(BaseModel):
id: int
prompt: str
content: str
active: bool
class InputPromptSnapshot(BaseModel):
id: int
prompt: str
content: str
active: bool
user_id: UUID | None
is_public: bool
@classmethod
def from_model(cls, input_prompt: InputPrompt) -> "InputPromptSnapshot":
return InputPromptSnapshot(
id=input_prompt.id,
prompt=input_prompt.prompt,
content=input_prompt.content,
active=input_prompt.active,
user_id=input_prompt.user_id,
is_public=input_prompt.is_public,
)

View File

@@ -266,5 +266,7 @@ class FullModelVersionResponse(BaseModel):
class AllUsersResponse(BaseModel):
accepted: list[FullUserSnapshot]
invited: list[InvitedUserSnapshot]
slack_users: list[FullUserSnapshot]
accepted_pages: int
invited_pages: int
slack_users_pages: int

View File

@@ -119,6 +119,7 @@ def set_user_role(
def list_all_users(
q: str | None = None,
accepted_page: int | None = None,
slack_users_page: int | None = None,
invited_page: int | None = None,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
@@ -131,7 +132,12 @@ def list_all_users(
for user in list_users(db_session, email_filter_string=q)
if not is_api_key_email_address(user.email)
]
accepted_emails = {user.email for user in users}
slack_users = [user for user in users if user.role == UserRole.SLACK_USER]
accepted_users = [user for user in users if user.role != UserRole.SLACK_USER]
accepted_emails = {user.email for user in accepted_users}
slack_users_emails = {user.email for user in slack_users}
invited_emails = get_invited_users()
if q:
invited_emails = [
@@ -139,10 +145,11 @@ def list_all_users(
]
accepted_count = len(accepted_emails)
slack_users_count = len(slack_users_emails)
invited_count = len(invited_emails)
# If any of q, accepted_page, or invited_page is None, return all users
if accepted_page is None or invited_page is None:
if accepted_page is None or invited_page is None or slack_users_page is None:
return AllUsersResponse(
accepted=[
FullUserSnapshot(
@@ -153,11 +160,23 @@ def list_all_users(
UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED
),
)
for user in users
for user in accepted_users
],
slack_users=[
FullUserSnapshot(
id=user.id,
email=user.email,
role=user.role,
status=(
UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED
),
)
for user in slack_users
],
invited=[InvitedUserSnapshot(email=email) for email in invited_emails],
accepted_pages=1,
invited_pages=1,
slack_users_pages=1,
)
# Otherwise, return paginated results
@@ -169,13 +188,27 @@ def list_all_users(
role=user.role,
status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED,
)
for user in users
for user in accepted_users
][accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE],
slack_users=[
FullUserSnapshot(
id=user.id,
email=user.email,
role=user.role,
status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED,
)
for user in slack_users
][
slack_users_page
* USERS_PAGE_SIZE : (slack_users_page + 1)
* USERS_PAGE_SIZE
],
invited=[InvitedUserSnapshot(email=email) for email in invited_emails][
invited_page * USERS_PAGE_SIZE : (invited_page + 1) * USERS_PAGE_SIZE
],
accepted_pages=accepted_count // USERS_PAGE_SIZE + 1,
invited_pages=invited_count // USERS_PAGE_SIZE + 1,
slack_users_pages=slack_users_count // USERS_PAGE_SIZE + 1,
)

View File

@@ -48,6 +48,9 @@ from danswer.tools.tool_implementations.search_like_tool_utils import (
from danswer.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from danswer.tools.tool_implementations.search_like_tool_utils import (
ORIGINAL_CONTEXT_DOCUMENTS_ID,
)
from danswer.utils.logger import setup_logger
from danswer.utils.special_types import JSON_ro
@@ -391,15 +394,35 @@ class SearchTool(Tool):
"""Other utility functions"""
@classmethod
def get_search_result(cls, llm_call: LLMCall) -> list[LlmDoc] | None:
def get_search_result(
cls, llm_call: LLMCall
) -> tuple[list[LlmDoc], dict[str, int]] | None:
"""
Returns the final search results and a map of docs to their original search rank (which is what is displayed to user)
"""
if not llm_call.tool_call_info:
return None
final_search_results = []
doc_id_to_original_search_rank_map = {}
for yield_item in llm_call.tool_call_info:
if (
isinstance(yield_item, ToolResponse)
and yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID
):
return cast(list[LlmDoc], yield_item.response)
final_search_results = cast(list[LlmDoc], yield_item.response)
elif (
isinstance(yield_item, ToolResponse)
and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID
):
search_contexts = yield_item.response.contexts
original_doc_search_rank = 1
for idx, doc in enumerate(search_contexts):
if doc.document_id not in doc_id_to_original_search_rank_map:
doc_id_to_original_search_rank_map[
doc.document_id
] = original_doc_search_rank
original_doc_search_rank += 1
return None
return final_search_results, doc_id_to_original_search_rank_map

View File

@@ -15,6 +15,7 @@ from danswer.tools.message import ToolCallSummary
from danswer.tools.models import ToolResponse
ORIGINAL_CONTEXT_DOCUMENTS_ID = "search_doc_content"
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"

View File

@@ -0,0 +1,77 @@
from __future__ import annotations
import importlib
import os
import pkgutil
import sys
from types import ModuleType
from typing import List
from typing import Type
from typing import TypeVar
T = TypeVar("T")
def import_all_modules_from_dir(dir_path: str) -> List[ModuleType]:
"""
Imports all modules found in the given directory and its subdirectories,
returning a list of imported module objects.
"""
dir_path = os.path.abspath(dir_path)
if dir_path not in sys.path:
sys.path.insert(0, dir_path)
imported_modules: List[ModuleType] = []
for _, package_name, _ in pkgutil.walk_packages([dir_path]):
try:
module = importlib.import_module(package_name)
imported_modules.append(module)
except Exception as e:
# Handle or log exceptions as needed
print(f"Could not import {package_name}: {e}")
return imported_modules
def all_subclasses(cls: Type[T]) -> List[Type[T]]:
"""
Recursively find all subclasses of the given class.
"""
direct_subs = cls.__subclasses__()
result: List[Type[T]] = []
for subclass in direct_subs:
result.append(subclass)
# Extend the result by recursively calling all_subclasses
result.extend(all_subclasses(subclass))
return result
def find_all_subclasses_in_dir(parent_class: Type[T], directory: str) -> List[Type[T]]:
"""
Imports all modules from the given directory (and subdirectories),
then returns all classes that are subclasses of parent_class.
:param parent_class: The class to find subclasses of.
:param directory: The directory to search for subclasses.
:return: A list of all subclasses of parent_class found in the directory.
"""
# First import all modules to ensure classes are loaded into memory
import_all_modules_from_dir(directory)
# Gather all subclasses of the given parent class
subclasses = all_subclasses(parent_class)
return subclasses
# Example usage:
if __name__ == "__main__":
class Animal:
pass
# Suppose "mymodules" contains files that define classes inheriting from Animal
found_subclasses = find_all_subclasses_in_dir(Animal, "mymodules")
for sc in found_subclasses:
print("Found subclass:", sc.__name__)

View File

@@ -11,6 +11,14 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/danswer/configs/saml
#####
# Auto Permission Sync
#####
# In seconds, default is 5 minutes
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
# In seconds, default is 5 minutes
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)

View File

@@ -10,6 +10,9 @@ from danswer.access.utils import prefix_group_w_source
from danswer.configs.constants import DocumentSource
from danswer.db.models import User__ExternalUserGroupId
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
from danswer.utils.logger import setup_logger
logger = setup_logger()
class ExternalUserGroup(BaseModel):
@@ -73,7 +76,13 @@ def replace_user__ext_group_for_cc_pair(
new_external_permissions = []
for external_group in group_defs:
for user_email in external_group.user_emails:
user_id = email_id_map[user_email]
user_id = email_id_map.get(user_email.lower())
if user_id is None:
logger.warning(
f"User in group {external_group.id}"
f" with email {user_email} not found"
)
continue
new_external_permissions.append(
User__ExternalUserGroupId(
user_id=user_id,

View File

@@ -195,6 +195,7 @@ def _fetch_all_page_restrictions_for_space(
confluence_client: OnyxConfluence,
slim_docs: list[SlimDocument],
space_permissions_by_space_key: dict[str, ExternalAccess],
is_cloud: bool,
) -> list[DocExternalAccess]:
"""
For all pages, if a page has restrictions, then use those restrictions.
@@ -222,29 +223,50 @@ def _fetch_all_page_restrictions_for_space(
continue
space_key = slim_doc.perm_sync_data.get("space_key")
if space_permissions := space_permissions_by_space_key.get(space_key):
# If there are no restrictions, then use the space's restrictions
document_restrictions.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=space_permissions,
)
if not (space_permissions := space_permissions_by_space_key.get(space_key)):
logger.debug(
f"Individually fetching space permissions for space {space_key}"
)
if (
not space_permissions.is_public
and not space_permissions.external_user_emails
and not space_permissions.external_user_group_ids
):
try:
# If the space permissions are not in the cache, then fetch them
if is_cloud:
retrieved_space_permissions = _get_cloud_space_permissions(
confluence_client=confluence_client, space_key=space_key
)
else:
retrieved_space_permissions = _get_server_space_permissions(
confluence_client=confluence_client, space_key=space_key
)
space_permissions_by_space_key[space_key] = retrieved_space_permissions
space_permissions = retrieved_space_permissions
except Exception as e:
logger.warning(
f"Permissions are empty for document: {slim_doc.id}\n"
"This means space permissions are may be wrong for"
f" Space key: {space_key}"
f"Error fetching space permissions for space {space_key}: {e}"
)
if not space_permissions:
logger.warning(
f"No permissions found for document {slim_doc.id} in space {space_key}"
)
continue
logger.warning(
f"No permissions found for document {slim_doc.id} in space {space_key}"
# If there are no restrictions, then use the space's restrictions
document_restrictions.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=space_permissions,
)
)
if (
not space_permissions.is_public
and not space_permissions.external_user_emails
and not space_permissions.external_user_group_ids
):
logger.warning(
f"Permissions are empty for document: {slim_doc.id}\n"
"This means space permissions are may be wrong for"
f" Space key: {space_key}"
)
logger.debug("Finished fetching all page restrictions for space")
return document_restrictions
@@ -283,4 +305,5 @@ def confluence_doc_sync(
confluence_client=confluence_connector.confluence_client,
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,
is_cloud=is_cloud,
)

View File

@@ -3,6 +3,8 @@ from collections.abc import Callable
from danswer.access.models import DocExternalAccess
from danswer.configs.constants import DocumentSource
from danswer.db.models import ConnectorCredentialPair
from ee.danswer.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
from ee.danswer.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.danswer.db.external_perm import ExternalUserGroup
from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_sync
from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync
@@ -56,7 +58,7 @@ GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
# If nothing is specified here, we run the doc_sync every time the celery beat runs
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all doc permissions every 5 minutes
DocumentSource.CONFLUENCE: 5 * 60,
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
DocumentSource.SLACK: 5 * 60,
}
@@ -64,7 +66,7 @@ DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all group permissions every 30 minutes
DocumentSource.GOOGLE_DRIVE: 5 * 60,
DocumentSource.CONFLUENCE: 30 * 60,
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
}

View File

@@ -132,13 +132,18 @@ def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) ->
if personas:
logger.notice("Seeding Personas")
for persona in personas:
if not persona.prompt_ids:
raise ValueError(
f"Invalid Persona with name {persona.name}; no prompts exist"
)
upsert_persona(
user=None, # Seeding is done as admin
name=persona.name,
description=persona.description,
num_chunks=persona.num_chunks
if persona.num_chunks is not None
else 0.0,
num_chunks=(
persona.num_chunks if persona.num_chunks is not None else 0.0
),
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=RecencyBiasSetting.AUTO,

View File

@@ -1,4 +1,6 @@
import asyncio
import json
from types import TracebackType
from typing import cast
from typing import Optional
@@ -6,11 +8,11 @@ import httpx
import openai
import vertexai # type: ignore
import voyageai # type: ignore
from cohere import Client as CohereClient
from cohere import AsyncClient as CohereAsyncClient
from fastapi import APIRouter
from fastapi import HTTPException
from google.oauth2 import service_account # type: ignore
from litellm import embedding
from litellm import aembedding
from litellm.exceptions import RateLimitError
from retry import retry
from sentence_transformers import CrossEncoder # type: ignore
@@ -63,22 +65,31 @@ class CloudEmbedding:
provider: EmbeddingProvider,
api_url: str | None = None,
api_version: str | None = None,
timeout: int = API_BASED_EMBEDDING_TIMEOUT,
) -> None:
self.provider = provider
self.api_key = api_key
self.api_url = api_url
self.api_version = api_version
self.timeout = timeout
self.http_client = httpx.AsyncClient(timeout=timeout)
self._closed = False
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
async def _embed_openai(
self, texts: list[str], model: str | None
) -> list[Embedding]:
if not model:
model = DEFAULT_OPENAI_MODEL
client = openai.OpenAI(api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT)
# Use the OpenAI specific timeout for this one
client = openai.AsyncOpenAI(
api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT
)
final_embeddings: list[Embedding] = []
try:
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = client.embeddings.create(input=text_batch, model=model)
response = await client.embeddings.create(input=text_batch, model=model)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
@@ -93,19 +104,19 @@ class CloudEmbedding:
logger.error(error_string)
raise RuntimeError(error_string)
def _embed_cohere(
async def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_COHERE_MODEL
client = CohereClient(api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT)
client = CohereAsyncClient(api_key=self.api_key)
final_embeddings: list[Embedding] = []
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
# empirically it's only off by a very few tokens so it's not a big deal
response = client.embed(
response = await client.embed(
texts=text_batch,
model=model,
input_type=embedding_type,
@@ -114,26 +125,29 @@ class CloudEmbedding:
final_embeddings.extend(cast(list[Embedding], response.embeddings))
return final_embeddings
def _embed_voyage(
async def _embed_voyage(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_VOYAGE_MODEL
client = voyageai.Client(
client = voyageai.AsyncClient(
api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT
)
response = client.embed(
texts,
response = await client.embed(
texts=texts,
model=model,
input_type=embedding_type,
truncation=True,
)
return response.embeddings
def _embed_azure(self, texts: list[str], model: str | None) -> list[Embedding]:
response = embedding(
async def _embed_azure(
self, texts: list[str], model: str | None
) -> list[Embedding]:
response = await aembedding(
model=model,
input=texts,
timeout=API_BASED_EMBEDDING_TIMEOUT,
@@ -142,10 +156,9 @@ class CloudEmbedding:
api_version=self.api_version,
)
embeddings = [embedding["embedding"] for embedding in response.data]
return embeddings
def _embed_vertex(
async def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
@@ -158,7 +171,7 @@ class CloudEmbedding:
vertexai.init(project=project_id, credentials=credentials)
client = TextEmbeddingModel.from_pretrained(model)
embeddings = client.get_embeddings(
embeddings = await client.get_embeddings_async(
[
TextEmbeddingInput(
text,
@@ -166,11 +179,11 @@ class CloudEmbedding:
)
for text in texts
],
auto_truncate=True, # Also this is default
auto_truncate=True, # This is the default
)
return [embedding.values for embedding in embeddings]
def _embed_litellm_proxy(
async def _embed_litellm_proxy(
self, texts: list[str], model_name: str | None
) -> list[Embedding]:
if not model_name:
@@ -183,22 +196,20 @@ class CloudEmbedding:
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
)
with httpx.Client() as client:
response = client.post(
self.api_url,
json={
"model": model_name,
"input": texts,
},
headers=headers,
timeout=API_BASED_EMBEDDING_TIMEOUT,
)
response.raise_for_status()
result = response.json()
return [embedding["embedding"] for embedding in result["data"]]
response = await self.http_client.post(
self.api_url,
json={
"model": model_name,
"input": texts,
},
headers=headers,
)
response.raise_for_status()
result = response.json()
return [embedding["embedding"] for embedding in result["data"]]
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
def embed(
async def embed(
self,
*,
texts: list[str],
@@ -207,19 +218,19 @@ class CloudEmbedding:
deployment_name: str | None = None,
) -> list[Embedding]:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
return await self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return self._embed_azure(texts, f"azure/{deployment_name}")
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return self._embed_litellm_proxy(texts, model_name)
return await self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return self._embed_cohere(texts, model_name, embedding_type)
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return self._embed_voyage(texts, model_name, embedding_type)
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return self._embed_vertex(texts, model_name, embedding_type)
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
@@ -233,6 +244,30 @@ class CloudEmbedding:
logger.debug(f"Creating Embedding instance for provider: {provider}")
return CloudEmbedding(api_key, provider, api_url, api_version)
async def aclose(self) -> None:
"""Explicitly close the client."""
if not self._closed:
await self.http_client.aclose()
self._closed = True
async def __aenter__(self) -> "CloudEmbedding":
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
def __del__(self) -> None:
"""Finalizer to warn about unclosed clients."""
if not self._closed:
logger.warning(
"CloudEmbedding was not properly closed. Use 'async with' or call aclose()"
)
def get_embedding_model(
model_name: str,
@@ -242,9 +277,6 @@ def get_embedding_model(
global _GLOBAL_MODELS_DICT # A dictionary to store models
if _GLOBAL_MODELS_DICT is None:
_GLOBAL_MODELS_DICT = {}
if model_name not in _GLOBAL_MODELS_DICT:
logger.notice(f"Loading {model_name}")
# Some model architectures that aren't built into the Transformers or Sentence
@@ -275,7 +307,7 @@ def get_local_reranking_model(
@simple_log_function_time()
def embed_text(
async def embed_text(
texts: list[str],
text_type: EmbedTextType,
model_name: str | None,
@@ -311,18 +343,18 @@ def embed_text(
"Cloud models take an explicit text type instead."
)
cloud_model = CloudEmbedding(
async with CloudEmbedding(
api_key=api_key,
provider=provider_type,
api_url=api_url,
api_version=api_version,
)
embeddings = cloud_model.embed(
texts=texts,
model_name=model_name,
deployment_name=deployment_name,
text_type=text_type,
)
) as cloud_model:
embeddings = await cloud_model.embed(
texts=texts,
model_name=model_name,
deployment_name=deployment_name,
text_type=text_type,
)
if any(embedding is None for embedding in embeddings):
error_message = "Embeddings contain None values\n"
@@ -338,8 +370,12 @@ def embed_text(
local_model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
embeddings_vectors = local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
# Run CPU-bound embedding in a thread pool
embeddings_vectors = await asyncio.get_event_loop().run_in_executor(
None,
lambda: local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
),
)
embeddings = [
embedding if isinstance(embedding, list) else embedding.tolist()
@@ -357,27 +393,31 @@ def embed_text(
@simple_log_function_time()
def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
cross_encoder = get_local_reranking_model(model_name)
return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
# Run CPU-bound reranking in a thread pool
return await asyncio.get_event_loop().run_in_executor(
None,
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore
)
def cohere_rerank(
async def cohere_rerank(
query: str, docs: list[str], model_name: str, api_key: str
) -> list[float]:
cohere_client = CohereClient(api_key=api_key)
response = cohere_client.rerank(query=query, documents=docs, model=model_name)
cohere_client = CohereAsyncClient(api_key=api_key)
response = await cohere_client.rerank(query=query, documents=docs, model=model_name)
results = response.results
sorted_results = sorted(results, key=lambda item: item.index)
return [result.relevance_score for result in sorted_results]
def litellm_rerank(
async def litellm_rerank(
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
) -> list[float]:
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
with httpx.Client() as client:
response = client.post(
async with httpx.AsyncClient() as client:
response = await client.post(
api_url,
json={
"model": model_name,
@@ -411,7 +451,7 @@ async def process_embed_request(
else:
prefix = None
embeddings = embed_text(
embeddings = await embed_text(
texts=embed_request.texts,
model_name=embed_request.model_name,
deployment_name=embed_request.deployment_name,
@@ -451,7 +491,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
try:
if rerank_request.provider_type is None:
sim_scores = local_rerank(
sim_scores = await local_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
@@ -461,7 +501,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
if rerank_request.api_url is None:
raise ValueError("API URL is required for LiteLLM reranking.")
sim_scores = litellm_rerank(
sim_scores = await litellm_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
api_url=rerank_request.api_url,
@@ -474,7 +514,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
elif rerank_request.provider_type == RerankerProvider.COHERE:
if rerank_request.api_key is None:
raise RuntimeError("Cohere Rerank Requires an API Key")
sim_scores = cohere_rerank(
sim_scores = await cohere_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,

View File

@@ -6,12 +6,12 @@ router = APIRouter(prefix="/api")
@router.get("/health")
def healthcheck() -> Response:
async def healthcheck() -> Response:
return Response(status_code=200)
@router.get("/gpu-status")
def gpu_status() -> dict[str, bool | str]:
async def gpu_status() -> dict[str, bool | str]:
if torch.cuda.is_available():
return {"gpu_available": True, "type": "cuda"}
elif torch.backends.mps.is_available():

View File

@@ -1,3 +1,4 @@
import asyncio
import time
from collections.abc import Callable
from collections.abc import Generator
@@ -21,21 +22,39 @@ def simple_log_function_time(
include_args: bool = False,
) -> Callable[[F], F]:
def decorator(func: F) -> F:
@wraps(func)
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
result = func(*args, **kwargs)
elapsed_time_str = str(time.time() - start_time)
log_name = func_name or func.__name__
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
if debug_only:
logger.debug(final_log)
else:
logger.notice(final_log)
if asyncio.iscoroutinefunction(func):
return result
@wraps(func)
async def wrapped_async_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
result = await func(*args, **kwargs)
elapsed_time_str = str(time.time() - start_time)
log_name = func_name or func.__name__
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
if debug_only:
logger.debug(final_log)
else:
logger.notice(final_log)
return result
return cast(F, wrapped_func)
return cast(F, wrapped_async_func)
else:
@wraps(func)
def wrapped_sync_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
result = func(*args, **kwargs)
elapsed_time_str = str(time.time() - start_time)
log_name = func_name or func.__name__
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
if debug_only:
logger.debug(final_log)
else:
logger.notice(final_log)
return result
return cast(F, wrapped_sync_func)
return decorator

View File

@@ -29,7 +29,7 @@ trafilatura==1.12.2
langchain==0.1.17
langchain-core==0.1.50
langchain-text-splitters==0.0.1
litellm==1.53.1
litellm==1.54.1
lxml==5.3.0
lxml_html_clean==0.2.2
llama-index==0.9.45

View File

@@ -1,30 +1,34 @@
black==23.3.0
boto3-stubs[s3]==1.34.133
celery-types==0.19.0
cohere==5.6.1
google-cloud-aiplatform==1.58.0
lxml==5.3.0
lxml_html_clean==0.2.2
mypy-extensions==1.0.0
mypy==1.8.0
pandas-stubs==2.2.3.241009
pandas==2.2.3
pre-commit==3.2.2
pytest-asyncio==0.22.0
pytest==7.4.4
reorder-python-imports==3.9.0
ruff==0.0.286
types-PyYAML==6.0.12.11
sentence-transformers==2.6.1
trafilatura==1.12.2
types-beautifulsoup4==4.12.0.3
types-html5lib==1.1.11.13
types-oauthlib==3.2.0.9
types-setuptools==68.0.0.3
types-Pillow==10.2.0.20240822
types-passlib==1.7.7.20240106
types-Pillow==10.2.0.20240822
types-psutil==5.9.5.17
types-psycopg2==2.9.21.10
types-python-dateutil==2.8.19.13
types-pytz==2023.3.1.1
types-PyYAML==6.0.12.11
types-regex==2023.3.23.1
types-requests==2.28.11.17
types-retry==0.9.9.3
types-setuptools==68.0.0.3
types-urllib3==1.26.25.11
trafilatura==1.12.2
lxml==5.3.0
lxml_html_clean==0.2.2
boto3-stubs[s3]==1.34.133
pandas==2.2.3
pandas-stubs==2.2.3.241009
cohere==5.6.1
voyageai==0.2.3

View File

@@ -12,5 +12,5 @@ torch==2.2.0
transformers==4.39.2
uvicorn==0.21.1
voyageai==0.2.3
litellm==1.50.2
litellm==1.54.1
sentry-sdk[fastapi,celery,starlette]==2.14.0

View File

@@ -42,7 +42,7 @@ class PersonaManager:
"is_public": is_public,
"llm_filter_extraction": llm_filter_extraction,
"recency_bias": recency_bias,
"prompt_ids": prompt_ids or [],
"prompt_ids": prompt_ids or [0],
"document_set_ids": document_set_ids or [],
"tool_ids": tool_ids or [],
"llm_model_provider_override": llm_model_provider_override,

View File

@@ -69,8 +69,10 @@ class TenantManager:
return AllUsersResponse(
accepted=[FullUserSnapshot(**user) for user in data["accepted"]],
invited=[InvitedUserSnapshot(**user) for user in data["invited"]],
slack_users=[FullUserSnapshot(**user) for user in data["slack_users"]],
accepted_pages=data["accepted_pages"],
invited_pages=data["invited_pages"],
slack_users_pages=data["slack_users_pages"],
)
@staticmethod

View File

@@ -130,8 +130,10 @@ class UserManager:
all_users = AllUsersResponse(
accepted=[FullUserSnapshot(**user) for user in data["accepted"]],
invited=[InvitedUserSnapshot(**user) for user in data["invited"]],
slack_users=[FullUserSnapshot(**user) for user in data["slack_users"]],
accepted_pages=data["accepted_pages"],
invited_pages=data["invited_pages"],
slack_users_pages=data["slack_users_pages"],
)
for accepted_user in all_users.accepted:
if accepted_user.email == user.email and accepted_user.id == user.id:

View File

@@ -3,6 +3,8 @@ from datetime import datetime
from datetime import timezone
from typing import Any
import pytest
from danswer.connectors.models import InputType
from danswer.db.enums import AccessType
from danswer.server.documents.models import DocumentSource
@@ -23,7 +25,7 @@ from tests.integration.common_utils.vespa import vespa_fixture
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
# @pytest.mark.xfail(reason="flaky - see DAN-789 for example", strict=False)
@pytest.mark.xfail(reason="flaky - see DAN-789 for example", strict=False)
def test_slack_permission_sync(
reset: None,
vespa_client: vespa_fixture,
@@ -65,7 +67,6 @@ def test_slack_permission_sync(
input_type=InputType.POLL,
source=DocumentSource.SLACK,
connector_specific_config={
"workspace": "onyx-test-workspace",
"channels": [public_channel["name"], private_channel["name"]],
},
access_type=AccessType.SYNC,
@@ -279,7 +280,6 @@ def test_slack_group_permission_sync(
input_type=InputType.POLL,
source=DocumentSource.SLACK,
connector_specific_config={
"workspace": "onyx-test-workspace",
"channels": [private_channel["name"]],
},
access_type=AccessType.SYNC,

View File

@@ -61,7 +61,6 @@ def test_slack_prune(
input_type=InputType.POLL,
source=DocumentSource.SLACK,
connector_specific_config={
"workspace": "onyx-test-workspace",
"channels": [public_channel["name"], private_channel["name"]],
},
access_type=AccessType.PUBLIC,

View File

@@ -27,13 +27,6 @@ def test_limited(reset: None) -> None:
)
assert response.status_code == 200
# test basic endpoints
response = requests.get(
f"{API_SERVER_URL}/input_prompt",
headers=api_key.headers,
)
assert response.status_code == 403
# test admin endpoints
response = requests.get(
f"{API_SERVER_URL}/admin/api-key",

View File

@@ -72,8 +72,10 @@ def process_text(
processor = CitationProcessor(
context_docs=mock_docs,
doc_id_to_rank_map=mapping,
display_doc_order_dict=mock_doc_id_to_rank_map,
stop_stream=None,
)
result: list[DanswerAnswerPiece | CitationInfo] = []
for token in tokens:
result.extend(processor.process_token(token))
@@ -86,6 +88,7 @@ def process_text(
final_answer_text += piece.answer_piece or ""
elif isinstance(piece, CitationInfo):
citations.append(piece)
return final_answer_text, citations

View File

@@ -0,0 +1,132 @@
from datetime import datetime
import pytest
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import LlmDoc
from danswer.chat.stream_processing.citation_processing import CitationProcessor
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
from danswer.configs.constants import DocumentSource
"""
This module contains tests for the citation extraction functionality in Danswer,
specifically the substitution of the number of document cited in the UI. (The LLM
will see the sources post re-ranking and relevance check, the UI before these steps.)
This module is a derivative of test_citation_processing.py.
The tests focusses specifically on the substitution of the number of document cited in the UI.
Key components:
- mock_docs: A list of mock LlmDoc objects used for testing.
- mock_doc_mapping: A dictionary mapping document IDs to their initial ranks.
- mock_doc_mapping_rerank: A dictionary mapping document IDs to their ranks after re-ranking/relevance check.
- process_text: A helper function that simulates the citation extraction process.
- test_citation_extraction: A parametrized test function covering various citation scenarios.
To add new test cases:
1. Add a new tuple to the @pytest.mark.parametrize decorator of test_citation_extraction.
2. Each tuple should contain:
- A descriptive test name (string)
- Input tokens (list of strings)
- Expected output text (string)
- Expected citations (list of document IDs)
"""
mock_docs = [
LlmDoc(
document_id=f"doc_{int(id/2)}",
content="Document is a doc",
blurb=f"Document #{id}",
semantic_identifier=f"Doc {id}",
source_type=DocumentSource.WEB,
metadata={},
updated_at=datetime.now(),
link=f"https://{int(id/2)}.com" if int(id / 2) % 2 == 0 else None,
source_links={0: "https://mintlify.com/docs/settings/broken-links"},
match_highlights=[],
)
for id in range(10)
]
mock_doc_mapping = {
"doc_0": 1,
"doc_1": 2,
"doc_2": 3,
"doc_3": 4,
"doc_4": 5,
"doc_5": 6,
}
mock_doc_mapping_rerank = {
"doc_0": 2,
"doc_1": 1,
"doc_2": 4,
"doc_3": 3,
"doc_4": 6,
"doc_5": 5,
}
@pytest.fixture
def mock_data() -> tuple[list[LlmDoc], dict[str, int]]:
return mock_docs, mock_doc_mapping
def process_text(
tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]]
) -> tuple[str, list[CitationInfo]]:
mock_docs, mock_doc_id_to_rank_map = mock_data
mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
processor = CitationProcessor(
context_docs=mock_docs,
doc_id_to_rank_map=mapping,
display_doc_order_dict=mock_doc_mapping_rerank,
stop_stream=None,
)
result: list[DanswerAnswerPiece | CitationInfo] = []
for token in tokens:
result.extend(processor.process_token(token))
result.extend(processor.process_token(None))
final_answer_text = ""
citations = []
for piece in result:
if isinstance(piece, DanswerAnswerPiece):
final_answer_text += piece.answer_piece or ""
elif isinstance(piece, CitationInfo):
citations.append(piece)
return final_answer_text, citations
@pytest.mark.parametrize(
"test_name, input_tokens, expected_text, expected_citations",
[
(
"Single citation",
["Gro", "wth! [", "1", "]", "."],
"Growth! [[2]](https://0.com).",
["doc_0"],
),
],
)
def test_citation_substitution(
mock_data: tuple[list[LlmDoc], dict[str, int]],
test_name: str,
input_tokens: list[str],
expected_text: str,
expected_citations: list[str],
) -> None:
final_answer_text, citations = process_text(input_tokens, mock_data)
assert (
final_answer_text.strip() == expected_text.strip()
), f"Test '{test_name}' failed: Final answer text does not match expected output."
assert [
citation.document_id for citation in citations
] == expected_citations, (
f"Test '{test_name}' failed: Citations do not match expected output."
)

View File

@@ -0,0 +1,120 @@
from typing import List
from danswer.configs.app_configs import MAX_DOCUMENT_CHARS
from danswer.connectors.models import Document
from danswer.connectors.models import DocumentSource
from danswer.connectors.models import Section
from danswer.indexing.indexing_pipeline import filter_documents
def create_test_document(
doc_id: str = "test_id",
title: str | None = "Test Title",
semantic_id: str = "test_semantic_id",
sections: List[Section] | None = None,
) -> Document:
if sections is None:
sections = [Section(text="Test content", link="test_link")]
return Document(
id=doc_id,
title=title,
semantic_identifier=semantic_id,
sections=sections,
source=DocumentSource.FILE,
metadata={},
)
def test_filter_documents_empty_title_and_content() -> None:
doc = create_test_document(
title="", semantic_id="", sections=[Section(text="", link="test_link")]
)
result = filter_documents([doc])
assert len(result) == 0
def test_filter_documents_empty_title_with_content() -> None:
doc = create_test_document(
title="", sections=[Section(text="Valid content", link="test_link")]
)
result = filter_documents([doc])
assert len(result) == 1
assert result[0].id == "test_id"
def test_filter_documents_empty_content_with_title() -> None:
doc = create_test_document(
title="Valid Title", sections=[Section(text="", link="test_link")]
)
result = filter_documents([doc])
assert len(result) == 1
assert result[0].id == "test_id"
def test_filter_documents_exceeding_max_chars() -> None:
if not MAX_DOCUMENT_CHARS: # Skip if no max chars configured
return
long_text = "a" * (MAX_DOCUMENT_CHARS + 1)
doc = create_test_document(sections=[Section(text=long_text, link="test_link")])
result = filter_documents([doc])
assert len(result) == 0
def test_filter_documents_valid_document() -> None:
doc = create_test_document(
title="Valid Title", sections=[Section(text="Valid content", link="test_link")]
)
result = filter_documents([doc])
assert len(result) == 1
assert result[0].id == "test_id"
assert result[0].title == "Valid Title"
def test_filter_documents_whitespace_only() -> None:
doc = create_test_document(
title=" ", semantic_id=" ", sections=[Section(text=" ", link="test_link")]
)
result = filter_documents([doc])
assert len(result) == 0
def test_filter_documents_semantic_id_no_title() -> None:
doc = create_test_document(
title=None,
semantic_id="Valid Semantic ID",
sections=[Section(text="Valid content", link="test_link")],
)
result = filter_documents([doc])
assert len(result) == 1
assert result[0].semantic_identifier == "Valid Semantic ID"
def test_filter_documents_multiple_sections() -> None:
doc = create_test_document(
sections=[
Section(text="Content 1", link="test_link"),
Section(text="Content 2", link="test_link"),
Section(text="Content 3", link="test_link"),
]
)
result = filter_documents([doc])
assert len(result) == 1
assert len(result[0].sections) == 3
def test_filter_documents_multiple_documents() -> None:
docs = [
create_test_document(doc_id="1", title="Title 1"),
create_test_document(
doc_id="2", title="", sections=[Section(text="", link="test_link")]
), # Should be filtered
create_test_document(doc_id="3", title="Title 3"),
]
result = filter_documents(docs)
assert len(result) == 2
assert {doc.id for doc in result} == {"1", "3"}
def test_filter_documents_empty_batch() -> None:
result = filter_documents([])
assert len(result) == 0

View File

@@ -0,0 +1,198 @@
import asyncio
import time
from collections.abc import AsyncGenerator
from typing import Any
from typing import List
from unittest.mock import AsyncMock
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from httpx import AsyncClient
from litellm.exceptions import RateLimitError
from model_server.encoders import CloudEmbedding
from model_server.encoders import embed_text
from model_server.encoders import local_rerank
from model_server.encoders import process_embed_request
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import EmbedRequest
@pytest.fixture
async def mock_http_client() -> AsyncGenerator[AsyncMock, None]:
with patch("httpx.AsyncClient") as mock:
client = AsyncMock(spec=AsyncClient)
mock.return_value = client
client.post = AsyncMock()
async with client as c:
yield c
@pytest.fixture
def sample_embeddings() -> List[List[float]]:
return [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
@pytest.mark.asyncio
async def test_cloud_embedding_context_manager() -> None:
async with CloudEmbedding("fake-key", EmbeddingProvider.OPENAI) as embedding:
assert not embedding._closed
assert embedding._closed
@pytest.mark.asyncio
async def test_cloud_embedding_explicit_close() -> None:
embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI)
assert not embedding._closed
await embedding.aclose()
assert embedding._closed
@pytest.mark.asyncio
async def test_openai_embedding(
mock_http_client: AsyncMock, sample_embeddings: List[List[float]]
) -> None:
with patch("openai.AsyncOpenAI") as mock_openai:
mock_client = AsyncMock()
mock_openai.return_value = mock_client
mock_response = MagicMock()
mock_response.data = [MagicMock(embedding=emb) for emb in sample_embeddings]
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI)
result = await embedding._embed_openai(
["test1", "test2"], "text-embedding-ada-002"
)
assert result == sample_embeddings
mock_client.embeddings.create.assert_called_once()
@pytest.mark.asyncio
async def test_embed_text_cloud_provider() -> None:
with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed:
mock_embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
mock_embed.side_effect = AsyncMock(return_value=[[0.1, 0.2], [0.3, 0.4]])
result = await embed_text(
texts=["test1", "test2"],
text_type=EmbedTextType.QUERY,
model_name="fake-model",
deployment_name=None,
max_context_length=512,
normalize_embeddings=True,
api_key="fake-key",
provider_type=EmbeddingProvider.OPENAI,
prefix=None,
api_url=None,
api_version=None,
)
assert result == [[0.1, 0.2], [0.3, 0.4]]
mock_embed.assert_called_once()
@pytest.mark.asyncio
async def test_embed_text_local_model() -> None:
with patch("model_server.encoders.get_embedding_model") as mock_get_model:
mock_model = MagicMock()
mock_model.encode.return_value = [[0.1, 0.2], [0.3, 0.4]]
mock_get_model.return_value = mock_model
result = await embed_text(
texts=["test1", "test2"],
text_type=EmbedTextType.QUERY,
model_name="fake-local-model",
deployment_name=None,
max_context_length=512,
normalize_embeddings=True,
api_key=None,
provider_type=None,
prefix=None,
api_url=None,
api_version=None,
)
assert result == [[0.1, 0.2], [0.3, 0.4]]
mock_model.encode.assert_called_once()
@pytest.mark.asyncio
async def test_local_rerank() -> None:
with patch("model_server.encoders.get_local_reranking_model") as mock_get_model:
mock_model = MagicMock()
mock_array = MagicMock()
mock_array.tolist.return_value = [0.8, 0.6]
mock_model.predict.return_value = mock_array
mock_get_model.return_value = mock_model
result = await local_rerank(
query="test query", docs=["doc1", "doc2"], model_name="fake-rerank-model"
)
assert result == [0.8, 0.6]
mock_model.predict.assert_called_once()
@pytest.mark.asyncio
async def test_rate_limit_handling() -> None:
with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed:
mock_embed.side_effect = RateLimitError(
"Rate limit exceeded", llm_provider="openai", model="fake-model"
)
with pytest.raises(RateLimitError):
await embed_text(
texts=["test"],
text_type=EmbedTextType.QUERY,
model_name="fake-model",
deployment_name=None,
max_context_length=512,
normalize_embeddings=True,
api_key="fake-key",
provider_type=EmbeddingProvider.OPENAI,
prefix=None,
api_url=None,
api_version=None,
)
@pytest.mark.asyncio
async def test_concurrent_embeddings() -> None:
def mock_encode(*args: Any, **kwargs: Any) -> List[List[float]]:
time.sleep(5)
return [[0.1, 0.2, 0.3]]
test_req = EmbedRequest(
texts=["test"],
model_name="'nomic-ai/nomic-embed-text-v1'",
deployment_name=None,
max_context_length=512,
normalize_embeddings=True,
api_key=None,
provider_type=None,
text_type=EmbedTextType.QUERY,
manual_query_prefix=None,
manual_passage_prefix=None,
api_url=None,
api_version=None,
)
with patch("model_server.encoders.get_embedding_model") as mock_get_model:
mock_model = MagicMock()
mock_model.encode = mock_encode
mock_get_model.return_value = mock_model
start_time = time.time()
tasks = [process_embed_request(test_req) for _ in range(5)]
await asyncio.gather(*tasks)
end_time = time.time()
# 5 * 5 seconds = 25 seconds, this test ensures that the embeddings are at least yielding the thread
# However, the developer may still introduce unnecessary blocking above the mock and this test will
# still pass as long as it's less than (7 - 5) / 5 seconds
assert end_time - start_time < 7

View File

@@ -6,7 +6,7 @@ chart-dirs:
# must be kept in sync with Chart.yaml
chart-repos:
- vespa=https://danswer-ai.github.io/vespa-helm-charts
- vespa=https://onyx-dot-app.github.io/vespa-helm-charts
- postgresql=https://charts.bitnami.com/bitnami
helm-extra-args: --debug --timeout 600s

View File

@@ -183,6 +183,13 @@ services:
- GONG_CONNECTOR_START_TIME=${GONG_CONNECTOR_START_TIME:-}
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
- GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-}
- MAX_DOCUMENT_CHARS=${MAX_DOCUMENT_CHARS:-}
- MAX_FILE_SIZE_BYTES=${MAX_FILE_SIZE_BYTES:-}
# Egnyte OAuth Configs
- EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-}
- EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-}
- EGNYTE_BASE_DOMAIN=${EGNYTE_BASE_DOMAIN:-}
- EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-}
# Celery Configs (defaults are set in the supervisord.conf file.
# prefer doing that to have one source of defaults)
- CELERY_WORKER_INDEXING_CONCURRENCY=${CELERY_WORKER_INDEXING_CONCURRENCY:-}

View File

@@ -0,0 +1,74 @@
# Docker service resource limits. Most are commented out by default.
# 'background' service has preset (override-able) limits due to variable resource needs.
# Uncomment and set env vars for specific service limits.
# See: https://docs.danswer.dev/deployment/resource-sizing for details.
services:
background:
deploy:
resources:
limits:
cpus: ${BACKGROUND_CPU_LIMIT:-4}
memory: ${BACKGROUND_MEM_LIMIT:-4g}
# reservations:
# cpus: ${BACKGROUND_CPU_RESERVATION}
# memory: ${BACKGROUND_MEM_RESERVATION}
# nginx:
# deploy:
# resources:
# limits:
# cpus: ${NGINX_CPU_LIMIT}
# memory: ${NGINX_MEM_LIMIT}
# reservations:
# cpus: ${NGINX_CPU_RESERVATION}
# memory: ${NGINX_MEM_RESERVATION}
# api_server:
# deploy:
# resources:
# limits:
# cpus: ${API_SERVER_CPU_LIMIT}
# memory: ${API_SERVER_MEM_LIMIT}
# reservations:
# cpus: ${API_SERVER_CPU_RESERVATION}
# memory: ${API_SERVER_MEM_RESERVATION}
# index:
# deploy:
# resources:
# limits:
# cpus: ${VESPA_CPU_LIMIT}
# memory: ${VESPA_MEM_LIMIT}
# reservations:
# cpus: ${VESPA_CPU_RESERVATION}
# memory: ${VESPA_MEM_RESERVATION}
# inference_model_server:
# deploy:
# resources:
# limits:
# cpus: ${INFERENCE_CPU_LIMIT}
# memory: ${INFERENCE_MEM_LIMIT}
# reservations:
# cpus: ${INFERENCE_CPU_RESERVATION}
# memory: ${INFERENCE_MEM_RESERVATION}
# indexing_model_server:
# deploy:
# resources:
# limits:
# cpus: ${INDEXING_CPU_LIMIT}
# memory: ${INDEXING_MEM_LIMIT}
# reservations:
# cpus: ${INDEXING_CPU_RESERVATION}
# memory: ${INDEXING_MEM_RESERVATION}
# relational_db:
# deploy:
# resources:
# limits:
# cpus: ${POSTGRES_CPU_LIMIT}
# memory: ${POSTGRES_MEM_LIMIT}
# reservations:
# cpus: ${POSTGRES_CPU_RESERVATION}
# memory: ${POSTGRES_MEM_RESERVATION}

View File

@@ -3,13 +3,13 @@ dependencies:
repository: https://charts.bitnami.com/bitnami
version: 14.3.1
- name: vespa
repository: https://danswer-ai.github.io/vespa-helm-charts
version: 0.2.16
repository: https://onyx-dot-app.github.io/vespa-helm-charts
version: 0.2.18
- name: nginx
repository: oci://registry-1.docker.io/bitnamicharts
version: 15.14.0
- name: redis
repository: https://charts.bitnami.com/bitnami
version: 20.1.0
digest: sha256:711bbb76ba6ab604a36c9bf1839ab6faa5610afb21e535afd933c78f2d102232
generated: "2024-11-07T09:39:30.17171-08:00"
digest: sha256:5c9eb3d55d5f8e3beb64f26d26f686c8d62755daa10e2e6d87530bdf2fbbf957
generated: "2024-12-10T10:47:35.812483-08:00"

View File

@@ -23,8 +23,8 @@ dependencies:
repository: https://charts.bitnami.com/bitnami
condition: postgresql.enabled
- name: vespa
version: 0.2.16
repository: https://danswer-ai.github.io/vespa-helm-charts
version: 0.2.18
repository: https://onyx-dot-app.github.io/vespa-helm-charts
condition: vespa.enabled
- name: nginx
version: 15.14.0

View File

@@ -61,6 +61,8 @@ data:
WEB_CONNECTOR_VALIDATE_URLS: ""
GONG_CONNECTOR_START_TIME: ""
NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP: ""
MAX_DOCUMENT_CHARS: ""
MAX_FILE_SIZE_BYTES: ""
# DanswerBot SlackBot Configs
DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER: ""
DANSWER_BOT_DISPLAY_ERROR_MSGS: ""

View File

@@ -66,6 +66,9 @@ ARG NEXT_PUBLIC_POSTHOG_HOST
ENV NEXT_PUBLIC_POSTHOG_KEY=${NEXT_PUBLIC_POSTHOG_KEY}
ENV NEXT_PUBLIC_POSTHOG_HOST=${NEXT_PUBLIC_POSTHOG_HOST}
ARG NEXT_PUBLIC_CLOUD_ENABLED
ENV NEXT_PUBLIC_CLOUD_ENABLED=${NEXT_PUBLIC_CLOUD_ENABLED}
ARG NEXT_PUBLIC_SENTRY_DSN
ENV NEXT_PUBLIC_SENTRY_DSN=${NEXT_PUBLIC_SENTRY_DSN}
@@ -138,6 +141,9 @@ ARG NEXT_PUBLIC_POSTHOG_HOST
ENV NEXT_PUBLIC_POSTHOG_KEY=${NEXT_PUBLIC_POSTHOG_KEY}
ENV NEXT_PUBLIC_POSTHOG_HOST=${NEXT_PUBLIC_POSTHOG_HOST}
ARG NEXT_PUBLIC_CLOUD_ENABLED
ENV NEXT_PUBLIC_CLOUD_ENABLED=${NEXT_PUBLIC_CLOUD_ENABLED}
ARG NEXT_PUBLIC_SENTRY_DSN
ENV NEXT_PUBLIC_SENTRY_DSN=${NEXT_PUBLIC_SENTRY_DSN}

4
web/package-lock.json generated
View File

@@ -1,12 +1,12 @@
{
"name": "qa",
"version": "0.2.0-dev",
"version": "1.0.0-dev",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"name": "qa",
"version": "0.2.0-dev",
"version": "1.0.0-dev",
"dependencies": {
"@dnd-kit/core": "^6.1.0",
"@dnd-kit/modifiers": "^7.0.0",

View File

@@ -1,6 +1,6 @@
{
"name": "qa",
"version": "0.2.0-dev",
"version": "1.0.0-dev",
"version-comment": "version field must be SemVer or chromatic will barf",
"private": true,
"scripts": {

BIN
web/public/Egnyte.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 12 KiB

BIN
web/public/Wikipedia.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 769 KiB

File diff suppressed because one or more lines are too long

Before

Width:  |  Height:  |  Size: 164 KiB

View File

@@ -82,7 +82,7 @@ export const DanswerApiKeyForm = ({
}}
>
{({ isSubmitting, values, setFieldValue }) => (
<Form>
<Form className="w-full overflow-visible">
<Text className="mb-4 text-lg">
Choose a memorable name for your API key. This is optional and
can be added or changed later!

View File

@@ -45,9 +45,6 @@ function NewApiKeyModal({
<div className="px-8 py-8">
<div className="flex w-full border-b border-border mb-4 pb-4">
<Title>New API Key</Title>
<div onClick={onClose} className="ml-auto p-1 rounded hover:bg-hover">
<FiX size={18} />
</div>
</div>
<div className="h-32">
<Text className="mb-4">

View File

@@ -96,6 +96,16 @@ export function SlackBotTable({ slackBots }: { slackBots: SlackBot[] }) {
</ClickableTableRow>
);
})}
{slackBots.length === 0 && (
<TableRow>
<TableCell
colSpan={4}
className="text-center text-muted-foreground"
>
Please add a New Slack Bot to begin chatting with Danswer!
</TableCell>
</TableRow>
)}
</TableBody>
</Table>
{slackBots.length > NUM_IN_PAGE && (

View File

@@ -275,8 +275,9 @@ export function CustomLLMProviderUpdateForm({
<SubLabel>
<>
<div>
Additional configurations needed by the model provider. Are
passed to litellm via environment variables.
Additional configurations needed by the model provider. These
are passed to litellm via environment + as arguments into the
`completion` call.
</div>
<div className="mt-2">
@@ -290,14 +291,14 @@ export function CustomLLMProviderUpdateForm({
<FieldArray
name="custom_config_list"
render={(arrayHelpers: ArrayHelpers<any[]>) => (
<div>
<div className="w-full">
{formikProps.values.custom_config_list.map((_, index) => {
return (
<div
key={index}
className={index === 0 ? "mt-2" : "mt-6"}
className={(index === 0 ? "mt-2" : "mt-6") + " w-full"}
>
<div className="flex">
<div className="flex w-full">
<div className="w-full mr-6 border border-border p-3 rounded">
<div>
<Label>Key</Label>
@@ -457,6 +458,7 @@ export function CustomLLMProviderUpdateForm({
<Button
type="button"
variant="destructive"
className="ml-3"
icon={FiTrash}
onClick={async () => {
const response = await fetch(

View File

@@ -2,6 +2,11 @@ import CardSection from "@/components/admin/CardSection";
import { getNameFromPath } from "@/lib/fileUtils";
import { ValidSources } from "@/lib/types";
import Title from "@/components/ui/title";
import { EditIcon } from "@/components/icons/icons";
import { useState } from "react";
import { ChevronUpIcon } from "lucide-react";
import { ChevronDownIcon } from "@/components/icons/icons";
function convertObjectToString(obj: any): string | any {
// Check if obj is an object and not an array or null
@@ -39,14 +44,83 @@ function buildConfigEntries(
return obj;
}
function ConfigItem({ label, value }: { label: string; value: any }) {
const [isExpanded, setIsExpanded] = useState(false);
const isExpandable = Array.isArray(value) && value.length > 5;
const renderValue = () => {
if (Array.isArray(value)) {
const displayedItems = isExpanded ? value : value.slice(0, 5);
return (
<ul className="list-disc max-w-full pl-4 mt-2 overflow-x-auto">
{displayedItems.map((item, index) => (
<li
key={index}
className="mb-1 max-w-full overflow-hidden text-right text-ellipsis whitespace-nowrap"
>
{convertObjectToString(item)}
</li>
))}
</ul>
);
} else if (typeof value === "object" && value !== null) {
return (
<div className="mt-2 overflow-x-auto">
{Object.entries(value).map(([key, val]) => (
<div key={key} className="mb-1">
<span className="font-semibold">{key}:</span>{" "}
{convertObjectToString(val)}
</div>
))}
</div>
);
}
return convertObjectToString(value) || "-";
};
return (
<li className="w-full py-2">
<div className="flex items-center justify-between w-full">
<span className="mb-2">{label}</span>
<div className="mt-2 overflow-x-auto w-fit">
{renderValue()}
{isExpandable && (
<button
onClick={() => setIsExpanded(!isExpanded)}
className="mt-2 ml-auto text-text-600 hover:text-text-800 flex items-center"
>
{isExpanded ? (
<>
<ChevronUpIcon className="h-4 w-4 mr-1" />
Show less
</>
) : (
<>
<ChevronDownIcon className="h-4 w-4 mr-1" />
Show all ({value.length} items)
</>
)}
</button>
)}
</div>
</div>
</li>
);
}
export function AdvancedConfigDisplay({
pruneFreq,
refreshFreq,
indexingStart,
onRefreshEdit,
onPruningEdit,
}: {
pruneFreq: number | null;
refreshFreq: number | null;
indexingStart: Date | null;
onRefreshEdit: () => void;
onPruningEdit: () => void;
}) {
const formatRefreshFrequency = (seconds: number | null): string => {
if (seconds === null) return "-";
@@ -75,14 +149,21 @@ export function AdvancedConfigDisplay({
<>
<Title className="mt-8 mb-2">Advanced Configuration</Title>
<CardSection>
<ul className="w-full text-sm divide-y divide-neutral-200 dark:divide-neutral-700">
<ul className="w-full text-sm divide-y divide-background-200 dark:divide-background-700">
{pruneFreq && (
<li
key={0}
className="w-full flex justify-between items-center py-2"
>
<span>Pruning Frequency</span>
<span>{formatPruneFrequency(pruneFreq)}</span>
<span className="ml-auto w-24">
{formatPruneFrequency(pruneFreq)}
</span>
<span className="w-8 text-right">
<button onClick={() => onPruningEdit()}>
<EditIcon size={12} />
</button>
</span>
</li>
)}
{refreshFreq && (
@@ -91,7 +172,14 @@ export function AdvancedConfigDisplay({
className="w-full flex justify-between items-center py-2"
>
<span>Refresh Frequency</span>
<span>{formatRefreshFrequency(refreshFreq)}</span>
<span className="ml-auto w-24">
{formatRefreshFrequency(refreshFreq)}
</span>
<span className="w-8 text-right">
<button onClick={() => onRefreshEdit()}>
<EditIcon size={12} />
</button>
</span>
</li>
)}
{indexingStart && (
@@ -127,15 +215,9 @@ export function ConfigDisplay({
<>
<Title className="mb-2">Configuration</Title>
<CardSection>
<ul className="w-full text-sm divide-y divide-neutral-200 dark:divide-neutral-700">
<ul className="w-full text-sm divide-y divide-background-200 dark:divide-background-700">
{configEntries.map(([key, value]) => (
<li
key={key}
className="w-full flex justify-between items-center py-2"
>
<span>{key}</span>
<span>{convertObjectToString(value) || "-"}</span>
</li>
<ConfigItem key={key} label={key} value={value} />
))}
</ul>
</CardSection>

View File

@@ -7,7 +7,10 @@ import { SourceIcon } from "@/components/SourceIcon";
import { CCPairStatus } from "@/components/Status";
import { usePopup } from "@/components/admin/connectors/Popup";
import CredentialSection from "@/components/credentials/CredentialSection";
import { updateConnectorCredentialPairName } from "@/lib/connector";
import {
updateConnectorCredentialPairName,
updateConnectorCredentialPairProperty,
} from "@/lib/connector";
import { credentialTemplates } from "@/lib/connectors/credentials";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { ValidSources } from "@/lib/types";
@@ -26,12 +29,33 @@ import { buildCCPairInfoUrl } from "./lib";
import { CCPairFullInfo, ConnectorCredentialPairStatus } from "./types";
import { EditableStringFieldDisplay } from "@/components/EditableStringFieldDisplay";
import { Button } from "@/components/ui/button";
import EditPropertyModal from "@/components/modals/EditPropertyModal";
import * as Yup from "yup";
// since the uploaded files are cleaned up after some period of time
// re-indexing will not work for the file connector. Also, it would not
// make sense to re-index, since the files will not have changed.
const CONNECTOR_TYPES_THAT_CANT_REINDEX: ValidSources[] = [ValidSources.File];
// synchronize these validations with the SQLAlchemy connector class until we have a
// centralized schema for both frontend and backend
const RefreshFrequencySchema = Yup.object().shape({
propertyValue: Yup.number()
.typeError("Property value must be a valid number")
.integer("Property value must be an integer")
.min(60, "Property value must be greater than or equal to 60")
.required("Property value is required"),
});
const PruneFrequencySchema = Yup.object().shape({
propertyValue: Yup.number()
.typeError("Property value must be a valid number")
.integer("Property value must be an integer")
.min(86400, "Property value must be greater than or equal to 86400")
.required("Property value is required"),
});
function Main({ ccPairId }: { ccPairId: number }) {
const router = useRouter(); // Initialize the router
const {
@@ -45,6 +69,8 @@ function Main({ ccPairId }: { ccPairId: number }) {
);
const [hasLoadedOnce, setHasLoadedOnce] = useState(false);
const [editingRefreshFrequency, setEditingRefreshFrequency] = useState(false);
const [editingPruningFrequency, setEditingPruningFrequency] = useState(false);
const { popup, setPopup } = usePopup();
const finishConnectorDeletion = useCallback(() => {
@@ -90,6 +116,86 @@ function Main({ ccPairId }: { ccPairId: number }) {
}
};
const handleRefreshEdit = async () => {
setEditingRefreshFrequency(true);
};
const handlePruningEdit = async () => {
setEditingPruningFrequency(true);
};
const handleRefreshSubmit = async (
propertyName: string,
propertyValue: string
) => {
const parsedRefreshFreq = parseInt(propertyValue, 10);
if (isNaN(parsedRefreshFreq)) {
setPopup({
message: "Invalid refresh frequency: must be an integer",
type: "error",
});
return;
}
try {
const response = await updateConnectorCredentialPairProperty(
ccPairId,
propertyName,
String(parsedRefreshFreq)
);
if (!response.ok) {
throw new Error(await response.text());
}
mutate(buildCCPairInfoUrl(ccPairId));
setPopup({
message: "Connector refresh frequency updated successfully",
type: "success",
});
} catch (error) {
setPopup({
message: "Failed to update connector refresh frequency",
type: "error",
});
}
};
const handlePruningSubmit = async (
propertyName: string,
propertyValue: string
) => {
const parsedFreq = parseInt(propertyValue, 10);
if (isNaN(parsedFreq)) {
setPopup({
message: "Invalid pruning frequency: must be an integer",
type: "error",
});
return;
}
try {
const response = await updateConnectorCredentialPairProperty(
ccPairId,
propertyName,
String(parsedFreq)
);
if (!response.ok) {
throw new Error(await response.text());
}
mutate(buildCCPairInfoUrl(ccPairId));
setPopup({
message: "Connector pruning frequency updated successfully",
type: "success",
});
} catch (error) {
setPopup({
message: "Failed to update connector pruning frequency",
type: "error",
});
}
};
if (isLoading) {
return <ThreeDotsLoader />;
}
@@ -114,9 +220,35 @@ function Main({ ccPairId }: { ccPairId: number }) {
refresh_freq: refreshFreq,
indexing_start: indexingStart,
} = ccPair.connector;
return (
<>
{popup}
{editingRefreshFrequency && (
<EditPropertyModal
propertyTitle="Refresh Frequency"
propertyDetails="How often the connector should refresh (in seconds)"
propertyName="refresh_frequency"
propertyValue={String(refreshFreq)}
validationSchema={RefreshFrequencySchema}
onSubmit={handleRefreshSubmit}
onClose={() => setEditingRefreshFrequency(false)}
/>
)}
{editingPruningFrequency && (
<EditPropertyModal
propertyTitle="Pruning Frequency"
propertyDetails="How often the connector should be pruned (in seconds)"
propertyName="pruning_frequency"
propertyValue={String(pruneFreq)}
validationSchema={PruneFrequencySchema}
onSubmit={handlePruningSubmit}
onClose={() => setEditingPruningFrequency(false)}
/>
)}
<BackButton
behaviorOverride={() => router.push("/admin/indexing/status")}
/>
@@ -125,7 +257,7 @@ function Main({ ccPairId }: { ccPairId: number }) {
<SourceIcon iconSize={32} sourceType={ccPair.connector.source} />
</div>
<div className="ml-1">
<div className="ml-1 overflow-hidden text-ellipsis whitespace-nowrap flex-1 mr-4">
<EditableStringFieldDisplay
value={ccPair.name}
isEditable={ccPair.is_editable_for_current_user}
@@ -213,6 +345,8 @@ function Main({ ccPairId }: { ccPairId: number }) {
pruneFreq={pruneFreq}
indexingStart={indexingStart}
refreshFreq={refreshFreq}
onRefreshEdit={handleRefreshEdit}
onPruningEdit={handlePruningEdit}
/>
)}

View File

@@ -19,7 +19,11 @@ import AdvancedFormPage from "./pages/Advanced";
import DynamicConnectionForm from "./pages/DynamicConnectorCreationForm";
import CreateCredential from "@/components/credentials/actions/CreateCredential";
import ModifyCredential from "@/components/credentials/actions/ModifyCredential";
import { ConfigurableSources, ValidSources } from "@/lib/types";
import {
ConfigurableSources,
oauthSupportedSources,
ValidSources,
} from "@/lib/types";
import { Credential, credentialTemplates } from "@/lib/connectors/credentials";
import {
ConnectionConfiguration,
@@ -45,6 +49,8 @@ import { useRouter } from "next/navigation";
import CardSection from "@/components/admin/CardSection";
import { prepareOAuthAuthorizationRequest } from "@/lib/oauth_utils";
import { EE_ENABLED, NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import TemporaryLoadingModal from "@/components/TemporaryLoadingModal";
import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth";
export interface AdvancedConfig {
refreshFreq: number;
pruneFreq: number;
@@ -154,14 +160,9 @@ export default function AddConnector({
const configuration: ConnectionConfiguration = connectorConfigs[connector];
// Form context and popup management
const {
setFormStep,
setAllowCreate: setAllowCreate,
formStep,
nextFormStep,
prevFormStep,
} = useFormContext();
const { setFormStep, setAllowCreate, formStep } = useFormContext();
const { popup, setPopup } = usePopup();
const [uploading, setUploading] = useState(false);
// Hooks for Google Drive and Gmail credentials
const { liveGDriveCredential } = useGoogleDriveCredentials(connector);
@@ -339,16 +340,24 @@ export default function AddConnector({
}
// File-specific handling
if (connector == "file") {
const response = await submitFiles(
selectedFiles,
setPopup,
name,
access_type,
groups
);
if (response) {
onSuccess();
setUploading(true);
try {
const response = await submitFiles(
selectedFiles,
setPopup,
name,
access_type,
groups
);
if (response) {
onSuccess();
}
} catch (error) {
setPopup({ message: "Error uploading files", type: "error" });
} finally {
setUploading(false);
}
return;
}
@@ -410,9 +419,9 @@ export default function AddConnector({
<div className="mx-auto mb-8 w-full">
{popup}
<div className="mb-4">
<HealthCheckBanner />
</div>
{uploading && (
<TemporaryLoadingModal content="Uploading files..." />
)}
<AdminPageTitle
includeDivider={false}
@@ -444,26 +453,38 @@ export default function AddConnector({
{/* Button to pop up a form to manually enter credentials */}
<button
className="mt-6 text-sm bg-background-900 px-2 py-1.5 flex text-text-200 flex-none rounded mr-4"
onClick={() =>
setCreateConnectorToggle(
(createConnectorToggle) => !createConnectorToggle
)
}
onClick={async () => {
const redirectUrl =
await getConnectorOauthRedirectUrl(connector);
// if redirect is supported, just use it
if (redirectUrl) {
window.location.href = redirectUrl;
} else {
setCreateConnectorToggle(
(createConnectorToggle) =>
!createConnectorToggle
);
}
}}
>
Create New
</button>
{/* Button to sign in via OAuth */}
<button
onClick={handleAuthorize}
className="mt-6 text-sm bg-blue-500 px-2 py-1.5 flex text-text-200 flex-none rounded"
disabled={isAuthorizing}
hidden={!isAuthorizeVisible}
>
{isAuthorizing
? "Authorizing..."
: `Authorize with ${getSourceDisplayName(connector)}`}
</button>
{oauthSupportedSources.includes(connector) &&
NEXT_PUBLIC_CLOUD_ENABLED && (
<button
onClick={handleAuthorize}
className="mt-6 text-sm bg-blue-500 px-2 py-1.5 flex text-text-200 flex-none rounded"
disabled={isAuthorizing}
hidden={!isAuthorizeVisible}
>
{isAuthorizing
? "Authorizing..."
: `Authorize with ${getSourceDisplayName(
connector
)}`}
</button>
)}
</div>
)}

View File

@@ -104,7 +104,9 @@ const GDriveMain = ({}: {}) => {
const googleDriveServiceAccountCredential:
| Credential<GoogleDriveServiceAccountCredentialJson>
| undefined = credentialsData.find(
(credential) => credential.credential_json?.google_service_account_key
(credential) =>
credential.credential_json?.google_service_account_key &&
credential.source === "google_drive"
);
const googleDriveConnectorIndexingStatuses: ConnectorIndexingStatus<

View File

@@ -135,7 +135,7 @@ export const DocumentFeedbackTable = ({
/>
</TableCell>
<TableCell>
<div className="ml-auto flex w-16">
<div className="relative">
<div
key={document.document_id}
className="h-10 ml-auto mr-8"

View File

@@ -353,13 +353,9 @@ export function CCPairIndexingStatusTable({
);
};
const toggleSources = () => {
const currentToggledCount =
Object.values(connectorsToggled).filter(Boolean).length;
const shouldToggleOn = currentToggledCount < sortedSources.length / 2;
const connectors = sortedSources.reduce(
(acc, source) => {
acc[source] = shouldToggleOn;
acc[source] = shouldExpand;
return acc;
},
{} as Record<ValidSources, boolean>
@@ -368,6 +364,7 @@ export function CCPairIndexingStatusTable({
setConnectorsToggled(connectors);
Cookies.set(TOGGLED_CONNECTORS_COOKIE_NAME, JSON.stringify(connectors));
};
const shouldExpand =
Object.values(connectorsToggled).filter(Boolean).length <
sortedSources.length;

View File

@@ -1,46 +0,0 @@
import useSWR from "swr";
import { InputPrompt } from "./interfaces";
const fetcher = (url: string) => fetch(url).then((res) => res.json());
export const useAdminInputPrompts = () => {
const { data, error, mutate } = useSWR<InputPrompt[]>(
`/api/admin/input_prompt`,
fetcher
);
return {
data,
error,
isLoading: !error && !data,
refreshInputPrompts: mutate,
};
};
export const useInputPrompts = (includePublic: boolean = false) => {
const { data, error, mutate } = useSWR<InputPrompt[]>(
`/api/input_prompt${includePublic ? "?include_public=true" : ""}`,
fetcher
);
return {
data,
error,
isLoading: !error && !data,
refreshInputPrompts: mutate,
};
};
export const useInputPrompt = (id: number) => {
const { data, error, mutate } = useSWR<InputPrompt>(
`/api/input_prompt/${id}`,
fetcher
);
return {
data,
error,
isLoading: !error && !data,
refreshInputPrompt: mutate,
};
};

Some files were not shown because too many files have changed in this diff Show More