mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-21 01:35:46 +00:00
Compare commits
74 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a13dea160a | ||
|
|
ca172f3306 | ||
|
|
e5d0587efa | ||
|
|
a9516202fe | ||
|
|
d23fca96c4 | ||
|
|
a45724c899 | ||
|
|
34e250407a | ||
|
|
046c0fbe3e | ||
|
|
76595facef | ||
|
|
af2d548766 | ||
|
|
7c29b1e028 | ||
|
|
a52c821e78 | ||
|
|
0770a587f1 | ||
|
|
748b79b0ef | ||
|
|
9cacb373ef | ||
|
|
21967d4b6f | ||
|
|
f5d638161b | ||
|
|
0b5013b47d | ||
|
|
1b846fbf06 | ||
|
|
cae8a131a2 | ||
|
|
72b4e8e9fe | ||
|
|
c04e2f14d9 | ||
|
|
b40a12d5d7 | ||
|
|
5e7d454ebe | ||
|
|
238509c536 | ||
|
|
d7f8cf8f18 | ||
|
|
5d810d373e | ||
|
|
9455576078 | ||
|
|
71421bb782 | ||
|
|
b88cb388b7 | ||
|
|
639986001f | ||
|
|
e7a7e78969 | ||
|
|
e255ff7d23 | ||
|
|
1be2502112 | ||
|
|
f2bedb8fdd | ||
|
|
637404f482 | ||
|
|
daae146920 | ||
|
|
d95959fb41 | ||
|
|
c667d28e7a | ||
|
|
9e0b482f47 | ||
|
|
fa84eb657f | ||
|
|
264df3441b | ||
|
|
b9bad8b7a0 | ||
|
|
600ebb6432 | ||
|
|
09fe8ea868 | ||
|
|
ad6be03b4d | ||
|
|
65d2511216 | ||
|
|
113bf19c65 | ||
|
|
6026536110 | ||
|
|
056b671cd4 | ||
|
|
8d83ae2ee8 | ||
|
|
ca988f5c5f | ||
|
|
4e4214b82c | ||
|
|
fe83f676df | ||
|
|
6d6e12119b | ||
|
|
1f2b7cb9c8 | ||
|
|
878a189011 | ||
|
|
48c10271c2 | ||
|
|
c6a79d847e | ||
|
|
1bc3f8b96f | ||
|
|
7f6a6944d6 | ||
|
|
06f4146597 | ||
|
|
7ea73d5a5a | ||
|
|
30dfe6dcb4 | ||
|
|
dc5d5dfe05 | ||
|
|
0746e0be5b | ||
|
|
970320bd49 | ||
|
|
4a7bd5578e | ||
|
|
874b098a4b | ||
|
|
ce18b63eea | ||
|
|
7a919c3589 | ||
|
|
631bac4432 | ||
|
|
53428f6e9c | ||
|
|
53b3dcbace |
60
README.md
60
README.md
@@ -1,48 +1,48 @@
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/danswer-ai/danswer/blob/main/README.md"} -->
|
||||
<!-- DANSWER_METADATA={"link": "https://github.com/onyx-dot-app/onyx/blob/main/README.md"} -->
|
||||
<a name="readme-top"></a>
|
||||
|
||||
<h2 align="center">
|
||||
<a href="https://www.danswer.ai/"> <img width="50%" src="https://github.com/danswer-owners/danswer/blob/1fabd9372d66cd54238847197c33f091a724803b/DanswerWithName.png?raw=true)" /></a>
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/LogoOnyx.png?raw=true)" /></a>
|
||||
</h2>
|
||||
|
||||
<p align="center">
|
||||
<p align="center">Open Source Gen-AI Chat + Unified Search.</p>
|
||||
<p align="center">Open Source Gen-AI + Enterprise Search.</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="https://docs.danswer.dev/" target="_blank">
|
||||
<a href="https://docs.onyx.app/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
|
||||
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2sslpdbyq-iIbTaNIVPBw_i_4vrujLYQ" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
<img src="https://img.shields.io/badge/discord-join-blue.svg?logo=discord&logoColor=white" alt="Discord">
|
||||
</a>
|
||||
<a href="https://github.com/danswer-ai/danswer/blob/main/README.md" target="_blank">
|
||||
<a href="https://github.com/onyx-dot-app/onyx/blob/main/README.md" target="_blank">
|
||||
<img src="https://img.shields.io/static/v1?label=license&message=MIT&color=blue" alt="License">
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<strong>[Danswer](https://www.danswer.ai/)</strong> is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Danswer provides a Chat interface and plugs into any LLM of your choice. Danswer can be deployed anywhere and for any
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
|
||||
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
|
||||
own control. Danswer is MIT licensed and designed to be modular and easily extensible. The system also comes fully ready
|
||||
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
|
||||
for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for
|
||||
configuring Personas (AI Assistants) and their Prompts.
|
||||
configuring AI Assistants.
|
||||
|
||||
Danswer also serves as a Unified Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
|
||||
By combining LLMs and team specific knowledge, Danswer becomes a subject matter expert for the team. Imagine ChatGPT if
|
||||
Onyx also serves as a Enterprise Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
|
||||
By combining LLMs and team specific knowledge, Onyx becomes a subject matter expert for the team. Imagine ChatGPT if
|
||||
it had access to your team's unique knowledge! It enables questions such as "A customer wants feature X, is this already
|
||||
supported?" or "Where's the pull request for feature Y?"
|
||||
|
||||
<h3>Usage</h3>
|
||||
|
||||
Danswer Web App:
|
||||
Onyx Web App:
|
||||
|
||||
https://github.com/danswer-ai/danswer/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
|
||||
|
||||
|
||||
Or, plug Danswer into your existing Slack workflows (more integrations to come 😁):
|
||||
Or, plug Onyx into your existing Slack workflows (more integrations to come 😁):
|
||||
|
||||
https://github.com/danswer-ai/danswer/assets/25087905/3e19739b-d178-4371-9a38-011430bdec1b
|
||||
|
||||
@@ -52,16 +52,16 @@ For more details on the Admin UI to manage connectors and users, check out our
|
||||
|
||||
## Deployment
|
||||
|
||||
Danswer can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.danswer.dev/quickstart) to learn more.
|
||||
Onyx can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
|
||||
|
||||
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/danswer-ai/danswer/tree/main/deployment/kubernetes).
|
||||
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment/kubernetes).
|
||||
|
||||
|
||||
## 💃 Main Features
|
||||
* Chat UI with the ability to select documents to chat with.
|
||||
* Create custom AI Assistants with different prompts and backing knowledge sets.
|
||||
* Connect Danswer with LLM of your choice (self-host for a fully airgapped solution).
|
||||
* Connect Onyx with LLM of your choice (self-host for a fully airgapped solution).
|
||||
* Document Search + AI Answers for natural language queries.
|
||||
* Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
|
||||
* Slack integration to get answers and search results directly in Slack.
|
||||
@@ -75,12 +75,12 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
|
||||
* Organizational understanding and ability to locate and suggest experts from your team.
|
||||
|
||||
|
||||
## Other Notable Benefits of Danswer
|
||||
## Other Notable Benefits of Onyx
|
||||
* User Authentication with document level access management.
|
||||
* Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
|
||||
* Admin Dashboard to configure connectors, document-sets, access, etc.
|
||||
* Custom deep learning models + learn from user feedback.
|
||||
* Easy deployment and ability to host Danswer anywhere of your choosing.
|
||||
* Easy deployment and ability to host Onyx anywhere of your choosing.
|
||||
|
||||
|
||||
## 🔌 Connectors
|
||||
@@ -108,10 +108,10 @@ Efficiently pulls the latest changes from:
|
||||
|
||||
## 📚 Editions
|
||||
|
||||
There are two editions of Danswer:
|
||||
There are two editions of Onyx:
|
||||
|
||||
* Danswer Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Danswer you will get if you follow the Deployment guide above.
|
||||
* Danswer Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
|
||||
* Onyx Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Onyx you will get if you follow the Deployment guide above.
|
||||
* Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
|
||||
* Single Sign-On (SSO), with support for both SAML and OIDC
|
||||
* Role-based access control
|
||||
* Document permission inheritance from connected sources
|
||||
@@ -119,24 +119,24 @@ There are two editions of Danswer:
|
||||
* Whitelabeling
|
||||
* API key authentication
|
||||
* Encryption of secrets
|
||||
* Any many more! Checkout [our website](https://www.danswer.ai/) for the latest.
|
||||
* Any many more! Checkout [our website](https://www.onyx.app/) for the latest.
|
||||
|
||||
To try the Danswer Enterprise Edition:
|
||||
To try the Onyx Enterprise Edition:
|
||||
|
||||
1. Checkout our [Cloud product](https://app.danswer.ai/signup).
|
||||
2. For self-hosting, contact us at [founders@danswer.ai](mailto:founders@danswer.ai) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
|
||||
1. Checkout our [Cloud product](https://cloud.onyx.app/signup).
|
||||
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/danswer/founders).
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#danswer-ai/danswer&Date)
|
||||
[](https://star-history.com/#onyx-dot-app/onyx&Date)
|
||||
|
||||
## ✨Contributors
|
||||
|
||||
<a href="https://github.com/danswer-ai/danswer/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=danswer-ai/danswer"/>
|
||||
<a href="https://github.com/onyx-dot-app/onyx/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=onyx-dot-app/onyx"/>
|
||||
</a>
|
||||
|
||||
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
||||
|
||||
@@ -0,0 +1,57 @@
|
||||
"""delete_input_prompts
|
||||
|
||||
Revision ID: bf7a81109301
|
||||
Revises: f7a894b06d02
|
||||
Create Date: 2024-12-09 12:00:49.884228
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bf7a81109301"
|
||||
down_revision = "f7a894b06d02"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_table("inputprompt__user")
|
||||
op.drop_table("inputprompt")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
"inputprompt",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("prompt", sa.String(), nullable=False),
|
||||
sa.Column("content", sa.String(), nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"inputprompt__user",
|
||||
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["input_prompt_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
|
||||
)
|
||||
@@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
from urllib.parse import quote
|
||||
@@ -18,7 +19,8 @@ _API_KEY_HEADER_NAME = "Authorization"
|
||||
# organizations like the Internet Engineering Task Force (IETF).
|
||||
_API_KEY_HEADER_ALTERNATIVE_NAME = "X-Danswer-Authorization"
|
||||
_BEARER_PREFIX = "Bearer "
|
||||
_API_KEY_PREFIX = "dn_"
|
||||
_API_KEY_PREFIX = "on_"
|
||||
_DEPRECATED_API_KEY_PREFIX = "dn_"
|
||||
_API_KEY_LEN = 192
|
||||
|
||||
|
||||
@@ -52,7 +54,9 @@ def extract_tenant_from_api_key_header(request: Request) -> str | None:
|
||||
|
||||
api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
|
||||
|
||||
if not api_key.startswith(_API_KEY_PREFIX):
|
||||
if not api_key.startswith(_API_KEY_PREFIX) and not api_key.startswith(
|
||||
_DEPRECATED_API_KEY_PREFIX
|
||||
):
|
||||
return None
|
||||
|
||||
parts = api_key[len(_API_KEY_PREFIX) :].split(".", 1)
|
||||
@@ -63,10 +67,19 @@ def extract_tenant_from_api_key_header(request: Request) -> str | None:
|
||||
return unquote(tenant_id) if tenant_id else None
|
||||
|
||||
|
||||
def _deprecated_hash_api_key(api_key: str) -> str:
|
||||
return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS)
|
||||
|
||||
|
||||
def hash_api_key(api_key: str) -> str:
|
||||
# NOTE: no salt is needed, as the API key is randomly generated
|
||||
# and overlaps are impossible
|
||||
return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS)
|
||||
if api_key.startswith(_API_KEY_PREFIX):
|
||||
return hashlib.sha256(api_key.encode("utf-8")).hexdigest()
|
||||
elif api_key.startswith(_DEPRECATED_API_KEY_PREFIX):
|
||||
return _deprecated_hash_api_key(api_key)
|
||||
else:
|
||||
raise ValueError(f"Invalid API key prefix: {api_key[:3]}")
|
||||
|
||||
|
||||
def build_displayable_api_key(api_key: str) -> str:
|
||||
|
||||
@@ -9,7 +9,6 @@ from danswer.utils.special_types import JSON_ro
|
||||
def get_invited_users() -> list[str]:
|
||||
try:
|
||||
store = get_kv_store()
|
||||
|
||||
return cast(list, store.load(KV_USER_STORE_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
return list()
|
||||
|
||||
@@ -131,7 +131,7 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
|
||||
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
if AUTH_TYPE == AuthType.BASIC:
|
||||
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
|
||||
return REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
# For other auth types, if the user is authenticated it's assumed that
|
||||
|
||||
@@ -598,7 +598,7 @@ def connector_indexing_proxy_task(
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
finally:
|
||||
except Exception:
|
||||
# if the DB exceptions, we'll just get an unfriendly failure message
|
||||
# in the UI instead of the cancellation message
|
||||
logger.exception(
|
||||
@@ -640,14 +640,41 @@ def connector_indexing_proxy_task(
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
ignore_exitcode = False
|
||||
|
||||
exit_code: int | None = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
|
||||
# seeing non-deterministic behavior where spawned tasks occasionally return exit code 1
|
||||
# even though logging clearly indicates that they completed successfully
|
||||
# to work around this, we ignore the job error state if the completion signal is OK
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if ignore_exitcode:
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code}"
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
|
||||
job.release()
|
||||
break
|
||||
|
||||
@@ -680,17 +680,28 @@ def monitor_ccpair_indexing_taskset(
|
||||
)
|
||||
task_logger.warning(msg)
|
||||
|
||||
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
|
||||
if index_attempt:
|
||||
if (
|
||||
index_attempt.status != IndexingStatus.CANCELED
|
||||
and index_attempt.status != IndexingStatus.FAILED
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
try:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session, payload.index_attempt_id
|
||||
)
|
||||
if index_attempt:
|
||||
if (
|
||||
index_attempt.status != IndexingStatus.CANCELED
|
||||
and index_attempt.status != IndexingStatus.FAILED
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"monitor_ccpair_indexing_taskset - transient exception marking index attempt as failed: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
@@ -82,7 +82,7 @@ class SimpleJob:
|
||||
return "running"
|
||||
elif self.process.exitcode is None:
|
||||
return "cancelled"
|
||||
elif self.process.exitcode > 0:
|
||||
elif self.process.exitcode != 0:
|
||||
return "error"
|
||||
else:
|
||||
return "finished"
|
||||
@@ -123,7 +123,8 @@ class SimpleJobClient:
|
||||
self._cleanup_completed_jobs()
|
||||
if len(self.jobs) >= self.n_workers:
|
||||
logger.debug(
|
||||
f"No available workers to run job. Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
|
||||
f"No available workers to run job. "
|
||||
f"Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'."
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -206,7 +206,9 @@ class Answer:
|
||||
# + figure out what the next LLM call should be
|
||||
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
|
||||
search_result = SearchTool.get_search_result(current_llm_call) or []
|
||||
search_result, displayed_search_results_map = SearchTool.get_search_result(
|
||||
current_llm_call
|
||||
) or ([], {})
|
||||
|
||||
# Quotes are no longer supported
|
||||
# answer_handler: AnswerResponseHandler
|
||||
@@ -224,6 +226,7 @@ class Answer:
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
display_doc_order_dict=displayed_search_results_map,
|
||||
)
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
|
||||
@@ -35,13 +35,18 @@ class DummyAnswerResponseHandler(AnswerResponseHandler):
|
||||
|
||||
class CitationResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_order_dict: dict[str, int],
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.display_doc_order_dict = display_doc_order_dict
|
||||
self.citation_processor = CitationProcessor(
|
||||
context_docs=self.context_docs,
|
||||
doc_id_to_rank_map=self.doc_id_to_rank_map,
|
||||
display_doc_order_dict=self.display_doc_order_dict,
|
||||
)
|
||||
self.processed_text = ""
|
||||
self.citations: list[CitationInfo] = []
|
||||
|
||||
@@ -22,12 +22,16 @@ class CitationProcessor:
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_order_dict: dict[str, int],
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.stop_stream = stop_stream
|
||||
self.order_mapping = doc_id_to_rank_map.order_mapping
|
||||
self.display_doc_order_dict = (
|
||||
display_doc_order_dict # original order of docs to displayed to user
|
||||
)
|
||||
self.llm_out = ""
|
||||
self.max_citation_num = len(context_docs)
|
||||
self.citation_order: list[int] = []
|
||||
@@ -98,6 +102,18 @@ class CitationProcessor:
|
||||
self.citation_order.index(real_citation_num) + 1
|
||||
)
|
||||
|
||||
# get the value that was displayed to user, should always
|
||||
# be in the display_doc_order_dict. But check anyways
|
||||
if context_llm_doc.document_id in self.display_doc_order_dict:
|
||||
displayed_citation_num = self.display_doc_order_dict[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
else:
|
||||
displayed_citation_num = real_citation_num
|
||||
logger.warning(
|
||||
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
||||
)
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if target_citation_num in self.current_citations:
|
||||
start, end = citation.span()
|
||||
@@ -118,6 +134,7 @@ class CitationProcessor:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
# stay with the original for now (order of LLM cites)
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
@@ -139,6 +156,7 @@ class CitationProcessor:
|
||||
if target_citation_num not in self.cited_inds:
|
||||
self.cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
# stay with the original for now (order of LLM cites)
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
@@ -148,7 +166,8 @@ class CitationProcessor:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{target_citation_num}]]({link})"
|
||||
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
||||
# + f"[[{target_citation_num}]]({link})"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
@@ -156,7 +175,8 @@ class CitationProcessor:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{target_citation_num}]]()"
|
||||
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
||||
# + f"[[{target_citation_num}]]()"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
|
||||
@@ -348,6 +348,12 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
|
||||
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Egnyte specific configs
|
||||
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
|
||||
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
|
||||
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
|
||||
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
@@ -411,21 +417,28 @@ LARGE_CHUNK_RATIO = 4
|
||||
# We don't want the metadata to overwhelm the actual contents of the chunk
|
||||
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
|
||||
# Timeout to wait for job's last update before killing it, in hours
|
||||
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))
|
||||
CLEANUP_INDEXING_JOBS_TIMEOUT = int(
|
||||
os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT") or 3
|
||||
)
|
||||
|
||||
# The indexer will warn in the logs whenver a document exceeds this threshold (in bytes)
|
||||
INDEXING_SIZE_WARNING_THRESHOLD = int(
|
||||
os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD", 100 * 1024 * 1024)
|
||||
os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD") or 100 * 1024 * 1024
|
||||
)
|
||||
|
||||
# during indexing, will log verbose memory diff stats every x batches and at the end.
|
||||
# 0 disables this behavior and is the default.
|
||||
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))
|
||||
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0)
|
||||
|
||||
# During an indexing attempt, specifies the number of batches which are allowed to
|
||||
# exception without aborting the attempt.
|
||||
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT", 0))
|
||||
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0)
|
||||
|
||||
# Maximum file size in a document to be indexed
|
||||
MAX_DOCUMENT_CHARS = int(os.environ.get("MAX_DOCUMENT_CHARS") or 5_000_000)
|
||||
MAX_FILE_SIZE_BYTES = int(
|
||||
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
|
||||
) # 2GB in bytes
|
||||
|
||||
#####
|
||||
# Miscellaneous
|
||||
|
||||
@@ -3,7 +3,6 @@ import os
|
||||
|
||||
PROMPTS_YAML = "./danswer/seeding/prompts.yaml"
|
||||
PERSONAS_YAML = "./danswer/seeding/personas.yaml"
|
||||
INPUT_PROMPT_YAML = "./danswer/seeding/input_prompts.yaml"
|
||||
|
||||
NUM_RETURNED_HITS = 50
|
||||
# Used for LLM filtering and reranking
|
||||
|
||||
@@ -132,6 +132,7 @@ class DocumentSource(str, Enum):
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
FRESHDESK = "freshdesk"
|
||||
FIREFLIES = "fireflies"
|
||||
EGNYTE = "egnyte"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
|
||||
@@ -368,4 +368,5 @@ def build_confluence_client(
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=10,
|
||||
max_backoff_seconds=60,
|
||||
cloud=is_cloud,
|
||||
)
|
||||
|
||||
384
backend/danswer/connectors/egnyte/connector.py
Normal file
384
backend/danswer/connectors/egnyte/connector.py
Normal file
@@ -0,0 +1,384 @@
|
||||
import io
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from logging import Logger
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
from danswer.configs.app_configs import EGNYTE_BASE_DOMAIN
|
||||
from danswer.configs.app_configs import EGNYTE_CLIENT_ID
|
||||
from danswer.configs.app_configs import EGNYTE_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import OAuthConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import detect_encoding
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.file_processing.extract_file_text import get_file_ext
|
||||
from danswer.file_processing.extract_file_text import is_text_file_extension
|
||||
from danswer.file_processing.extract_file_text import is_valid_file_ext
|
||||
from danswer.file_processing.extract_file_text import read_text_file
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
|
||||
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
|
||||
_TIMEOUT = 60
|
||||
|
||||
|
||||
def _request_with_retries(
|
||||
method: str,
|
||||
url: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
headers: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
timeout: int = _TIMEOUT,
|
||||
stream: bool = False,
|
||||
tries: int = 8,
|
||||
delay: float = 1,
|
||||
backoff: float = 2,
|
||||
) -> requests.Response:
|
||||
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
|
||||
def _make_request() -> requests.Response:
|
||||
response = requests.request(
|
||||
method,
|
||||
url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code != 403:
|
||||
logger.exception(
|
||||
f"Failed to call Egnyte API.\n"
|
||||
f"URL: {url}\n"
|
||||
f"Headers: {headers}\n"
|
||||
f"Data: {data}\n"
|
||||
f"Params: {params}"
|
||||
)
|
||||
raise e
|
||||
return response
|
||||
|
||||
return _make_request()
|
||||
|
||||
|
||||
def _parse_last_modified(last_modified: str) -> datetime:
|
||||
return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace(
|
||||
tzinfo=timezone.utc
|
||||
)
|
||||
|
||||
|
||||
def _process_egnyte_file(
|
||||
file_metadata: dict[str, Any],
|
||||
file_content: IO,
|
||||
base_url: str,
|
||||
folder_path: str | None = None,
|
||||
) -> Document | None:
|
||||
"""Process an Egnyte file into a Document object
|
||||
|
||||
Args:
|
||||
file_data: The file data from Egnyte API
|
||||
file_content: The raw content of the file in bytes
|
||||
base_url: The base URL for the Egnyte instance
|
||||
folder_path: Optional folder path to filter results
|
||||
"""
|
||||
# Skip if file path doesn't match folder path filter
|
||||
if folder_path and not file_metadata["path"].startswith(folder_path):
|
||||
raise ValueError(
|
||||
f"File path {file_metadata['path']} does not match folder path {folder_path}"
|
||||
)
|
||||
|
||||
file_name = file_metadata["name"]
|
||||
extension = get_file_ext(file_name)
|
||||
if not is_valid_file_ext(extension):
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
return None
|
||||
|
||||
# Extract text content based on file type
|
||||
if is_text_file_extension(file_name):
|
||||
encoding = detect_encoding(file_content)
|
||||
file_content_raw, file_metadata = read_text_file(
|
||||
file_content, encoding=encoding, ignore_danswer_metadata=False
|
||||
)
|
||||
else:
|
||||
file_content_raw = extract_file_text(
|
||||
file=file_content,
|
||||
file_name=file_name,
|
||||
break_on_unprocessable=True,
|
||||
)
|
||||
|
||||
# Build the web URL for the file
|
||||
web_url = f"{base_url}/navigate/file/{file_metadata['group_id']}"
|
||||
|
||||
# Create document metadata
|
||||
metadata: dict[str, str | list[str]] = {
|
||||
"file_path": file_metadata["path"],
|
||||
"last_modified": file_metadata.get("last_modified", ""),
|
||||
}
|
||||
|
||||
# Add lock info if present
|
||||
if lock_info := file_metadata.get("lock_info"):
|
||||
metadata[
|
||||
"lock_owner"
|
||||
] = f"{lock_info.get('first_name', '')} {lock_info.get('last_name', '')}"
|
||||
|
||||
# Create the document owners
|
||||
primary_owner = None
|
||||
if uploaded_by := file_metadata.get("uploaded_by"):
|
||||
primary_owner = BasicExpertInfo(
|
||||
email=uploaded_by, # Using username as email since that's what we have
|
||||
)
|
||||
|
||||
# Create the document
|
||||
return Document(
|
||||
id=f"egnyte-{file_metadata['entry_id']}",
|
||||
sections=[Section(text=file_content_raw.strip(), link=web_url)],
|
||||
source=DocumentSource.EGNYTE,
|
||||
semantic_identifier=file_name,
|
||||
metadata=metadata,
|
||||
doc_updated_at=(
|
||||
_parse_last_modified(file_metadata["last_modified"])
|
||||
if "last_modified" in file_metadata
|
||||
else None
|
||||
),
|
||||
primary_owners=[primary_owner] if primary_owner else None,
|
||||
)
|
||||
|
||||
|
||||
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
def __init__(
|
||||
self,
|
||||
folder_path: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.domain = "" # will always be set in `load_credentials`
|
||||
self.folder_path = folder_path or "" # Root folder if not specified
|
||||
self.batch_size = batch_size
|
||||
self.access_token: str | None = None
|
||||
|
||||
@classmethod
|
||||
def oauth_id(cls) -> DocumentSource:
|
||||
return DocumentSource.EGNYTE
|
||||
|
||||
@classmethod
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
if not EGNYTE_CLIENT_ID:
|
||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
if EGNYTE_LOCALHOST_OVERRIDE:
|
||||
base_domain = EGNYTE_LOCALHOST_OVERRIDE
|
||||
|
||||
callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte"
|
||||
return (
|
||||
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
f"?client_id={EGNYTE_CLIENT_ID}"
|
||||
f"&redirect_uri={callback_uri}"
|
||||
f"&scope=Egnyte.filesystem"
|
||||
f"&state={state}"
|
||||
f"&response_type=code"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def oauth_code_to_token(cls, code: str) -> dict[str, Any]:
|
||||
if not EGNYTE_CLIENT_ID:
|
||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||
if not EGNYTE_CLIENT_SECRET:
|
||||
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
# Exchange code for token
|
||||
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
data = {
|
||||
"client_id": EGNYTE_CLIENT_ID,
|
||||
"client_secret": EGNYTE_CLIENT_SECRET,
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": f"{EGNYTE_LOCALHOST_OVERRIDE or ''}/connector/oauth/callback/egnyte",
|
||||
"scope": "Egnyte.filesystem",
|
||||
}
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
response = _request_with_retries(
|
||||
method="POST",
|
||||
url=url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
# try a lot faster since this is a realtime flow
|
||||
backoff=0,
|
||||
delay=0.1,
|
||||
)
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failed to exchange code for token: {response.text}")
|
||||
|
||||
token_data = response.json()
|
||||
return {
|
||||
"domain": EGNYTE_BASE_DOMAIN,
|
||||
"access_token": token_data["access_token"],
|
||||
}
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.domain = credentials["domain"]
|
||||
self.access_token = credentials["access_token"]
|
||||
return None
|
||||
|
||||
def _get_files_list(
|
||||
self,
|
||||
path: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
if not self.access_token or not self.domain:
|
||||
raise ConnectorMissingCredentialError("Egnyte")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
}
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"list_content": True,
|
||||
}
|
||||
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{path or ''}"
|
||||
response = _request_with_retries(
|
||||
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
|
||||
)
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
|
||||
|
||||
data = response.json()
|
||||
all_files: list[dict[str, Any]] = []
|
||||
|
||||
# Add files from current directory
|
||||
all_files.extend(data.get("files", []))
|
||||
|
||||
# Recursively traverse folders
|
||||
for item in data.get("folders", []):
|
||||
all_files.extend(self._get_files_list(item["path"]))
|
||||
|
||||
return all_files
|
||||
|
||||
def _filter_files(
|
||||
self,
|
||||
files: list[dict[str, Any]],
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
filtered_files = []
|
||||
for file in files:
|
||||
if file["is_folder"]:
|
||||
continue
|
||||
|
||||
file_modified = _parse_last_modified(file["last_modified"])
|
||||
if start_time and file_modified < start_time:
|
||||
continue
|
||||
if end_time and file_modified > end_time:
|
||||
continue
|
||||
|
||||
filtered_files.append(file)
|
||||
|
||||
return filtered_files
|
||||
|
||||
def _process_files(
|
||||
self,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
) -> Generator[list[Document], None, None]:
|
||||
files = self._get_files_list(self.folder_path)
|
||||
files = self._filter_files(files, start_time, end_time)
|
||||
|
||||
current_batch: list[Document] = []
|
||||
for file in files:
|
||||
try:
|
||||
# Set up request with streaming enabled
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
}
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}"
|
||||
response = _request_with_retries(
|
||||
method="GET",
|
||||
url=url,
|
||||
headers=headers,
|
||||
timeout=_TIMEOUT,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
logger.error(
|
||||
f"Failed to fetch file content: {file['path']} (status code: {response.status_code})"
|
||||
)
|
||||
continue
|
||||
|
||||
# Stream the response content into a BytesIO buffer
|
||||
buffer = io.BytesIO()
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
buffer.write(chunk)
|
||||
|
||||
# Reset buffer's position to the start
|
||||
buffer.seek(0)
|
||||
|
||||
# Process the streamed file content
|
||||
doc = _process_egnyte_file(
|
||||
file_metadata=file,
|
||||
file_content=buffer,
|
||||
base_url=_EGNYTE_APP_BASE.format(domain=self.domain),
|
||||
folder_path=self.folder_path,
|
||||
)
|
||||
|
||||
if doc is not None:
|
||||
current_batch.append(doc)
|
||||
|
||||
if len(current_batch) >= self.batch_size:
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process file {file['path']}")
|
||||
continue
|
||||
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
yield from self._process_files()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_time = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_time = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
yield from self._process_files(start_time=start_time, end_time=end_time)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = EgnyteConnector()
|
||||
connector.load_credentials(
|
||||
{
|
||||
"domain": os.environ["EGNYTE_DOMAIN"],
|
||||
"access_token": os.environ["EGNYTE_ACCESS_TOKEN"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
@@ -15,6 +15,7 @@ from danswer.connectors.danswer_jira.connector import JiraConnector
|
||||
from danswer.connectors.discourse.connector import DiscourseConnector
|
||||
from danswer.connectors.document360.connector import Document360Connector
|
||||
from danswer.connectors.dropbox.connector import DropboxConnector
|
||||
from danswer.connectors.egnyte.connector import EgnyteConnector
|
||||
from danswer.connectors.file.connector import LocalFileConnector
|
||||
from danswer.connectors.fireflies.connector import FirefliesConnector
|
||||
from danswer.connectors.freshdesk.connector import FreshdeskConnector
|
||||
@@ -40,7 +41,6 @@ from danswer.connectors.salesforce.connector import SalesforceConnector
|
||||
from danswer.connectors.sharepoint.connector import SharepointConnector
|
||||
from danswer.connectors.slab.connector import SlabConnector
|
||||
from danswer.connectors.slack.connector import SlackPollConnector
|
||||
from danswer.connectors.slack.load_connector import SlackLoadConnector
|
||||
from danswer.connectors.teams.connector import TeamsConnector
|
||||
from danswer.connectors.web.connector import WebConnector
|
||||
from danswer.connectors.wikipedia.connector import WikipediaConnector
|
||||
@@ -63,7 +63,6 @@ def identify_connector_class(
|
||||
DocumentSource.WEB: WebConnector,
|
||||
DocumentSource.FILE: LocalFileConnector,
|
||||
DocumentSource.SLACK: {
|
||||
InputType.LOAD_STATE: SlackLoadConnector,
|
||||
InputType.POLL: SlackPollConnector,
|
||||
InputType.SLIM_RETRIEVAL: SlackPollConnector,
|
||||
},
|
||||
@@ -103,6 +102,7 @@ def identify_connector_class(
|
||||
DocumentSource.XENFORO: XenforoConnector,
|
||||
DocumentSource.FRESHDESK: FreshdeskConnector,
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
|
||||
@@ -17,11 +17,11 @@ from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.file_processing.extract_file_text import check_file_ext_is_valid
|
||||
from danswer.file_processing.extract_file_text import detect_encoding
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.file_processing.extract_file_text import get_file_ext
|
||||
from danswer.file_processing.extract_file_text import is_text_file_extension
|
||||
from danswer.file_processing.extract_file_text import is_valid_file_ext
|
||||
from danswer.file_processing.extract_file_text import load_files_from_zip
|
||||
from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.extract_file_text import read_text_file
|
||||
@@ -50,7 +50,7 @@ def _read_files_and_metadata(
|
||||
file_content, ignore_dirs=True
|
||||
):
|
||||
yield os.path.join(directory_path, file_info.filename), file, metadata
|
||||
elif check_file_ext_is_valid(extension):
|
||||
elif is_valid_file_ext(extension):
|
||||
yield file_name, file_content, metadata
|
||||
else:
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
@@ -63,7 +63,7 @@ def _process_file(
|
||||
pdf_pass: str | None = None,
|
||||
) -> list[Document]:
|
||||
extension = get_file_ext(file_name)
|
||||
if not check_file_ext_is_valid(extension):
|
||||
if not is_valid_file_ext(extension):
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
return []
|
||||
|
||||
|
||||
@@ -4,11 +4,13 @@ from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import MAX_FILE_SIZE_BYTES
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.google_drive.doc_conversion import build_slim_document
|
||||
from danswer.connectors.google_drive.doc_conversion import (
|
||||
@@ -452,12 +454,14 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if isinstance(self.creds, ServiceAccountCredentials)
|
||||
else self._manage_oauth_retrieval
|
||||
)
|
||||
return retrieval_method(
|
||||
drive_files = retrieval_method(
|
||||
is_slim=is_slim,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
return drive_files
|
||||
|
||||
def _extract_docs_from_google_drive(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
@@ -473,6 +477,15 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
files_to_process = []
|
||||
# Gather the files into batches to be processed in parallel
|
||||
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
|
||||
if (
|
||||
file.get("size")
|
||||
and int(cast(str, file.get("size"))) > MAX_FILE_SIZE_BYTES
|
||||
):
|
||||
logger.warning(
|
||||
f"Skipping file {file.get('name', 'Unknown')} as it is too large: {file.get('size')} bytes"
|
||||
)
|
||||
continue
|
||||
|
||||
files_to_process.append(file)
|
||||
if len(files_to_process) >= LARGE_BATCH_SIZE:
|
||||
yield from _process_files_batch(
|
||||
|
||||
@@ -16,7 +16,7 @@ logger = setup_logger()
|
||||
|
||||
FILE_FIELDS = (
|
||||
"nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, "
|
||||
"shortcutDetails, owners(emailAddress))"
|
||||
"shortcutDetails, owners(emailAddress), size)"
|
||||
)
|
||||
SLIM_FILE_FIELDS = (
|
||||
"nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), "
|
||||
|
||||
@@ -2,6 +2,7 @@ import abc
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import SlimDocument
|
||||
|
||||
@@ -64,6 +65,23 @@ class SlimConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OAuthConnector(BaseConnector):
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_id(cls) -> DocumentSource:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_code_to_token(cls, code: str) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Event driven
|
||||
class EventConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -132,7 +132,6 @@ class LinearConnector(LoadConnector, PollConnector):
|
||||
branchName
|
||||
customerTicketCount
|
||||
description
|
||||
descriptionData
|
||||
comments {
|
||||
nodes {
|
||||
url
|
||||
@@ -215,5 +214,6 @@ class LinearConnector(LoadConnector, PollConnector):
|
||||
if __name__ == "__main__":
|
||||
connector = LinearConnector()
|
||||
connector.load_credentials({"linear_api_key": os.environ["LINEAR_API_KEY"]})
|
||||
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
|
||||
@@ -134,7 +134,6 @@ def get_latest_message_time(thread: ThreadType) -> datetime:
|
||||
|
||||
|
||||
def thread_to_doc(
|
||||
workspace: str,
|
||||
channel: ChannelType,
|
||||
thread: ThreadType,
|
||||
slack_cleaner: SlackTextCleaner,
|
||||
@@ -171,15 +170,15 @@ def thread_to_doc(
|
||||
else first_message
|
||||
)
|
||||
|
||||
doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}"
|
||||
doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}".replace(
|
||||
"\n", " "
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=f"{channel_id}__{thread[0]['ts']}",
|
||||
sections=[
|
||||
Section(
|
||||
link=get_message_link(
|
||||
event=m, workspace=workspace, channel_id=channel_id
|
||||
),
|
||||
link=get_message_link(event=m, client=client, channel_id=channel_id),
|
||||
text=slack_cleaner.index_clean(cast(str, m["text"])),
|
||||
)
|
||||
for m in thread
|
||||
@@ -263,7 +262,6 @@ def filter_channels(
|
||||
|
||||
def _get_all_docs(
|
||||
client: WebClient,
|
||||
workspace: str,
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
oldest: str | None = None,
|
||||
@@ -310,7 +308,6 @@ def _get_all_docs(
|
||||
if filtered_thread:
|
||||
channel_docs += 1
|
||||
yield thread_to_doc(
|
||||
workspace=workspace,
|
||||
channel=channel,
|
||||
thread=filtered_thread,
|
||||
slack_cleaner=slack_cleaner,
|
||||
@@ -373,14 +370,12 @@ def _get_all_doc_ids(
|
||||
class SlackPollConnector(PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
workspace: str,
|
||||
channels: list[str] | None = None,
|
||||
# if specified, will treat the specified channel strings as
|
||||
# regexes, and will only index channels that fully match the regexes
|
||||
channel_regex_enabled: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.workspace = workspace
|
||||
self.channels = channels
|
||||
self.channel_regex_enabled = channel_regex_enabled
|
||||
self.batch_size = batch_size
|
||||
@@ -414,7 +409,6 @@ class SlackPollConnector(PollConnector, SlimConnector):
|
||||
documents: list[Document] = []
|
||||
for document in _get_all_docs(
|
||||
client=self.client,
|
||||
workspace=self.workspace,
|
||||
channels=self.channels,
|
||||
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||
# NOTE: need to impute to `None` instead of using 0.0, since Slack will
|
||||
@@ -438,7 +432,6 @@ if __name__ == "__main__":
|
||||
|
||||
slack_channel = os.environ.get("SLACK_CHANNEL")
|
||||
connector = SlackPollConnector(
|
||||
workspace=os.environ["SLACK_WORKSPACE"],
|
||||
channels=[slack_channel] if slack_channel else None,
|
||||
)
|
||||
connector.load_credentials({"slack_bot_token": os.environ["SLACK_BOT_TOKEN"]})
|
||||
|
||||
@@ -1,140 +0,0 @@
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.slack.connector import filter_channels
|
||||
from danswer.connectors.slack.utils import get_message_link
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_event_time(event: dict[str, Any]) -> datetime | None:
|
||||
ts = event.get("ts")
|
||||
if not ts:
|
||||
return None
|
||||
return datetime.fromtimestamp(float(ts), tz=timezone.utc)
|
||||
|
||||
|
||||
class SlackLoadConnector(LoadConnector):
|
||||
# WARNING: DEPRECATED, DO NOT USE
|
||||
def __init__(
|
||||
self,
|
||||
workspace: str,
|
||||
export_path_str: str,
|
||||
channels: list[str] | None = None,
|
||||
# if specified, will treat the specified channel strings as
|
||||
# regexes, and will only index channels that fully match the regexes
|
||||
channel_regex_enabled: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.workspace = workspace
|
||||
self.channels = channels
|
||||
self.channel_regex_enabled = channel_regex_enabled
|
||||
self.export_path_str = export_path_str
|
||||
self.batch_size = batch_size
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if credentials:
|
||||
logger.warning("Unexpected credentials provided for Slack Load Connector")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _process_batch_event(
|
||||
slack_event: dict[str, Any],
|
||||
channel: dict[str, Any],
|
||||
matching_doc: Document | None,
|
||||
workspace: str,
|
||||
) -> Document | None:
|
||||
if (
|
||||
slack_event["type"] == "message"
|
||||
and slack_event.get("subtype") != "channel_join"
|
||||
):
|
||||
if matching_doc:
|
||||
return Document(
|
||||
id=matching_doc.id,
|
||||
sections=matching_doc.sections
|
||||
+ [
|
||||
Section(
|
||||
link=get_message_link(
|
||||
event=slack_event,
|
||||
workspace=workspace,
|
||||
channel_id=channel["id"],
|
||||
),
|
||||
text=slack_event["text"],
|
||||
)
|
||||
],
|
||||
source=matching_doc.source,
|
||||
semantic_identifier=matching_doc.semantic_identifier,
|
||||
title="", # slack docs don't really have a "title"
|
||||
doc_updated_at=get_event_time(slack_event),
|
||||
metadata=matching_doc.metadata,
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=slack_event["ts"],
|
||||
sections=[
|
||||
Section(
|
||||
link=get_message_link(
|
||||
event=slack_event,
|
||||
workspace=workspace,
|
||||
channel_id=channel["id"],
|
||||
),
|
||||
text=slack_event["text"],
|
||||
)
|
||||
],
|
||||
source=DocumentSource.SLACK,
|
||||
semantic_identifier=channel["name"],
|
||||
title="", # slack docs don't really have a "title"
|
||||
doc_updated_at=get_event_time(slack_event),
|
||||
metadata={},
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
export_path = Path(self.export_path_str)
|
||||
|
||||
with open(export_path / "channels.json") as f:
|
||||
all_channels = json.load(f)
|
||||
|
||||
filtered_channels = filter_channels(
|
||||
all_channels, self.channels, self.channel_regex_enabled
|
||||
)
|
||||
|
||||
document_batch: dict[str, Document] = {}
|
||||
for channel_info in filtered_channels:
|
||||
channel_dir_path = export_path / cast(str, channel_info["name"])
|
||||
channel_file_paths = [
|
||||
channel_dir_path / file_name
|
||||
for file_name in os.listdir(channel_dir_path)
|
||||
]
|
||||
for path in channel_file_paths:
|
||||
with open(path) as f:
|
||||
events = cast(list[dict[str, Any]], json.load(f))
|
||||
for slack_event in events:
|
||||
doc = self._process_batch_event(
|
||||
slack_event=slack_event,
|
||||
channel=channel_info,
|
||||
matching_doc=document_batch.get(
|
||||
slack_event.get("thread_ts", "")
|
||||
),
|
||||
workspace=self.workspace,
|
||||
)
|
||||
if doc:
|
||||
document_batch[doc.id] = doc
|
||||
if len(document_batch) >= self.batch_size:
|
||||
yield list(document_batch.values())
|
||||
|
||||
yield list(document_batch.values())
|
||||
@@ -2,6 +2,7 @@ import re
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from functools import lru_cache
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -21,19 +22,21 @@ basic_retry_wrapper = retry_builder()
|
||||
_SLACK_LIMIT = 900
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_base_url(token: str) -> str:
|
||||
"""Retrieve and cache the base URL of the Slack workspace based on the client token."""
|
||||
client = WebClient(token=token)
|
||||
return client.auth_test()["url"]
|
||||
|
||||
|
||||
def get_message_link(
|
||||
event: dict[str, Any], workspace: str, channel_id: str | None = None
|
||||
event: dict[str, Any], client: WebClient, channel_id: str | None = None
|
||||
) -> str:
|
||||
channel_id = channel_id or cast(
|
||||
str, event["channel"]
|
||||
) # channel must either be present in the event or passed in
|
||||
message_ts = cast(str, event["ts"])
|
||||
message_ts_without_dot = message_ts.replace(".", "")
|
||||
thread_ts = cast(str | None, event.get("thread_ts"))
|
||||
return (
|
||||
f"https://{workspace}.slack.com/archives/{channel_id}/p{message_ts_without_dot}"
|
||||
+ (f"?thread_ts={thread_ts}" if thread_ts else "")
|
||||
)
|
||||
channel_id = channel_id or event["channel"]
|
||||
message_ts = event["ts"]
|
||||
response = client.chat_getPermalink(channel=channel_id, message_ts=message_ts)
|
||||
permalink = response["permalink"]
|
||||
return permalink
|
||||
|
||||
|
||||
def _make_slack_api_call_logged(
|
||||
|
||||
@@ -33,7 +33,7 @@ def get_created_datetime(chat_message: ChatMessage) -> datetime:
|
||||
|
||||
def _extract_channel_members(channel: Channel) -> list[BasicExpertInfo]:
|
||||
channel_members_list: list[BasicExpertInfo] = []
|
||||
members = channel.members.get().execute_query()
|
||||
members = channel.members.get().execute_query_retry()
|
||||
for member in members:
|
||||
channel_members_list.append(BasicExpertInfo(display_name=member.display_name))
|
||||
return channel_members_list
|
||||
@@ -51,7 +51,7 @@ def _get_threads_from_channel(
|
||||
end = end.replace(tzinfo=timezone.utc)
|
||||
|
||||
query = channel.messages.get()
|
||||
base_messages: list[ChatMessage] = query.execute_query()
|
||||
base_messages: list[ChatMessage] = query.execute_query_retry()
|
||||
|
||||
threads: list[list[ChatMessage]] = []
|
||||
for base_message in base_messages:
|
||||
@@ -65,7 +65,7 @@ def _get_threads_from_channel(
|
||||
continue
|
||||
|
||||
reply_query = base_message.replies.get_all()
|
||||
replies = reply_query.execute_query()
|
||||
replies = reply_query.execute_query_retry()
|
||||
|
||||
# start a list containing the base message and its replies
|
||||
thread: list[ChatMessage] = [base_message]
|
||||
@@ -82,7 +82,7 @@ def _get_channels_from_teams(
|
||||
channels_list: list[Channel] = []
|
||||
for team in teams:
|
||||
query = team.channels.get()
|
||||
channels = query.execute_query()
|
||||
channels = query.execute_query_retry()
|
||||
channels_list.extend(channels)
|
||||
|
||||
return channels_list
|
||||
@@ -210,7 +210,7 @@ class TeamsConnector(LoadConnector, PollConnector):
|
||||
|
||||
teams_list: list[Team] = []
|
||||
|
||||
teams = self.graph_client.teams.get().execute_query()
|
||||
teams = self.graph_client.teams.get().execute_query_retry()
|
||||
|
||||
if len(self.requested_team_list) > 0:
|
||||
adjusted_request_strings = [
|
||||
@@ -234,14 +234,25 @@ class TeamsConnector(LoadConnector, PollConnector):
|
||||
raise ConnectorMissingCredentialError("Teams")
|
||||
|
||||
teams = self._get_all_teams()
|
||||
logger.debug(f"Found available teams: {[str(t) for t in teams]}")
|
||||
if not teams:
|
||||
msg = "No teams found."
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
channels = _get_channels_from_teams(
|
||||
teams=teams,
|
||||
)
|
||||
logger.debug(f"Found available channels: {[c.id for c in channels]}")
|
||||
if not channels:
|
||||
msg = "No channels found."
|
||||
logger.error(msg)
|
||||
raise ValueError(msg)
|
||||
|
||||
# goes over channels, converts them into Document objects and then yields them in batches
|
||||
doc_batch: list[Document] = []
|
||||
for channel in channels:
|
||||
logger.debug(f"Fetching threads from channel: {channel.id}")
|
||||
thread_list = _get_threads_from_channel(channel, start=start, end=end)
|
||||
for thread in thread_list:
|
||||
converted_doc = _convert_thread_to_document(channel, thread)
|
||||
@@ -259,8 +270,8 @@ class TeamsConnector(LoadConnector, PollConnector):
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.utcfromtimestamp(start)
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
start_datetime = datetime.fromtimestamp(start, timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, timezone.utc)
|
||||
return self._fetch_from_teams(start=start_datetime, end=end_datetime)
|
||||
|
||||
|
||||
|
||||
@@ -204,7 +204,8 @@ def _build_documents_blocks(
|
||||
continue
|
||||
seen_docs_identifiers.add(d.document_id)
|
||||
|
||||
doc_sem_id = d.semantic_identifier
|
||||
# Strip newlines from the semantic identifier for Slackbot formatting
|
||||
doc_sem_id = d.semantic_identifier.replace("\n", " ")
|
||||
if d.source_type == DocumentSource.SLACK.value:
|
||||
doc_sem_id = "#" + doc_sem_id
|
||||
|
||||
|
||||
@@ -373,7 +373,9 @@ def handle_regular_answer(
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=receiver_ids,
|
||||
receiver_ids=[message_info.sender]
|
||||
if message_info.is_bot_msg and message_info.sender
|
||||
else receiver_ids,
|
||||
text="Hello! Danswer has some results for you!",
|
||||
blocks=all_blocks,
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
|
||||
@@ -11,6 +11,7 @@ from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.models.blocks import Block
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
from slack_sdk.models.metadata import Metadata
|
||||
from slack_sdk.socket_mode import SocketModeClient
|
||||
|
||||
@@ -140,6 +141,40 @@ def remove_danswer_bot_tag(message_str: str, client: WebClient) -> str:
|
||||
return re.sub(rf"<@{bot_tag_id}>\s", "", message_str)
|
||||
|
||||
|
||||
def _check_for_url_in_block(block: Block) -> bool:
|
||||
"""
|
||||
Check if the block has a key that contains "url" in it
|
||||
"""
|
||||
block_dict = block.to_dict()
|
||||
|
||||
def check_dict_for_url(d: dict) -> bool:
|
||||
for key, value in d.items():
|
||||
if "url" in key.lower():
|
||||
return True
|
||||
if isinstance(value, dict):
|
||||
if check_dict_for_url(value):
|
||||
return True
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and check_dict_for_url(item):
|
||||
return True
|
||||
return False
|
||||
|
||||
return check_dict_for_url(block_dict)
|
||||
|
||||
|
||||
def _build_error_block(error_message: str) -> Block:
|
||||
"""
|
||||
Build an error block to display in slack so that the user can see
|
||||
the error without completely breaking
|
||||
"""
|
||||
display_text = (
|
||||
"There was an error displaying all of the Onyx answers."
|
||||
f" Please let an admin or an onyx developer know. Error: {error_message}"
|
||||
)
|
||||
return SectionBlock(text=display_text)
|
||||
|
||||
|
||||
@retry(
|
||||
tries=DANSWER_BOT_NUM_RETRIES,
|
||||
delay=0.25,
|
||||
@@ -162,24 +197,9 @@ def respond_in_thread(
|
||||
message_ids: list[str] = []
|
||||
if not receiver_ids:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postMessage)
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
text=text,
|
||||
blocks=blocks,
|
||||
thread_ts=thread_ts,
|
||||
metadata=metadata,
|
||||
unfurl_links=unfurl,
|
||||
unfurl_media=unfurl,
|
||||
)
|
||||
if not response.get("ok"):
|
||||
raise RuntimeError(f"Failed to post message: {response}")
|
||||
message_ids.append(response["message_ts"])
|
||||
else:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
|
||||
for receiver in receiver_ids:
|
||||
try:
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
user=receiver,
|
||||
text=text,
|
||||
blocks=blocks,
|
||||
thread_ts=thread_ts,
|
||||
@@ -187,8 +207,68 @@ def respond_in_thread(
|
||||
unfurl_links=unfurl,
|
||||
unfurl_media=unfurl,
|
||||
)
|
||||
if not response.get("ok"):
|
||||
raise RuntimeError(f"Failed to post message: {response}")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to post message: {e} \n blocks: {blocks}")
|
||||
logger.warning("Trying again without blocks that have urls")
|
||||
|
||||
if not blocks:
|
||||
raise e
|
||||
|
||||
blocks_without_urls = [
|
||||
block for block in blocks if not _check_for_url_in_block(block)
|
||||
]
|
||||
blocks_without_urls.append(_build_error_block(str(e)))
|
||||
|
||||
# Try again wtihout blocks containing url
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
text=text,
|
||||
blocks=blocks_without_urls,
|
||||
thread_ts=thread_ts,
|
||||
metadata=metadata,
|
||||
unfurl_links=unfurl,
|
||||
unfurl_media=unfurl,
|
||||
)
|
||||
|
||||
message_ids.append(response["message_ts"])
|
||||
else:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
|
||||
for receiver in receiver_ids:
|
||||
try:
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
user=receiver,
|
||||
text=text,
|
||||
blocks=blocks,
|
||||
thread_ts=thread_ts,
|
||||
metadata=metadata,
|
||||
unfurl_links=unfurl,
|
||||
unfurl_media=unfurl,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to post message: {e} \n blocks: {blocks}")
|
||||
logger.warning("Trying again without blocks that have urls")
|
||||
|
||||
if not blocks:
|
||||
raise e
|
||||
|
||||
blocks_without_urls = [
|
||||
block for block in blocks if not _check_for_url_in_block(block)
|
||||
]
|
||||
blocks_without_urls.append(_build_error_block(str(e)))
|
||||
|
||||
# Try again wtihout blocks containing url
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
user=receiver,
|
||||
text=text,
|
||||
blocks=blocks_without_urls,
|
||||
thread_ts=thread_ts,
|
||||
metadata=metadata,
|
||||
unfurl_links=unfurl,
|
||||
unfurl_media=unfurl,
|
||||
)
|
||||
|
||||
message_ids.append(response["message_ts"])
|
||||
|
||||
return message_ids
|
||||
|
||||
@@ -20,7 +20,6 @@ from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import CredentialDataUpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -262,7 +261,8 @@ def _cleanup_credential__user_group_relationships__no_commit(
|
||||
|
||||
def alter_credential(
|
||||
credential_id: int,
|
||||
credential_data: CredentialDataUpdateRequest,
|
||||
name: str,
|
||||
credential_json: dict[str, Any],
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> Credential | None:
|
||||
@@ -272,11 +272,13 @@ def alter_credential(
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
credential.name = credential_data.name
|
||||
credential.name = name
|
||||
|
||||
# Update only the keys present in credential_data.credential_json
|
||||
for key, value in credential_data.credential_json.items():
|
||||
credential.credential_json[key] = value
|
||||
# Assign a new dictionary to credential.credential_json
|
||||
credential.credential_json = {
|
||||
**credential.credential_json,
|
||||
**credential_json,
|
||||
}
|
||||
|
||||
credential.user_id = user.id if user is not None else None
|
||||
db_session.commit()
|
||||
@@ -309,8 +311,8 @@ def update_credential_json(
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
if credential is None:
|
||||
return None
|
||||
credential.credential_json = credential_json
|
||||
|
||||
credential.credential_json = credential_json
|
||||
db_session.commit()
|
||||
return credential
|
||||
|
||||
|
||||
@@ -522,12 +522,16 @@ def expire_index_attempts(
|
||||
search_settings_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
delete_query = (
|
||||
delete(IndexAttempt)
|
||||
not_started_query = (
|
||||
update(IndexAttempt)
|
||||
.where(IndexAttempt.search_settings_id == search_settings_id)
|
||||
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
|
||||
.values(
|
||||
status=IndexingStatus.CANCELED,
|
||||
error_msg="Canceled, likely due to model swap",
|
||||
)
|
||||
)
|
||||
db_session.execute(delete_query)
|
||||
db_session.execute(not_started_query)
|
||||
|
||||
update_query = (
|
||||
update(IndexAttempt)
|
||||
@@ -549,9 +553,14 @@ def cancel_indexing_attempts_for_ccpair(
|
||||
include_secondary_index: bool = False,
|
||||
) -> None:
|
||||
stmt = (
|
||||
delete(IndexAttempt)
|
||||
update(IndexAttempt)
|
||||
.where(IndexAttempt.connector_credential_pair_id == cc_pair_id)
|
||||
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
|
||||
.values(
|
||||
status=IndexingStatus.CANCELED,
|
||||
error_msg="Canceled by user",
|
||||
time_started=datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
|
||||
if not include_secondary_index:
|
||||
|
||||
@@ -1,202 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import InputPrompt
|
||||
from danswer.db.models import User
|
||||
from danswer.server.features.input_prompt.models import InputPromptSnapshot
|
||||
from danswer.server.manage.models import UserInfo
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def insert_input_prompt_if_not_exists(
|
||||
user: User | None,
|
||||
input_prompt_id: int | None,
|
||||
prompt: str,
|
||||
content: str,
|
||||
active: bool,
|
||||
is_public: bool,
|
||||
db_session: Session,
|
||||
commit: bool = True,
|
||||
) -> InputPrompt:
|
||||
if input_prompt_id is not None:
|
||||
input_prompt = (
|
||||
db_session.query(InputPrompt).filter_by(id=input_prompt_id).first()
|
||||
)
|
||||
else:
|
||||
query = db_session.query(InputPrompt).filter(InputPrompt.prompt == prompt)
|
||||
if user:
|
||||
query = query.filter(InputPrompt.user_id == user.id)
|
||||
else:
|
||||
query = query.filter(InputPrompt.user_id.is_(None))
|
||||
input_prompt = query.first()
|
||||
|
||||
if input_prompt is None:
|
||||
input_prompt = InputPrompt(
|
||||
id=input_prompt_id,
|
||||
prompt=prompt,
|
||||
content=content,
|
||||
active=active,
|
||||
is_public=is_public or user is None,
|
||||
user_id=user.id if user else None,
|
||||
)
|
||||
db_session.add(input_prompt)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
|
||||
return input_prompt
|
||||
|
||||
|
||||
def insert_input_prompt(
|
||||
prompt: str,
|
||||
content: str,
|
||||
is_public: bool,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> InputPrompt:
|
||||
input_prompt = InputPrompt(
|
||||
prompt=prompt,
|
||||
content=content,
|
||||
active=True,
|
||||
is_public=is_public or user is None,
|
||||
user_id=user.id if user is not None else None,
|
||||
)
|
||||
db_session.add(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
return input_prompt
|
||||
|
||||
|
||||
def update_input_prompt(
|
||||
user: User | None,
|
||||
input_prompt_id: int,
|
||||
prompt: str,
|
||||
content: str,
|
||||
active: bool,
|
||||
db_session: Session,
|
||||
) -> InputPrompt:
|
||||
input_prompt = db_session.scalar(
|
||||
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
|
||||
)
|
||||
if input_prompt is None:
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if not validate_user_prompt_authorization(user, input_prompt):
|
||||
raise HTTPException(status_code=401, detail="You don't own this prompt")
|
||||
|
||||
input_prompt.prompt = prompt
|
||||
input_prompt.content = content
|
||||
input_prompt.active = active
|
||||
|
||||
db_session.commit()
|
||||
return input_prompt
|
||||
|
||||
|
||||
def validate_user_prompt_authorization(
|
||||
user: User | None, input_prompt: InputPrompt
|
||||
) -> bool:
|
||||
prompt = InputPromptSnapshot.from_model(input_prompt=input_prompt)
|
||||
|
||||
if prompt.user_id is not None:
|
||||
if user is None:
|
||||
return False
|
||||
|
||||
user_details = UserInfo.from_model(user)
|
||||
if str(user_details.id) != str(prompt.user_id):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def remove_public_input_prompt(input_prompt_id: int, db_session: Session) -> None:
|
||||
input_prompt = db_session.scalar(
|
||||
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
|
||||
)
|
||||
|
||||
if input_prompt is None:
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if not input_prompt.is_public:
|
||||
raise HTTPException(status_code=400, detail="This prompt is not public")
|
||||
|
||||
db_session.delete(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_input_prompt(
|
||||
user: User | None, input_prompt_id: int, db_session: Session
|
||||
) -> None:
|
||||
input_prompt = db_session.scalar(
|
||||
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
|
||||
)
|
||||
if input_prompt is None:
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if input_prompt.is_public:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Cannot delete public prompts with this method"
|
||||
)
|
||||
|
||||
if not validate_user_prompt_authorization(user, input_prompt):
|
||||
raise HTTPException(status_code=401, detail="You do not own this prompt")
|
||||
|
||||
db_session.delete(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def fetch_input_prompt_by_id(
|
||||
id: int, user_id: UUID | None, db_session: Session
|
||||
) -> InputPrompt:
|
||||
query = select(InputPrompt).where(InputPrompt.id == id)
|
||||
|
||||
if user_id:
|
||||
query = query.where(
|
||||
(InputPrompt.user_id == user_id) | (InputPrompt.user_id is None)
|
||||
)
|
||||
else:
|
||||
# If no user_id is provided, only fetch prompts without a user_id (aka public)
|
||||
query = query.where(InputPrompt.user_id == None) # noqa
|
||||
|
||||
result = db_session.scalar(query)
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(422, "No input prompt found")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def fetch_public_input_prompts(
|
||||
db_session: Session,
|
||||
) -> list[InputPrompt]:
|
||||
query = select(InputPrompt).where(InputPrompt.is_public)
|
||||
return list(db_session.scalars(query).all())
|
||||
|
||||
|
||||
def fetch_input_prompts_by_user(
|
||||
db_session: Session,
|
||||
user_id: UUID | None,
|
||||
active: bool | None = None,
|
||||
include_public: bool = False,
|
||||
) -> list[InputPrompt]:
|
||||
query = select(InputPrompt)
|
||||
|
||||
if user_id is not None:
|
||||
if include_public:
|
||||
query = query.where(
|
||||
(InputPrompt.user_id == user_id) | InputPrompt.is_public
|
||||
)
|
||||
else:
|
||||
query = query.where(InputPrompt.user_id == user_id)
|
||||
|
||||
elif include_public:
|
||||
query = query.where(InputPrompt.is_public)
|
||||
|
||||
if active is not None:
|
||||
query = query.where(InputPrompt.active == active)
|
||||
|
||||
return list(db_session.scalars(query).all())
|
||||
@@ -159,9 +159,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
)
|
||||
|
||||
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
|
||||
input_prompts: Mapped[list["InputPrompt"]] = relationship(
|
||||
"InputPrompt", back_populates="user"
|
||||
)
|
||||
|
||||
# Personas owned by this user
|
||||
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
|
||||
@@ -178,31 +175,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
)
|
||||
|
||||
|
||||
class InputPrompt(Base):
|
||||
__tablename__ = "inputprompt"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
prompt: Mapped[str] = mapped_column(String)
|
||||
content: Mapped[str] = mapped_column(String)
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
user: Mapped[User | None] = relationship("User", back_populates="input_prompts")
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class InputPrompt__User(Base):
|
||||
__tablename__ = "inputprompt__user"
|
||||
|
||||
input_prompt_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("inputprompt.id"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("inputprompt.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
|
||||
pass
|
||||
|
||||
@@ -596,6 +568,25 @@ class Connector(Base):
|
||||
list["DocumentByConnectorCredentialPair"]
|
||||
] = relationship("DocumentByConnectorCredentialPair", back_populates="connector")
|
||||
|
||||
# synchronize this validation logic with RefreshFrequencySchema etc on front end
|
||||
# until we have a centralized validation schema
|
||||
|
||||
# TODO(rkuo): experiment with SQLAlchemy validators rather than manual checks
|
||||
# https://docs.sqlalchemy.org/en/20/orm/mapped_attributes.html
|
||||
def validate_refresh_freq(self) -> None:
|
||||
if self.refresh_freq is not None:
|
||||
if self.refresh_freq < 60:
|
||||
raise ValueError(
|
||||
"refresh_freq must be greater than or equal to 60 seconds."
|
||||
)
|
||||
|
||||
def validate_prune_freq(self) -> None:
|
||||
if self.prune_freq is not None:
|
||||
if self.prune_freq < 86400:
|
||||
raise ValueError(
|
||||
"prune_freq must be greater than or equal to 86400 seconds."
|
||||
)
|
||||
|
||||
|
||||
class Credential(Base):
|
||||
__tablename__ = "credential"
|
||||
@@ -1530,6 +1521,7 @@ class SlackBot(Base):
|
||||
slack_channel_configs: Mapped[list[SlackChannelConfig]] = relationship(
|
||||
"SlackChannelConfig",
|
||||
back_populates="slack_bot",
|
||||
cascade="all, delete-orphan",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -453,9 +453,9 @@ def upsert_persona(
|
||||
"""
|
||||
|
||||
if persona_id is not None:
|
||||
persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
||||
existing_persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
||||
else:
|
||||
persona = _get_persona_by_name(
|
||||
existing_persona = _get_persona_by_name(
|
||||
persona_name=name, user=user, db_session=db_session
|
||||
)
|
||||
|
||||
@@ -481,62 +481,78 @@ def upsert_persona(
|
||||
prompts = None
|
||||
if prompt_ids is not None:
|
||||
prompts = db_session.query(Prompt).filter(Prompt.id.in_(prompt_ids)).all()
|
||||
if not prompts and prompt_ids:
|
||||
raise ValueError("prompts not found")
|
||||
|
||||
if prompts is not None and len(prompts) == 0:
|
||||
raise ValueError(
|
||||
f"Invalid Persona config, no valid prompts "
|
||||
f"specified. Specified IDs were: '{prompt_ids}'"
|
||||
)
|
||||
|
||||
# ensure all specified tools are valid
|
||||
if tools:
|
||||
validate_persona_tools(tools)
|
||||
|
||||
if persona:
|
||||
if existing_persona:
|
||||
# Built-in personas can only be updated through YAML configuration.
|
||||
# This ensures that core system personas are not modified unintentionally.
|
||||
if persona.builtin_persona and not builtin_persona:
|
||||
if existing_persona.builtin_persona and not builtin_persona:
|
||||
raise ValueError("Cannot update builtin persona with non-builtin.")
|
||||
|
||||
# this checks if the user has permission to edit the persona
|
||||
persona = fetch_persona_by_id(
|
||||
db_session=db_session, persona_id=persona.id, user=user, get_editable=True
|
||||
# will raise an Exception if the user does not have permission
|
||||
existing_persona = fetch_persona_by_id(
|
||||
db_session=db_session,
|
||||
persona_id=existing_persona.id,
|
||||
user=user,
|
||||
get_editable=True,
|
||||
)
|
||||
|
||||
# The following update excludes `default`, `built-in`, and display priority.
|
||||
# Display priority is handled separately in the `display-priority` endpoint.
|
||||
# `default` and `built-in` properties can only be set when creating a persona.
|
||||
persona.name = name
|
||||
persona.description = description
|
||||
persona.num_chunks = num_chunks
|
||||
persona.chunks_above = chunks_above
|
||||
persona.chunks_below = chunks_below
|
||||
persona.llm_relevance_filter = llm_relevance_filter
|
||||
persona.llm_filter_extraction = llm_filter_extraction
|
||||
persona.recency_bias = recency_bias
|
||||
persona.llm_model_provider_override = llm_model_provider_override
|
||||
persona.llm_model_version_override = llm_model_version_override
|
||||
persona.starter_messages = starter_messages
|
||||
persona.deleted = False # Un-delete if previously deleted
|
||||
persona.is_public = is_public
|
||||
persona.icon_color = icon_color
|
||||
persona.icon_shape = icon_shape
|
||||
existing_persona.name = name
|
||||
existing_persona.description = description
|
||||
existing_persona.num_chunks = num_chunks
|
||||
existing_persona.chunks_above = chunks_above
|
||||
existing_persona.chunks_below = chunks_below
|
||||
existing_persona.llm_relevance_filter = llm_relevance_filter
|
||||
existing_persona.llm_filter_extraction = llm_filter_extraction
|
||||
existing_persona.recency_bias = recency_bias
|
||||
existing_persona.llm_model_provider_override = llm_model_provider_override
|
||||
existing_persona.llm_model_version_override = llm_model_version_override
|
||||
existing_persona.starter_messages = starter_messages
|
||||
existing_persona.deleted = False # Un-delete if previously deleted
|
||||
existing_persona.is_public = is_public
|
||||
existing_persona.icon_color = icon_color
|
||||
existing_persona.icon_shape = icon_shape
|
||||
if remove_image or uploaded_image_id:
|
||||
persona.uploaded_image_id = uploaded_image_id
|
||||
persona.is_visible = is_visible
|
||||
persona.search_start_date = search_start_date
|
||||
persona.category_id = category_id
|
||||
existing_persona.uploaded_image_id = uploaded_image_id
|
||||
existing_persona.is_visible = is_visible
|
||||
existing_persona.search_start_date = search_start_date
|
||||
existing_persona.category_id = category_id
|
||||
# Do not delete any associations manually added unless
|
||||
# a new updated list is provided
|
||||
if document_sets is not None:
|
||||
persona.document_sets.clear()
|
||||
persona.document_sets = document_sets or []
|
||||
existing_persona.document_sets.clear()
|
||||
existing_persona.document_sets = document_sets or []
|
||||
|
||||
if prompts is not None:
|
||||
persona.prompts.clear()
|
||||
persona.prompts = prompts or []
|
||||
existing_persona.prompts.clear()
|
||||
existing_persona.prompts = prompts
|
||||
|
||||
if tools is not None:
|
||||
persona.tools = tools or []
|
||||
existing_persona.tools = tools or []
|
||||
|
||||
persona = existing_persona
|
||||
|
||||
else:
|
||||
persona = Persona(
|
||||
if not prompts:
|
||||
raise ValueError(
|
||||
"Invalid Persona config. "
|
||||
"Must specify at least one prompt for a new persona."
|
||||
)
|
||||
|
||||
new_persona = Persona(
|
||||
id=persona_id,
|
||||
user_id=user.id if user else None,
|
||||
is_public=is_public,
|
||||
@@ -549,7 +565,7 @@ def upsert_persona(
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
builtin_persona=builtin_persona,
|
||||
prompts=prompts or [],
|
||||
prompts=prompts,
|
||||
document_sets=document_sets or [],
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
llm_model_version_override=llm_model_version_override,
|
||||
@@ -564,8 +580,8 @@ def upsert_persona(
|
||||
is_default_persona=is_default_persona,
|
||||
category_id=category_id,
|
||||
)
|
||||
db_session.add(persona)
|
||||
|
||||
db_session.add(new_persona)
|
||||
persona = new_persona
|
||||
if commit:
|
||||
db_session.commit()
|
||||
else:
|
||||
|
||||
@@ -148,6 +148,7 @@ class Indexable(abc.ABC):
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
fresh_index: bool = False,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""
|
||||
Takes a list of document chunks and indexes them in the document index
|
||||
@@ -165,9 +166,14 @@ class Indexable(abc.ABC):
|
||||
only needs to index chunks into the PRIMARY index. Do not update the secondary index here,
|
||||
it is done automatically outside of this code.
|
||||
|
||||
NOTE: The fresh_index parameter, when set to True, assumes no documents have been previously
|
||||
indexed for the given index/tenant. This can be used to optimize the indexing process for
|
||||
new or empty indices.
|
||||
|
||||
Parameters:
|
||||
- chunks: Document chunks with all of the information needed for indexing to the document
|
||||
index.
|
||||
- fresh_index: Boolean indicating whether this is a fresh index with no existing documents.
|
||||
|
||||
Returns:
|
||||
List of document ids which map to unique documents and are used for deduping chunks
|
||||
|
||||
@@ -306,6 +306,7 @@ class VespaIndex(DocumentIndex):
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
fresh_index: bool = False,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""Receive a list of chunks from a batch of documents and index the chunks into Vespa along
|
||||
with updating the associated permissions. Assumes that a document will not be split into
|
||||
@@ -322,26 +323,29 @@ class VespaIndex(DocumentIndex):
|
||||
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
|
||||
get_vespa_http_client() as http_client,
|
||||
):
|
||||
# Check for existing documents, existing documents need to have all of their chunks deleted
|
||||
# prior to indexing as the document size (num chunks) may have shrunk
|
||||
first_chunks = [chunk for chunk in cleaned_chunks if chunk.chunk_id == 0]
|
||||
for chunk_batch in batch_generator(first_chunks, BATCH_SIZE):
|
||||
existing_docs.update(
|
||||
get_existing_documents_from_chunks(
|
||||
chunks=chunk_batch,
|
||||
if not fresh_index:
|
||||
# Check for existing documents, existing documents need to have all of their chunks deleted
|
||||
# prior to indexing as the document size (num chunks) may have shrunk
|
||||
first_chunks = [
|
||||
chunk for chunk in cleaned_chunks if chunk.chunk_id == 0
|
||||
]
|
||||
for chunk_batch in batch_generator(first_chunks, BATCH_SIZE):
|
||||
existing_docs.update(
|
||||
get_existing_documents_from_chunks(
|
||||
chunks=chunk_batch,
|
||||
index_name=self.index_name,
|
||||
http_client=http_client,
|
||||
executor=executor,
|
||||
)
|
||||
)
|
||||
|
||||
for doc_id_batch in batch_generator(existing_docs, BATCH_SIZE):
|
||||
delete_vespa_docs(
|
||||
document_ids=doc_id_batch,
|
||||
index_name=self.index_name,
|
||||
http_client=http_client,
|
||||
executor=executor,
|
||||
)
|
||||
)
|
||||
|
||||
for doc_id_batch in batch_generator(existing_docs, BATCH_SIZE):
|
||||
delete_vespa_docs(
|
||||
document_ids=doc_id_batch,
|
||||
index_name=self.index_name,
|
||||
http_client=http_client,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
|
||||
batch_index_vespa_chunks(
|
||||
|
||||
@@ -70,7 +70,7 @@ def get_file_ext(file_path_or_name: str | Path) -> str:
|
||||
return extension
|
||||
|
||||
|
||||
def check_file_ext_is_valid(ext: str) -> bool:
|
||||
def is_valid_file_ext(ext: str) -> bool:
|
||||
return ext in VALID_FILE_EXTENSIONS
|
||||
|
||||
|
||||
@@ -364,7 +364,7 @@ def extract_file_text(
|
||||
elif file_name is not None:
|
||||
final_extension = get_file_ext(file_name)
|
||||
|
||||
if check_file_ext_is_valid(final_extension):
|
||||
if is_valid_file_ext(final_extension):
|
||||
return extension_to_function.get(final_extension, file_io_to_text)(file)
|
||||
|
||||
# Either the file somehow has no name or the extension is not one that we recognize
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from typing import Protocol
|
||||
@@ -12,6 +13,7 @@ from danswer.access.access import get_access_for_documents
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.configs.app_configs import ENABLE_MULTIPASS_INDEXING
|
||||
from danswer.configs.app_configs import INDEXING_EXCEPTION_LIMIT
|
||||
from danswer.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
@@ -202,40 +204,13 @@ def index_doc_batch_with_handler(
|
||||
|
||||
|
||||
def index_doc_batch_prepare(
|
||||
document_batch: list[Document],
|
||||
documents: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
ignore_time_skip: bool = False,
|
||||
) -> DocumentBatchPrepareContext | None:
|
||||
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
|
||||
This preceeds indexing it into the actual document index."""
|
||||
documents: list[Document] = []
|
||||
for document in document_batch:
|
||||
empty_contents = not any(section.text.strip() for section in document.sections)
|
||||
if (
|
||||
(not document.title or not document.title.strip())
|
||||
and not document.semantic_identifier.strip()
|
||||
and empty_contents
|
||||
):
|
||||
# Skip documents that have neither title nor content
|
||||
# If the document doesn't have either, then there is no useful information in it
|
||||
# This is again verified later in the pipeline after chunking but at that point there should
|
||||
# already be no documents that are empty.
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as it has neither title nor content."
|
||||
)
|
||||
continue
|
||||
|
||||
if document.title is not None and not document.title.strip() and empty_contents:
|
||||
# The title is explicitly empty ("" and not None) and the document is empty
|
||||
# so when building the chunk text representation, it will be empty and unuseable
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as the chunks will be empty."
|
||||
)
|
||||
continue
|
||||
|
||||
documents.append(document)
|
||||
|
||||
# Create a trimmed list of docs that don't have a newer updated at
|
||||
# Shortcuts the time-consuming flow on connector index retries
|
||||
document_ids: list[str] = [document.id for document in documents]
|
||||
@@ -282,17 +257,64 @@ def index_doc_batch_prepare(
|
||||
)
|
||||
|
||||
|
||||
def filter_documents(document_batch: list[Document]) -> list[Document]:
|
||||
documents: list[Document] = []
|
||||
for document in document_batch:
|
||||
empty_contents = not any(section.text.strip() for section in document.sections)
|
||||
if (
|
||||
(not document.title or not document.title.strip())
|
||||
and not document.semantic_identifier.strip()
|
||||
and empty_contents
|
||||
):
|
||||
# Skip documents that have neither title nor content
|
||||
# If the document doesn't have either, then there is no useful information in it
|
||||
# This is again verified later in the pipeline after chunking but at that point there should
|
||||
# already be no documents that are empty.
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as it has neither title nor content."
|
||||
)
|
||||
continue
|
||||
|
||||
if document.title is not None and not document.title.strip() and empty_contents:
|
||||
# The title is explicitly empty ("" and not None) and the document is empty
|
||||
# so when building the chunk text representation, it will be empty and unuseable
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as the chunks will be empty."
|
||||
)
|
||||
continue
|
||||
|
||||
section_chars = sum(len(section.text) for section in document.sections)
|
||||
if (
|
||||
MAX_DOCUMENT_CHARS
|
||||
and len(document.title or document.semantic_identifier) + section_chars
|
||||
> MAX_DOCUMENT_CHARS
|
||||
):
|
||||
# Skip documents that are too long, later on there are more memory intensive steps done on the text
|
||||
# and the container will run out of memory and crash. Several other checks are included upstream but
|
||||
# those are at the connector level so a catchall is still needed.
|
||||
# Assumption here is that files that are that long, are generated files and not the type users
|
||||
# generally care for.
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as it is too long."
|
||||
)
|
||||
continue
|
||||
|
||||
documents.append(document)
|
||||
return documents
|
||||
|
||||
|
||||
@log_function_time(debug_only=True)
|
||||
def index_doc_batch(
|
||||
*,
|
||||
document_batch: list[Document],
|
||||
chunker: Chunker,
|
||||
embedder: IndexingEmbedder,
|
||||
document_index: DocumentIndex,
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
|
||||
) -> tuple[int, int]:
|
||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||
Note that the documents should already be batched at this point so that it does not inflate the
|
||||
@@ -309,8 +331,11 @@ def index_doc_batch(
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
logger.debug("Filtering Documents")
|
||||
filtered_documents = filter_fnc(document_batch)
|
||||
|
||||
ctx = index_doc_batch_prepare(
|
||||
document_batch=document_batch,
|
||||
documents=filtered_documents,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
ignore_time_skip=ignore_time_skip,
|
||||
db_session=db_session,
|
||||
|
||||
@@ -268,12 +268,16 @@ class DefaultMultiLLM(LLM):
|
||||
|
||||
# NOTE: have to set these as environment variables for Litellm since
|
||||
# not all are able to passed in but they always support them set as env
|
||||
# variables
|
||||
# variables. We'll also try passing them in, since litellm just ignores
|
||||
# addtional kwargs (and some kwargs MUST be passed in rather than set as
|
||||
# env variables)
|
||||
if custom_config:
|
||||
for k, v in custom_config.items():
|
||||
os.environ[k] = v
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
if custom_config:
|
||||
model_kwargs.update(custom_config)
|
||||
if extra_headers:
|
||||
model_kwargs.update({"extra_headers": extra_headers})
|
||||
if extra_body:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
@@ -7,7 +6,6 @@ from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import litellm # type: ignore
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValue
|
||||
@@ -100,53 +98,32 @@ def litellm_exception_to_error_msg(
|
||||
return error_msg
|
||||
|
||||
|
||||
# Processes CSV files to show the first 5 rows and max_columns (default 40) columns
|
||||
def _process_csv_file(file: InMemoryChatFile, max_columns: int = 40) -> str:
|
||||
df = pd.read_csv(io.StringIO(file.content.decode("utf-8")))
|
||||
|
||||
csv_preview = df.head().to_string(max_cols=max_columns)
|
||||
|
||||
file_name_section = (
|
||||
f"CSV FILE NAME: {file.filename}\n"
|
||||
if file.filename
|
||||
else "CSV FILE (NO NAME PROVIDED):\n"
|
||||
)
|
||||
return f"{file_name_section}{CODE_BLOCK_PAT.format(csv_preview)}\n\n\n"
|
||||
|
||||
|
||||
def _build_content(
|
||||
message: str,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
) -> str:
|
||||
"""Applies all non-image files."""
|
||||
text_files = (
|
||||
[file for file in files if file.file_type == ChatFileType.PLAIN_TEXT]
|
||||
if files
|
||||
else None
|
||||
)
|
||||
if not files:
|
||||
return message
|
||||
|
||||
csv_files = (
|
||||
[file for file in files if file.file_type == ChatFileType.CSV]
|
||||
if files
|
||||
else None
|
||||
)
|
||||
text_files = [
|
||||
file
|
||||
for file in files
|
||||
if file.file_type in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV)
|
||||
]
|
||||
|
||||
if not text_files and not csv_files:
|
||||
if not text_files:
|
||||
return message
|
||||
|
||||
final_message_with_files = "FILES:\n\n"
|
||||
for file in text_files or []:
|
||||
for file in text_files:
|
||||
file_content = file.content.decode("utf-8")
|
||||
file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else ""
|
||||
final_message_with_files += (
|
||||
f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n"
|
||||
)
|
||||
for file in csv_files or []:
|
||||
final_message_with_files += _process_csv_file(file)
|
||||
|
||||
final_message_with_files += message
|
||||
|
||||
return final_message_with_files
|
||||
return final_message_with_files + message
|
||||
|
||||
|
||||
def build_content_with_imgs(
|
||||
|
||||
@@ -52,12 +52,9 @@ from danswer.server.documents.connector import router as connector_router
|
||||
from danswer.server.documents.credential import router as credential_router
|
||||
from danswer.server.documents.document import router as document_router
|
||||
from danswer.server.documents.indexing import router as indexing_router
|
||||
from danswer.server.documents.standard_oauth import router as oauth_router
|
||||
from danswer.server.features.document_set.api import router as document_set_router
|
||||
from danswer.server.features.folder.api import router as folder_router
|
||||
from danswer.server.features.input_prompt.api import (
|
||||
admin_router as admin_input_prompt_router,
|
||||
)
|
||||
from danswer.server.features.input_prompt.api import basic_router as input_prompt_router
|
||||
from danswer.server.features.notifications.api import router as notification_router
|
||||
from danswer.server.features.persona.api import admin_router as admin_persona_router
|
||||
from danswer.server.features.persona.api import basic_router as persona_router
|
||||
@@ -258,8 +255,6 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, persona_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_persona_router)
|
||||
include_router_with_global_prefix_prepended(application, input_prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, notification_router)
|
||||
include_router_with_global_prefix_prepended(application, prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, tool_router)
|
||||
@@ -282,6 +277,7 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, long_term_logs_router)
|
||||
include_router_with_global_prefix_prepended(application, api_key_router)
|
||||
include_router_with_global_prefix_prepended(application, oauth_router)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
input_prompts:
|
||||
- id: -5
|
||||
prompt: "Elaborate"
|
||||
content: "Elaborate on the above, give me a more in depth explanation."
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -4
|
||||
prompt: "Reword"
|
||||
content: "Help me rewrite the following politely and concisely for professional communication:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -3
|
||||
prompt: "Email"
|
||||
content: "Write a professional email for me including a subject line, signature, etc. Template the parts that need editing with [ ]. The email should cover the following points:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -2
|
||||
prompt: "Debug"
|
||||
content: "Provide step-by-step troubleshooting instructions for the following issue:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
@@ -196,7 +196,7 @@ def seed_initial_documents(
|
||||
docs, chunks = _create_indexable_chunks(processed_docs, tenant_id)
|
||||
|
||||
index_doc_batch_prepare(
|
||||
document_batch=docs,
|
||||
documents=docs,
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
connector_id=connector_id,
|
||||
credential_id=PUBLIC_CREDENTIAL_ID,
|
||||
@@ -216,7 +216,7 @@ def seed_initial_documents(
|
||||
# as we just sent over the Vespa schema and there is a slight delay
|
||||
|
||||
index_with_retries = retry_builder()(document_index.index)
|
||||
index_with_retries(chunks=chunks)
|
||||
index_with_retries(chunks=chunks, fresh_index=True)
|
||||
|
||||
# Mock a run for the UI even though it did not actually call out to anything
|
||||
mock_successful_index_attempt(
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.chat_configs import INPUT_PROMPT_YAML
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import PERSONAS_YAML
|
||||
from danswer.configs.chat_configs import PROMPTS_YAML
|
||||
from danswer.context.search.enums import RecencyBiasSetting
|
||||
from danswer.db.document_set import get_or_create_document_set_by_name
|
||||
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt as PromptDBModel
|
||||
@@ -79,6 +77,9 @@ def load_personas_from_yaml(
|
||||
if prompts:
|
||||
prompt_ids = [prompt.id for prompt in prompts if prompt is not None]
|
||||
|
||||
if not prompt_ids:
|
||||
raise ValueError("Invalid Persona config, no prompts exist")
|
||||
|
||||
p_id = persona.get("id")
|
||||
tool_ids = []
|
||||
|
||||
@@ -123,45 +124,24 @@ def load_personas_from_yaml(
|
||||
tool_ids=tool_ids,
|
||||
builtin_persona=True,
|
||||
is_public=True,
|
||||
display_priority=existing_persona.display_priority
|
||||
if existing_persona is not None
|
||||
else persona.get("display_priority"),
|
||||
is_visible=existing_persona.is_visible
|
||||
if existing_persona is not None
|
||||
else persona.get("is_visible"),
|
||||
display_priority=(
|
||||
existing_persona.display_priority
|
||||
if existing_persona is not None
|
||||
else persona.get("display_priority")
|
||||
),
|
||||
is_visible=(
|
||||
existing_persona.is_visible
|
||||
if existing_persona is not None
|
||||
else persona.get("is_visible")
|
||||
),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def load_input_prompts_from_yaml(
|
||||
db_session: Session, input_prompts_yaml: str = INPUT_PROMPT_YAML
|
||||
) -> None:
|
||||
with open(input_prompts_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_input_prompts = data.get("input_prompts", [])
|
||||
for input_prompt in all_input_prompts:
|
||||
# If these prompts are deleted (which is a hard delete in the DB), on server startup
|
||||
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
|
||||
|
||||
insert_input_prompt_if_not_exists(
|
||||
user=None,
|
||||
input_prompt_id=input_prompt.get("id"),
|
||||
prompt=input_prompt["prompt"],
|
||||
content=input_prompt["content"],
|
||||
is_public=input_prompt["is_public"],
|
||||
active=input_prompt.get("active", True),
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
|
||||
def load_chat_yamls(
|
||||
db_session: Session,
|
||||
prompt_yaml: str = PROMPTS_YAML,
|
||||
personas_yaml: str = PERSONAS_YAML,
|
||||
input_prompts_yaml: str = INPUT_PROMPT_YAML,
|
||||
) -> None:
|
||||
load_prompts_from_yaml(db_session, prompt_yaml)
|
||||
load_personas_from_yaml(db_session, personas_yaml)
|
||||
load_input_prompts_from_yaml(db_session, input_prompts_yaml)
|
||||
|
||||
@@ -33,8 +33,6 @@ from danswer.db.engine import get_current_tenant_id
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import count_index_attempts_for_connector
|
||||
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||
@@ -45,6 +43,7 @@ from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import CCPairFullInfo
|
||||
from danswer.server.documents.models import CCPropertyUpdateRequest
|
||||
from danswer.server.documents.models import CCStatusUpdateRequest
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorCredentialPairMetadata
|
||||
@@ -192,9 +191,6 @@ def update_cc_pair_status(
|
||||
db_session
|
||||
)
|
||||
|
||||
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
try:
|
||||
@@ -308,6 +304,46 @@ def update_cc_pair_name(
|
||||
raise HTTPException(status_code=400, detail="Name must be unique")
|
||||
|
||||
|
||||
@router.put("/admin/cc-pair/{cc_pair_id}/property")
|
||||
def update_cc_pair_property(
|
||||
cc_pair_id: int,
|
||||
update_request: CCPropertyUpdateRequest, # in seconds
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[int]:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=True,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="CC Pair not found for current user's permissions"
|
||||
)
|
||||
|
||||
# Can we centralize logic for updating connector properties
|
||||
# so that we don't need to manually validate everywhere?
|
||||
if update_request.name == "refresh_frequency":
|
||||
cc_pair.connector.refresh_freq = int(update_request.value)
|
||||
cc_pair.connector.validate_refresh_freq()
|
||||
db_session.commit()
|
||||
|
||||
msg = "Refresh frequency updated successfully"
|
||||
elif update_request.name == "pruning_frequency":
|
||||
cc_pair.connector.prune_freq = int(update_request.value)
|
||||
cc_pair.connector.validate_prune_freq()
|
||||
db_session.commit()
|
||||
|
||||
msg = "Pruning frequency updated successfully"
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Property name {update_request.name} is not valid."
|
||||
)
|
||||
|
||||
return StatusResponse(success=True, message=msg, data=cc_pair_id)
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/last_pruned")
|
||||
def get_cc_pair_last_pruned(
|
||||
cc_pair_id: int,
|
||||
|
||||
@@ -181,7 +181,13 @@ def update_credential_data(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CredentialBase:
|
||||
credential = alter_credential(credential_id, credential_update, user, db_session)
|
||||
credential = alter_credential(
|
||||
credential_id,
|
||||
credential_update.name,
|
||||
credential_update.credential_json,
|
||||
user,
|
||||
db_session,
|
||||
)
|
||||
|
||||
if credential is None:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -364,6 +364,11 @@ class RunConnectorRequest(BaseModel):
|
||||
from_beginning: bool = False
|
||||
|
||||
|
||||
class CCPropertyUpdateRequest(BaseModel):
|
||||
name: str
|
||||
value: str
|
||||
|
||||
|
||||
"""Connectors Models"""
|
||||
|
||||
|
||||
|
||||
142
backend/danswer/server/documents/standard_oauth.py
Normal file
142
backend/danswer/server/documents/standard_oauth.py
Normal file
@@ -0,0 +1,142 @@
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import OAuthConnector
|
||||
from danswer.db.credentials import create_credential
|
||||
from danswer.db.engine import get_current_tenant_id
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.subclasses import find_all_subclasses_in_dir
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/connector/oauth")
|
||||
|
||||
_OAUTH_STATE_KEY_FMT = "oauth_state:{state}"
|
||||
_OAUTH_STATE_EXPIRATION_SECONDS = 10 * 60 # 10 minutes
|
||||
|
||||
# Cache for OAuth connectors, populated at module load time
|
||||
_OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {}
|
||||
|
||||
|
||||
def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]:
|
||||
"""Walk through the connectors package to find all OAuthConnector implementations"""
|
||||
global _OAUTH_CONNECTORS
|
||||
if _OAUTH_CONNECTORS: # Return cached connectors if already discovered
|
||||
return _OAUTH_CONNECTORS
|
||||
|
||||
oauth_connectors = find_all_subclasses_in_dir(
|
||||
cast(type[OAuthConnector], OAuthConnector), "danswer.connectors"
|
||||
)
|
||||
|
||||
_OAUTH_CONNECTORS = {cls.oauth_id(): cls for cls in oauth_connectors}
|
||||
return _OAUTH_CONNECTORS
|
||||
|
||||
|
||||
# Discover OAuth connectors at module load time
|
||||
_discover_oauth_connectors()
|
||||
|
||||
|
||||
class AuthorizeResponse(BaseModel):
|
||||
redirect_url: str
|
||||
|
||||
|
||||
@router.get("/authorize/{source}")
|
||||
def oauth_authorize(
|
||||
source: DocumentSource,
|
||||
desired_return_url: Annotated[str | None, Query()] = None,
|
||||
_: User = Depends(current_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> AuthorizeResponse:
|
||||
"""Initiates the OAuth flow by redirecting to the provider's auth page"""
|
||||
oauth_connectors = _discover_oauth_connectors()
|
||||
|
||||
if source not in oauth_connectors:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}")
|
||||
|
||||
connector_cls = oauth_connectors[source]
|
||||
base_url = WEB_DOMAIN
|
||||
|
||||
# store state in redis
|
||||
if not desired_return_url:
|
||||
desired_return_url = f"{base_url}/admin/connectors/{source}?step=0"
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
state = str(uuid.uuid4())
|
||||
redis_client.set(
|
||||
_OAUTH_STATE_KEY_FMT.format(state=state),
|
||||
desired_return_url,
|
||||
ex=_OAUTH_STATE_EXPIRATION_SECONDS,
|
||||
)
|
||||
|
||||
return AuthorizeResponse(
|
||||
redirect_url=connector_cls.oauth_authorization_url(base_url, state)
|
||||
)
|
||||
|
||||
|
||||
class CallbackResponse(BaseModel):
|
||||
redirect_url: str
|
||||
|
||||
|
||||
@router.get("/callback/{source}")
|
||||
def oauth_callback(
|
||||
source: DocumentSource,
|
||||
code: Annotated[str, Query()],
|
||||
state: Annotated[str, Query()],
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User = Depends(current_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> CallbackResponse:
|
||||
"""Handles the OAuth callback and exchanges the code for tokens"""
|
||||
oauth_connectors = _discover_oauth_connectors()
|
||||
|
||||
if source not in oauth_connectors:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}")
|
||||
|
||||
connector_cls = oauth_connectors[source]
|
||||
|
||||
# get state from redis
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
original_url_bytes = cast(
|
||||
bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state))
|
||||
)
|
||||
if not original_url_bytes:
|
||||
raise HTTPException(status_code=400, detail="Invalid OAuth state")
|
||||
original_url = original_url_bytes.decode("utf-8")
|
||||
|
||||
token_info = connector_cls.oauth_code_to_token(code)
|
||||
|
||||
# Create a new credential with the token info
|
||||
credential_data = CredentialBase(
|
||||
credential_json=token_info,
|
||||
admin_public=True, # Or based on some logic/parameter
|
||||
source=source,
|
||||
name=f"{source.title()} OAuth Credential",
|
||||
)
|
||||
|
||||
credential = create_credential(
|
||||
credential_data=credential_data,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return CallbackResponse(
|
||||
redirect_url=(
|
||||
f"{original_url}?credentialId={credential.id}"
|
||||
if "?" not in original_url
|
||||
else f"{original_url}&credentialId={credential.id}"
|
||||
)
|
||||
)
|
||||
@@ -1,134 +0,0 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.input_prompt import fetch_input_prompt_by_id
|
||||
from danswer.db.input_prompt import fetch_input_prompts_by_user
|
||||
from danswer.db.input_prompt import fetch_public_input_prompts
|
||||
from danswer.db.input_prompt import insert_input_prompt
|
||||
from danswer.db.input_prompt import remove_input_prompt
|
||||
from danswer.db.input_prompt import remove_public_input_prompt
|
||||
from danswer.db.input_prompt import update_input_prompt
|
||||
from danswer.db.models import User
|
||||
from danswer.server.features.input_prompt.models import CreateInputPromptRequest
|
||||
from danswer.server.features.input_prompt.models import InputPromptSnapshot
|
||||
from danswer.server.features.input_prompt.models import UpdateInputPromptRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
basic_router = APIRouter(prefix="/input_prompt")
|
||||
admin_router = APIRouter(prefix="/admin/input_prompt")
|
||||
|
||||
|
||||
@basic_router.get("")
|
||||
def list_input_prompts(
|
||||
user: User | None = Depends(current_user),
|
||||
include_public: bool = False,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[InputPromptSnapshot]:
|
||||
user_prompts = fetch_input_prompts_by_user(
|
||||
user_id=user.id if user is not None else None,
|
||||
db_session=db_session,
|
||||
include_public=include_public,
|
||||
)
|
||||
return [InputPromptSnapshot.from_model(prompt) for prompt in user_prompts]
|
||||
|
||||
|
||||
@basic_router.get("/{input_prompt_id}")
|
||||
def get_input_prompt(
|
||||
input_prompt_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> InputPromptSnapshot:
|
||||
input_prompt = fetch_input_prompt_by_id(
|
||||
id=input_prompt_id,
|
||||
user_id=user.id if user is not None else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
return InputPromptSnapshot.from_model(input_prompt=input_prompt)
|
||||
|
||||
|
||||
@basic_router.post("")
|
||||
def create_input_prompt(
|
||||
create_input_prompt_request: CreateInputPromptRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> InputPromptSnapshot:
|
||||
input_prompt = insert_input_prompt(
|
||||
prompt=create_input_prompt_request.prompt,
|
||||
content=create_input_prompt_request.content,
|
||||
is_public=create_input_prompt_request.is_public,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
return InputPromptSnapshot.from_model(input_prompt)
|
||||
|
||||
|
||||
@basic_router.patch("/{input_prompt_id}")
|
||||
def patch_input_prompt(
|
||||
input_prompt_id: int,
|
||||
update_input_prompt_request: UpdateInputPromptRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> InputPromptSnapshot:
|
||||
try:
|
||||
updated_input_prompt = update_input_prompt(
|
||||
user=user,
|
||||
input_prompt_id=input_prompt_id,
|
||||
prompt=update_input_prompt_request.prompt,
|
||||
content=update_input_prompt_request.content,
|
||||
active=update_input_prompt_request.active,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError as e:
|
||||
error_msg = "Error occurred while updated input prompt"
|
||||
logger.warn(f"{error_msg}. Stack trace: {e}")
|
||||
raise HTTPException(status_code=404, detail=error_msg)
|
||||
|
||||
return InputPromptSnapshot.from_model(updated_input_prompt)
|
||||
|
||||
|
||||
@basic_router.delete("/{input_prompt_id}")
|
||||
def delete_input_prompt(
|
||||
input_prompt_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
remove_input_prompt(user, input_prompt_id, db_session)
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = "Error occurred while deleting input prompt"
|
||||
logger.warn(f"{error_msg}. Stack trace: {e}")
|
||||
raise HTTPException(status_code=404, detail=error_msg)
|
||||
|
||||
|
||||
@admin_router.delete("/{input_prompt_id}")
|
||||
def delete_public_input_prompt(
|
||||
input_prompt_id: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
remove_public_input_prompt(input_prompt_id, db_session)
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = "Error occurred while deleting input prompt"
|
||||
logger.warn(f"{error_msg}. Stack trace: {e}")
|
||||
raise HTTPException(status_code=404, detail=error_msg)
|
||||
|
||||
|
||||
@admin_router.get("")
|
||||
def list_public_input_prompts(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[InputPromptSnapshot]:
|
||||
user_prompts = fetch_public_input_prompts(
|
||||
db_session=db_session,
|
||||
)
|
||||
return [InputPromptSnapshot.from_model(prompt) for prompt in user_prompts]
|
||||
@@ -1,47 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.db.models import InputPrompt
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class CreateInputPromptRequest(BaseModel):
|
||||
prompt: str
|
||||
content: str
|
||||
is_public: bool
|
||||
|
||||
|
||||
class UpdateInputPromptRequest(BaseModel):
|
||||
prompt: str
|
||||
content: str
|
||||
active: bool
|
||||
|
||||
|
||||
class InputPromptResponse(BaseModel):
|
||||
id: int
|
||||
prompt: str
|
||||
content: str
|
||||
active: bool
|
||||
|
||||
|
||||
class InputPromptSnapshot(BaseModel):
|
||||
id: int
|
||||
prompt: str
|
||||
content: str
|
||||
active: bool
|
||||
user_id: UUID | None
|
||||
is_public: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, input_prompt: InputPrompt) -> "InputPromptSnapshot":
|
||||
return InputPromptSnapshot(
|
||||
id=input_prompt.id,
|
||||
prompt=input_prompt.prompt,
|
||||
content=input_prompt.content,
|
||||
active=input_prompt.active,
|
||||
user_id=input_prompt.user_id,
|
||||
is_public=input_prompt.is_public,
|
||||
)
|
||||
@@ -266,5 +266,7 @@ class FullModelVersionResponse(BaseModel):
|
||||
class AllUsersResponse(BaseModel):
|
||||
accepted: list[FullUserSnapshot]
|
||||
invited: list[InvitedUserSnapshot]
|
||||
slack_users: list[FullUserSnapshot]
|
||||
accepted_pages: int
|
||||
invited_pages: int
|
||||
slack_users_pages: int
|
||||
|
||||
@@ -119,6 +119,7 @@ def set_user_role(
|
||||
def list_all_users(
|
||||
q: str | None = None,
|
||||
accepted_page: int | None = None,
|
||||
slack_users_page: int | None = None,
|
||||
invited_page: int | None = None,
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
@@ -131,7 +132,12 @@ def list_all_users(
|
||||
for user in list_users(db_session, email_filter_string=q)
|
||||
if not is_api_key_email_address(user.email)
|
||||
]
|
||||
accepted_emails = {user.email for user in users}
|
||||
|
||||
slack_users = [user for user in users if user.role == UserRole.SLACK_USER]
|
||||
accepted_users = [user for user in users if user.role != UserRole.SLACK_USER]
|
||||
|
||||
accepted_emails = {user.email for user in accepted_users}
|
||||
slack_users_emails = {user.email for user in slack_users}
|
||||
invited_emails = get_invited_users()
|
||||
if q:
|
||||
invited_emails = [
|
||||
@@ -139,10 +145,11 @@ def list_all_users(
|
||||
]
|
||||
|
||||
accepted_count = len(accepted_emails)
|
||||
slack_users_count = len(slack_users_emails)
|
||||
invited_count = len(invited_emails)
|
||||
|
||||
# If any of q, accepted_page, or invited_page is None, return all users
|
||||
if accepted_page is None or invited_page is None:
|
||||
if accepted_page is None or invited_page is None or slack_users_page is None:
|
||||
return AllUsersResponse(
|
||||
accepted=[
|
||||
FullUserSnapshot(
|
||||
@@ -153,11 +160,23 @@ def list_all_users(
|
||||
UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED
|
||||
),
|
||||
)
|
||||
for user in users
|
||||
for user in accepted_users
|
||||
],
|
||||
slack_users=[
|
||||
FullUserSnapshot(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
status=(
|
||||
UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED
|
||||
),
|
||||
)
|
||||
for user in slack_users
|
||||
],
|
||||
invited=[InvitedUserSnapshot(email=email) for email in invited_emails],
|
||||
accepted_pages=1,
|
||||
invited_pages=1,
|
||||
slack_users_pages=1,
|
||||
)
|
||||
|
||||
# Otherwise, return paginated results
|
||||
@@ -169,13 +188,27 @@ def list_all_users(
|
||||
role=user.role,
|
||||
status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED,
|
||||
)
|
||||
for user in users
|
||||
for user in accepted_users
|
||||
][accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE],
|
||||
slack_users=[
|
||||
FullUserSnapshot(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED,
|
||||
)
|
||||
for user in slack_users
|
||||
][
|
||||
slack_users_page
|
||||
* USERS_PAGE_SIZE : (slack_users_page + 1)
|
||||
* USERS_PAGE_SIZE
|
||||
],
|
||||
invited=[InvitedUserSnapshot(email=email) for email in invited_emails][
|
||||
invited_page * USERS_PAGE_SIZE : (invited_page + 1) * USERS_PAGE_SIZE
|
||||
],
|
||||
accepted_pages=accepted_count // USERS_PAGE_SIZE + 1,
|
||||
invited_pages=invited_count // USERS_PAGE_SIZE + 1,
|
||||
slack_users_pages=slack_users_count // USERS_PAGE_SIZE + 1,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -48,6 +48,9 @@ from danswer.tools.tool_implementations.search_like_tool_utils import (
|
||||
from danswer.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search_like_tool_utils import (
|
||||
ORIGINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.special_types import JSON_ro
|
||||
|
||||
@@ -391,15 +394,35 @@ class SearchTool(Tool):
|
||||
"""Other utility functions"""
|
||||
|
||||
@classmethod
|
||||
def get_search_result(cls, llm_call: LLMCall) -> list[LlmDoc] | None:
|
||||
def get_search_result(
|
||||
cls, llm_call: LLMCall
|
||||
) -> tuple[list[LlmDoc], dict[str, int]] | None:
|
||||
"""
|
||||
Returns the final search results and a map of docs to their original search rank (which is what is displayed to user)
|
||||
"""
|
||||
if not llm_call.tool_call_info:
|
||||
return None
|
||||
|
||||
final_search_results = []
|
||||
doc_id_to_original_search_rank_map = {}
|
||||
|
||||
for yield_item in llm_call.tool_call_info:
|
||||
if (
|
||||
isinstance(yield_item, ToolResponse)
|
||||
and yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID
|
||||
):
|
||||
return cast(list[LlmDoc], yield_item.response)
|
||||
final_search_results = cast(list[LlmDoc], yield_item.response)
|
||||
elif (
|
||||
isinstance(yield_item, ToolResponse)
|
||||
and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID
|
||||
):
|
||||
search_contexts = yield_item.response.contexts
|
||||
original_doc_search_rank = 1
|
||||
for idx, doc in enumerate(search_contexts):
|
||||
if doc.document_id not in doc_id_to_original_search_rank_map:
|
||||
doc_id_to_original_search_rank_map[
|
||||
doc.document_id
|
||||
] = original_doc_search_rank
|
||||
original_doc_search_rank += 1
|
||||
|
||||
return None
|
||||
return final_search_results, doc_id_to_original_search_rank_map
|
||||
|
||||
@@ -15,6 +15,7 @@ from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
ORIGINAL_CONTEXT_DOCUMENTS_ID = "search_doc_content"
|
||||
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"
|
||||
|
||||
|
||||
|
||||
77
backend/danswer/utils/subclasses.py
Normal file
77
backend/danswer/utils/subclasses.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import pkgutil
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
from typing import Type
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def import_all_modules_from_dir(dir_path: str) -> List[ModuleType]:
|
||||
"""
|
||||
Imports all modules found in the given directory and its subdirectories,
|
||||
returning a list of imported module objects.
|
||||
"""
|
||||
dir_path = os.path.abspath(dir_path)
|
||||
|
||||
if dir_path not in sys.path:
|
||||
sys.path.insert(0, dir_path)
|
||||
|
||||
imported_modules: List[ModuleType] = []
|
||||
|
||||
for _, package_name, _ in pkgutil.walk_packages([dir_path]):
|
||||
try:
|
||||
module = importlib.import_module(package_name)
|
||||
imported_modules.append(module)
|
||||
except Exception as e:
|
||||
# Handle or log exceptions as needed
|
||||
print(f"Could not import {package_name}: {e}")
|
||||
|
||||
return imported_modules
|
||||
|
||||
|
||||
def all_subclasses(cls: Type[T]) -> List[Type[T]]:
|
||||
"""
|
||||
Recursively find all subclasses of the given class.
|
||||
"""
|
||||
direct_subs = cls.__subclasses__()
|
||||
result: List[Type[T]] = []
|
||||
for subclass in direct_subs:
|
||||
result.append(subclass)
|
||||
# Extend the result by recursively calling all_subclasses
|
||||
result.extend(all_subclasses(subclass))
|
||||
return result
|
||||
|
||||
|
||||
def find_all_subclasses_in_dir(parent_class: Type[T], directory: str) -> List[Type[T]]:
|
||||
"""
|
||||
Imports all modules from the given directory (and subdirectories),
|
||||
then returns all classes that are subclasses of parent_class.
|
||||
|
||||
:param parent_class: The class to find subclasses of.
|
||||
:param directory: The directory to search for subclasses.
|
||||
:return: A list of all subclasses of parent_class found in the directory.
|
||||
"""
|
||||
# First import all modules to ensure classes are loaded into memory
|
||||
import_all_modules_from_dir(directory)
|
||||
|
||||
# Gather all subclasses of the given parent class
|
||||
subclasses = all_subclasses(parent_class)
|
||||
return subclasses
|
||||
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
|
||||
class Animal:
|
||||
pass
|
||||
|
||||
# Suppose "mymodules" contains files that define classes inheriting from Animal
|
||||
found_subclasses = find_all_subclasses_in_dir(Animal, "mymodules")
|
||||
for sc in found_subclasses:
|
||||
print("Found subclass:", sc.__name__)
|
||||
@@ -11,6 +11,14 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/danswer/configs/saml
|
||||
#####
|
||||
# Auto Permission Sync
|
||||
#####
|
||||
# In seconds, default is 5 minutes
|
||||
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
# In seconds, default is 5 minutes
|
||||
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,9 @@ from danswer.access.utils import prefix_group_w_source
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.models import User__ExternalUserGroupId
|
||||
from danswer.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ExternalUserGroup(BaseModel):
|
||||
@@ -73,7 +76,13 @@ def replace_user__ext_group_for_cc_pair(
|
||||
new_external_permissions = []
|
||||
for external_group in group_defs:
|
||||
for user_email in external_group.user_emails:
|
||||
user_id = email_id_map[user_email]
|
||||
user_id = email_id_map.get(user_email.lower())
|
||||
if user_id is None:
|
||||
logger.warning(
|
||||
f"User in group {external_group.id}"
|
||||
f" with email {user_email} not found"
|
||||
)
|
||||
continue
|
||||
new_external_permissions.append(
|
||||
User__ExternalUserGroupId(
|
||||
user_id=user_id,
|
||||
|
||||
@@ -195,6 +195,7 @@ def _fetch_all_page_restrictions_for_space(
|
||||
confluence_client: OnyxConfluence,
|
||||
slim_docs: list[SlimDocument],
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
is_cloud: bool,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
For all pages, if a page has restrictions, then use those restrictions.
|
||||
@@ -222,29 +223,50 @@ def _fetch_all_page_restrictions_for_space(
|
||||
continue
|
||||
|
||||
space_key = slim_doc.perm_sync_data.get("space_key")
|
||||
if space_permissions := space_permissions_by_space_key.get(space_key):
|
||||
# If there are no restrictions, then use the space's restrictions
|
||||
document_restrictions.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=space_permissions,
|
||||
)
|
||||
if not (space_permissions := space_permissions_by_space_key.get(space_key)):
|
||||
logger.debug(
|
||||
f"Individually fetching space permissions for space {space_key}"
|
||||
)
|
||||
if (
|
||||
not space_permissions.is_public
|
||||
and not space_permissions.external_user_emails
|
||||
and not space_permissions.external_user_group_ids
|
||||
):
|
||||
try:
|
||||
# If the space permissions are not in the cache, then fetch them
|
||||
if is_cloud:
|
||||
retrieved_space_permissions = _get_cloud_space_permissions(
|
||||
confluence_client=confluence_client, space_key=space_key
|
||||
)
|
||||
else:
|
||||
retrieved_space_permissions = _get_server_space_permissions(
|
||||
confluence_client=confluence_client, space_key=space_key
|
||||
)
|
||||
space_permissions_by_space_key[space_key] = retrieved_space_permissions
|
||||
space_permissions = retrieved_space_permissions
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Permissions are empty for document: {slim_doc.id}\n"
|
||||
"This means space permissions are may be wrong for"
|
||||
f" Space key: {space_key}"
|
||||
f"Error fetching space permissions for space {space_key}: {e}"
|
||||
)
|
||||
|
||||
if not space_permissions:
|
||||
logger.warning(
|
||||
f"No permissions found for document {slim_doc.id} in space {space_key}"
|
||||
)
|
||||
continue
|
||||
|
||||
logger.warning(
|
||||
f"No permissions found for document {slim_doc.id} in space {space_key}"
|
||||
# If there are no restrictions, then use the space's restrictions
|
||||
document_restrictions.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
external_access=space_permissions,
|
||||
)
|
||||
)
|
||||
if (
|
||||
not space_permissions.is_public
|
||||
and not space_permissions.external_user_emails
|
||||
and not space_permissions.external_user_group_ids
|
||||
):
|
||||
logger.warning(
|
||||
f"Permissions are empty for document: {slim_doc.id}\n"
|
||||
"This means space permissions are may be wrong for"
|
||||
f" Space key: {space_key}"
|
||||
)
|
||||
|
||||
logger.debug("Finished fetching all page restrictions for space")
|
||||
return document_restrictions
|
||||
@@ -283,4 +305,5 @@ def confluence_doc_sync(
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
is_cloud=is_cloud,
|
||||
)
|
||||
|
||||
@@ -3,6 +3,8 @@ from collections.abc import Callable
|
||||
from danswer.access.models import DocExternalAccess
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from ee.danswer.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.danswer.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.external_permissions.confluence.doc_sync import confluence_doc_sync
|
||||
from ee.danswer.external_permissions.confluence.group_sync import confluence_group_sync
|
||||
@@ -56,7 +58,7 @@ GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
|
||||
# If nothing is specified here, we run the doc_sync every time the celery beat runs
|
||||
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all doc permissions every 5 minutes
|
||||
DocumentSource.CONFLUENCE: 5 * 60,
|
||||
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
DocumentSource.SLACK: 5 * 60,
|
||||
}
|
||||
|
||||
@@ -64,7 +66,7 @@ DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
|
||||
# Polling is not supported so we fetch all group permissions every 30 minutes
|
||||
DocumentSource.GOOGLE_DRIVE: 5 * 60,
|
||||
DocumentSource.CONFLUENCE: 30 * 60,
|
||||
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -132,13 +132,18 @@ def _seed_personas(db_session: Session, personas: list[CreatePersonaRequest]) ->
|
||||
if personas:
|
||||
logger.notice("Seeding Personas")
|
||||
for persona in personas:
|
||||
if not persona.prompt_ids:
|
||||
raise ValueError(
|
||||
f"Invalid Persona with name {persona.name}; no prompts exist"
|
||||
)
|
||||
|
||||
upsert_persona(
|
||||
user=None, # Seeding is done as admin
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks
|
||||
if persona.num_chunks is not None
|
||||
else 0.0,
|
||||
num_chunks=(
|
||||
persona.num_chunks if persona.num_chunks is not None else 0.0
|
||||
),
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
from types import TracebackType
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
@@ -6,11 +8,11 @@ import httpx
|
||||
import openai
|
||||
import vertexai # type: ignore
|
||||
import voyageai # type: ignore
|
||||
from cohere import Client as CohereClient
|
||||
from cohere import AsyncClient as CohereAsyncClient
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from litellm import embedding
|
||||
from litellm import aembedding
|
||||
from litellm.exceptions import RateLimitError
|
||||
from retry import retry
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
@@ -63,22 +65,31 @@ class CloudEmbedding:
|
||||
provider: EmbeddingProvider,
|
||||
api_url: str | None = None,
|
||||
api_version: str | None = None,
|
||||
timeout: int = API_BASED_EMBEDDING_TIMEOUT,
|
||||
) -> None:
|
||||
self.provider = provider
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url
|
||||
self.api_version = api_version
|
||||
self.timeout = timeout
|
||||
self.http_client = httpx.AsyncClient(timeout=timeout)
|
||||
self._closed = False
|
||||
|
||||
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
|
||||
async def _embed_openai(
|
||||
self, texts: list[str], model: str | None
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_OPENAI_MODEL
|
||||
|
||||
client = openai.OpenAI(api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT)
|
||||
# Use the OpenAI specific timeout for this one
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT
|
||||
)
|
||||
|
||||
final_embeddings: list[Embedding] = []
|
||||
try:
|
||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||
response = client.embeddings.create(input=text_batch, model=model)
|
||||
response = await client.embeddings.create(input=text_batch, model=model)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
)
|
||||
@@ -93,19 +104,19 @@ class CloudEmbedding:
|
||||
logger.error(error_string)
|
||||
raise RuntimeError(error_string)
|
||||
|
||||
def _embed_cohere(
|
||||
async def _embed_cohere(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_COHERE_MODEL
|
||||
|
||||
client = CohereClient(api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT)
|
||||
client = CohereAsyncClient(api_key=self.api_key)
|
||||
|
||||
final_embeddings: list[Embedding] = []
|
||||
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
|
||||
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
|
||||
# empirically it's only off by a very few tokens so it's not a big deal
|
||||
response = client.embed(
|
||||
response = await client.embed(
|
||||
texts=text_batch,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
@@ -114,26 +125,29 @@ class CloudEmbedding:
|
||||
final_embeddings.extend(cast(list[Embedding], response.embeddings))
|
||||
return final_embeddings
|
||||
|
||||
def _embed_voyage(
|
||||
async def _embed_voyage(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_VOYAGE_MODEL
|
||||
|
||||
client = voyageai.Client(
|
||||
client = voyageai.AsyncClient(
|
||||
api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT
|
||||
)
|
||||
|
||||
response = client.embed(
|
||||
texts,
|
||||
response = await client.embed(
|
||||
texts=texts,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
return response.embeddings
|
||||
|
||||
def _embed_azure(self, texts: list[str], model: str | None) -> list[Embedding]:
|
||||
response = embedding(
|
||||
async def _embed_azure(
|
||||
self, texts: list[str], model: str | None
|
||||
) -> list[Embedding]:
|
||||
response = await aembedding(
|
||||
model=model,
|
||||
input=texts,
|
||||
timeout=API_BASED_EMBEDDING_TIMEOUT,
|
||||
@@ -142,10 +156,9 @@ class CloudEmbedding:
|
||||
api_version=self.api_version,
|
||||
)
|
||||
embeddings = [embedding["embedding"] for embedding in response.data]
|
||||
|
||||
return embeddings
|
||||
|
||||
def _embed_vertex(
|
||||
async def _embed_vertex(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
@@ -158,7 +171,7 @@ class CloudEmbedding:
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
client = TextEmbeddingModel.from_pretrained(model)
|
||||
|
||||
embeddings = client.get_embeddings(
|
||||
embeddings = await client.get_embeddings_async(
|
||||
[
|
||||
TextEmbeddingInput(
|
||||
text,
|
||||
@@ -166,11 +179,11 @@ class CloudEmbedding:
|
||||
)
|
||||
for text in texts
|
||||
],
|
||||
auto_truncate=True, # Also this is default
|
||||
auto_truncate=True, # This is the default
|
||||
)
|
||||
return [embedding.values for embedding in embeddings]
|
||||
|
||||
def _embed_litellm_proxy(
|
||||
async def _embed_litellm_proxy(
|
||||
self, texts: list[str], model_name: str | None
|
||||
) -> list[Embedding]:
|
||||
if not model_name:
|
||||
@@ -183,22 +196,20 @@ class CloudEmbedding:
|
||||
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
|
||||
)
|
||||
|
||||
with httpx.Client() as client:
|
||||
response = client.post(
|
||||
self.api_url,
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": texts,
|
||||
},
|
||||
headers=headers,
|
||||
timeout=API_BASED_EMBEDDING_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return [embedding["embedding"] for embedding in result["data"]]
|
||||
response = await self.http_client.post(
|
||||
self.api_url,
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": texts,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return [embedding["embedding"] for embedding in result["data"]]
|
||||
|
||||
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
|
||||
def embed(
|
||||
async def embed(
|
||||
self,
|
||||
*,
|
||||
texts: list[str],
|
||||
@@ -207,19 +218,19 @@ class CloudEmbedding:
|
||||
deployment_name: str | None = None,
|
||||
) -> list[Embedding]:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(texts, model_name)
|
||||
return await self._embed_openai(texts, model_name)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
return await self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return self._embed_litellm_proxy(texts, model_name)
|
||||
return await self._embed_litellm_proxy(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return self._embed_cohere(texts, model_name, embedding_type)
|
||||
return await self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return self._embed_voyage(texts, model_name, embedding_type)
|
||||
return await self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return self._embed_vertex(texts, model_name, embedding_type)
|
||||
return await self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
@@ -233,6 +244,30 @@ class CloudEmbedding:
|
||||
logger.debug(f"Creating Embedding instance for provider: {provider}")
|
||||
return CloudEmbedding(api_key, provider, api_url, api_version)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Explicitly close the client."""
|
||||
if not self._closed:
|
||||
await self.http_client.aclose()
|
||||
self._closed = True
|
||||
|
||||
async def __aenter__(self) -> "CloudEmbedding":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Finalizer to warn about unclosed clients."""
|
||||
if not self._closed:
|
||||
logger.warning(
|
||||
"CloudEmbedding was not properly closed. Use 'async with' or call aclose()"
|
||||
)
|
||||
|
||||
|
||||
def get_embedding_model(
|
||||
model_name: str,
|
||||
@@ -242,9 +277,6 @@ def get_embedding_model(
|
||||
|
||||
global _GLOBAL_MODELS_DICT # A dictionary to store models
|
||||
|
||||
if _GLOBAL_MODELS_DICT is None:
|
||||
_GLOBAL_MODELS_DICT = {}
|
||||
|
||||
if model_name not in _GLOBAL_MODELS_DICT:
|
||||
logger.notice(f"Loading {model_name}")
|
||||
# Some model architectures that aren't built into the Transformers or Sentence
|
||||
@@ -275,7 +307,7 @@ def get_local_reranking_model(
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def embed_text(
|
||||
async def embed_text(
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None,
|
||||
@@ -311,18 +343,18 @@ def embed_text(
|
||||
"Cloud models take an explicit text type instead."
|
||||
)
|
||||
|
||||
cloud_model = CloudEmbedding(
|
||||
async with CloudEmbedding(
|
||||
api_key=api_key,
|
||||
provider=provider_type,
|
||||
api_url=api_url,
|
||||
api_version=api_version,
|
||||
)
|
||||
embeddings = cloud_model.embed(
|
||||
texts=texts,
|
||||
model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
text_type=text_type,
|
||||
)
|
||||
) as cloud_model:
|
||||
embeddings = await cloud_model.embed(
|
||||
texts=texts,
|
||||
model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
text_type=text_type,
|
||||
)
|
||||
|
||||
if any(embedding is None for embedding in embeddings):
|
||||
error_message = "Embeddings contain None values\n"
|
||||
@@ -338,8 +370,12 @@ def embed_text(
|
||||
local_model = get_embedding_model(
|
||||
model_name=model_name, max_context_length=max_context_length
|
||||
)
|
||||
embeddings_vectors = local_model.encode(
|
||||
prefixed_texts, normalize_embeddings=normalize_embeddings
|
||||
# Run CPU-bound embedding in a thread pool
|
||||
embeddings_vectors = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: local_model.encode(
|
||||
prefixed_texts, normalize_embeddings=normalize_embeddings
|
||||
),
|
||||
)
|
||||
embeddings = [
|
||||
embedding if isinstance(embedding, list) else embedding.tolist()
|
||||
@@ -357,27 +393,31 @@ def embed_text(
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
|
||||
async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
|
||||
cross_encoder = get_local_reranking_model(model_name)
|
||||
return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
|
||||
# Run CPU-bound reranking in a thread pool
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore
|
||||
)
|
||||
|
||||
|
||||
def cohere_rerank(
|
||||
async def cohere_rerank(
|
||||
query: str, docs: list[str], model_name: str, api_key: str
|
||||
) -> list[float]:
|
||||
cohere_client = CohereClient(api_key=api_key)
|
||||
response = cohere_client.rerank(query=query, documents=docs, model=model_name)
|
||||
cohere_client = CohereAsyncClient(api_key=api_key)
|
||||
response = await cohere_client.rerank(query=query, documents=docs, model=model_name)
|
||||
results = response.results
|
||||
sorted_results = sorted(results, key=lambda item: item.index)
|
||||
return [result.relevance_score for result in sorted_results]
|
||||
|
||||
|
||||
def litellm_rerank(
|
||||
async def litellm_rerank(
|
||||
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
|
||||
) -> list[float]:
|
||||
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
|
||||
with httpx.Client() as client:
|
||||
response = client.post(
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
api_url,
|
||||
json={
|
||||
"model": model_name,
|
||||
@@ -411,7 +451,7 @@ async def process_embed_request(
|
||||
else:
|
||||
prefix = None
|
||||
|
||||
embeddings = embed_text(
|
||||
embeddings = await embed_text(
|
||||
texts=embed_request.texts,
|
||||
model_name=embed_request.model_name,
|
||||
deployment_name=embed_request.deployment_name,
|
||||
@@ -451,7 +491,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
|
||||
|
||||
try:
|
||||
if rerank_request.provider_type is None:
|
||||
sim_scores = local_rerank(
|
||||
sim_scores = await local_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
@@ -461,7 +501,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
|
||||
if rerank_request.api_url is None:
|
||||
raise ValueError("API URL is required for LiteLLM reranking.")
|
||||
|
||||
sim_scores = litellm_rerank(
|
||||
sim_scores = await litellm_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
api_url=rerank_request.api_url,
|
||||
@@ -474,7 +514,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
|
||||
elif rerank_request.provider_type == RerankerProvider.COHERE:
|
||||
if rerank_request.api_key is None:
|
||||
raise RuntimeError("Cohere Rerank Requires an API Key")
|
||||
sim_scores = cohere_rerank(
|
||||
sim_scores = await cohere_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
|
||||
@@ -6,12 +6,12 @@ router = APIRouter(prefix="/api")
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
def healthcheck() -> Response:
|
||||
async def healthcheck() -> Response:
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.get("/gpu-status")
|
||||
def gpu_status() -> dict[str, bool | str]:
|
||||
async def gpu_status() -> dict[str, bool | str]:
|
||||
if torch.cuda.is_available():
|
||||
return {"gpu_available": True, "type": "cuda"}
|
||||
elif torch.backends.mps.is_available():
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
@@ -21,21 +22,39 @@ def simple_log_function_time(
|
||||
include_args: bool = False,
|
||||
) -> Callable[[F], F]:
|
||||
def decorator(func: F) -> F:
|
||||
@wraps(func)
|
||||
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
elapsed_time_str = str(time.time() - start_time)
|
||||
log_name = func_name or func.__name__
|
||||
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
|
||||
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
|
||||
if debug_only:
|
||||
logger.debug(final_log)
|
||||
else:
|
||||
logger.notice(final_log)
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
|
||||
return result
|
||||
@wraps(func)
|
||||
async def wrapped_async_func(*args: Any, **kwargs: Any) -> Any:
|
||||
start_time = time.time()
|
||||
result = await func(*args, **kwargs)
|
||||
elapsed_time_str = str(time.time() - start_time)
|
||||
log_name = func_name or func.__name__
|
||||
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
|
||||
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
|
||||
if debug_only:
|
||||
logger.debug(final_log)
|
||||
else:
|
||||
logger.notice(final_log)
|
||||
return result
|
||||
|
||||
return cast(F, wrapped_func)
|
||||
return cast(F, wrapped_async_func)
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def wrapped_sync_func(*args: Any, **kwargs: Any) -> Any:
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
elapsed_time_str = str(time.time() - start_time)
|
||||
log_name = func_name or func.__name__
|
||||
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
|
||||
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
|
||||
if debug_only:
|
||||
logger.debug(final_log)
|
||||
else:
|
||||
logger.notice(final_log)
|
||||
return result
|
||||
|
||||
return cast(F, wrapped_sync_func)
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -29,7 +29,7 @@ trafilatura==1.12.2
|
||||
langchain==0.1.17
|
||||
langchain-core==0.1.50
|
||||
langchain-text-splitters==0.0.1
|
||||
litellm==1.53.1
|
||||
litellm==1.54.1
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
llama-index==0.9.45
|
||||
|
||||
@@ -1,30 +1,34 @@
|
||||
black==23.3.0
|
||||
boto3-stubs[s3]==1.34.133
|
||||
celery-types==0.19.0
|
||||
cohere==5.6.1
|
||||
google-cloud-aiplatform==1.58.0
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
mypy-extensions==1.0.0
|
||||
mypy==1.8.0
|
||||
pandas-stubs==2.2.3.241009
|
||||
pandas==2.2.3
|
||||
pre-commit==3.2.2
|
||||
pytest-asyncio==0.22.0
|
||||
pytest==7.4.4
|
||||
reorder-python-imports==3.9.0
|
||||
ruff==0.0.286
|
||||
types-PyYAML==6.0.12.11
|
||||
sentence-transformers==2.6.1
|
||||
trafilatura==1.12.2
|
||||
types-beautifulsoup4==4.12.0.3
|
||||
types-html5lib==1.1.11.13
|
||||
types-oauthlib==3.2.0.9
|
||||
types-setuptools==68.0.0.3
|
||||
types-Pillow==10.2.0.20240822
|
||||
types-passlib==1.7.7.20240106
|
||||
types-Pillow==10.2.0.20240822
|
||||
types-psutil==5.9.5.17
|
||||
types-psycopg2==2.9.21.10
|
||||
types-python-dateutil==2.8.19.13
|
||||
types-pytz==2023.3.1.1
|
||||
types-PyYAML==6.0.12.11
|
||||
types-regex==2023.3.23.1
|
||||
types-requests==2.28.11.17
|
||||
types-retry==0.9.9.3
|
||||
types-setuptools==68.0.0.3
|
||||
types-urllib3==1.26.25.11
|
||||
trafilatura==1.12.2
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
boto3-stubs[s3]==1.34.133
|
||||
pandas==2.2.3
|
||||
pandas-stubs==2.2.3.241009
|
||||
cohere==5.6.1
|
||||
voyageai==0.2.3
|
||||
|
||||
@@ -12,5 +12,5 @@ torch==2.2.0
|
||||
transformers==4.39.2
|
||||
uvicorn==0.21.1
|
||||
voyageai==0.2.3
|
||||
litellm==1.50.2
|
||||
litellm==1.54.1
|
||||
sentry-sdk[fastapi,celery,starlette]==2.14.0
|
||||
@@ -42,7 +42,7 @@ class PersonaManager:
|
||||
"is_public": is_public,
|
||||
"llm_filter_extraction": llm_filter_extraction,
|
||||
"recency_bias": recency_bias,
|
||||
"prompt_ids": prompt_ids or [],
|
||||
"prompt_ids": prompt_ids or [0],
|
||||
"document_set_ids": document_set_ids or [],
|
||||
"tool_ids": tool_ids or [],
|
||||
"llm_model_provider_override": llm_model_provider_override,
|
||||
|
||||
@@ -69,8 +69,10 @@ class TenantManager:
|
||||
return AllUsersResponse(
|
||||
accepted=[FullUserSnapshot(**user) for user in data["accepted"]],
|
||||
invited=[InvitedUserSnapshot(**user) for user in data["invited"]],
|
||||
slack_users=[FullUserSnapshot(**user) for user in data["slack_users"]],
|
||||
accepted_pages=data["accepted_pages"],
|
||||
invited_pages=data["invited_pages"],
|
||||
slack_users_pages=data["slack_users_pages"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -130,8 +130,10 @@ class UserManager:
|
||||
all_users = AllUsersResponse(
|
||||
accepted=[FullUserSnapshot(**user) for user in data["accepted"]],
|
||||
invited=[InvitedUserSnapshot(**user) for user in data["invited"]],
|
||||
slack_users=[FullUserSnapshot(**user) for user in data["slack_users"]],
|
||||
accepted_pages=data["accepted_pages"],
|
||||
invited_pages=data["invited_pages"],
|
||||
slack_users_pages=data["slack_users_pages"],
|
||||
)
|
||||
for accepted_user in all_users.accepted:
|
||||
if accepted_user.email == user.email and accepted_user.id == user.id:
|
||||
|
||||
@@ -3,6 +3,8 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
@@ -23,7 +25,7 @@ from tests.integration.common_utils.vespa import vespa_fixture
|
||||
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
|
||||
|
||||
|
||||
# @pytest.mark.xfail(reason="flaky - see DAN-789 for example", strict=False)
|
||||
@pytest.mark.xfail(reason="flaky - see DAN-789 for example", strict=False)
|
||||
def test_slack_permission_sync(
|
||||
reset: None,
|
||||
vespa_client: vespa_fixture,
|
||||
@@ -65,7 +67,6 @@ def test_slack_permission_sync(
|
||||
input_type=InputType.POLL,
|
||||
source=DocumentSource.SLACK,
|
||||
connector_specific_config={
|
||||
"workspace": "onyx-test-workspace",
|
||||
"channels": [public_channel["name"], private_channel["name"]],
|
||||
},
|
||||
access_type=AccessType.SYNC,
|
||||
@@ -279,7 +280,6 @@ def test_slack_group_permission_sync(
|
||||
input_type=InputType.POLL,
|
||||
source=DocumentSource.SLACK,
|
||||
connector_specific_config={
|
||||
"workspace": "onyx-test-workspace",
|
||||
"channels": [private_channel["name"]],
|
||||
},
|
||||
access_type=AccessType.SYNC,
|
||||
|
||||
@@ -61,7 +61,6 @@ def test_slack_prune(
|
||||
input_type=InputType.POLL,
|
||||
source=DocumentSource.SLACK,
|
||||
connector_specific_config={
|
||||
"workspace": "onyx-test-workspace",
|
||||
"channels": [public_channel["name"], private_channel["name"]],
|
||||
},
|
||||
access_type=AccessType.PUBLIC,
|
||||
|
||||
@@ -27,13 +27,6 @@ def test_limited(reset: None) -> None:
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# test basic endpoints
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/input_prompt",
|
||||
headers=api_key.headers,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
# test admin endpoints
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/api-key",
|
||||
|
||||
@@ -72,8 +72,10 @@ def process_text(
|
||||
processor = CitationProcessor(
|
||||
context_docs=mock_docs,
|
||||
doc_id_to_rank_map=mapping,
|
||||
display_doc_order_dict=mock_doc_id_to_rank_map,
|
||||
stop_stream=None,
|
||||
)
|
||||
|
||||
result: list[DanswerAnswerPiece | CitationInfo] = []
|
||||
for token in tokens:
|
||||
result.extend(processor.process_token(token))
|
||||
@@ -86,6 +88,7 @@ def process_text(
|
||||
final_answer_text += piece.answer_piece or ""
|
||||
elif isinstance(piece, CitationInfo):
|
||||
citations.append(piece)
|
||||
|
||||
return final_answer_text, citations
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.stream_processing.citation_processing import CitationProcessor
|
||||
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
|
||||
"""
|
||||
This module contains tests for the citation extraction functionality in Danswer,
|
||||
specifically the substitution of the number of document cited in the UI. (The LLM
|
||||
will see the sources post re-ranking and relevance check, the UI before these steps.)
|
||||
This module is a derivative of test_citation_processing.py.
|
||||
|
||||
The tests focusses specifically on the substitution of the number of document cited in the UI.
|
||||
|
||||
Key components:
|
||||
- mock_docs: A list of mock LlmDoc objects used for testing.
|
||||
- mock_doc_mapping: A dictionary mapping document IDs to their initial ranks.
|
||||
- mock_doc_mapping_rerank: A dictionary mapping document IDs to their ranks after re-ranking/relevance check.
|
||||
- process_text: A helper function that simulates the citation extraction process.
|
||||
- test_citation_extraction: A parametrized test function covering various citation scenarios.
|
||||
|
||||
To add new test cases:
|
||||
1. Add a new tuple to the @pytest.mark.parametrize decorator of test_citation_extraction.
|
||||
2. Each tuple should contain:
|
||||
- A descriptive test name (string)
|
||||
- Input tokens (list of strings)
|
||||
- Expected output text (string)
|
||||
- Expected citations (list of document IDs)
|
||||
"""
|
||||
|
||||
|
||||
mock_docs = [
|
||||
LlmDoc(
|
||||
document_id=f"doc_{int(id/2)}",
|
||||
content="Document is a doc",
|
||||
blurb=f"Document #{id}",
|
||||
semantic_identifier=f"Doc {id}",
|
||||
source_type=DocumentSource.WEB,
|
||||
metadata={},
|
||||
updated_at=datetime.now(),
|
||||
link=f"https://{int(id/2)}.com" if int(id / 2) % 2 == 0 else None,
|
||||
source_links={0: "https://mintlify.com/docs/settings/broken-links"},
|
||||
match_highlights=[],
|
||||
)
|
||||
for id in range(10)
|
||||
]
|
||||
|
||||
mock_doc_mapping = {
|
||||
"doc_0": 1,
|
||||
"doc_1": 2,
|
||||
"doc_2": 3,
|
||||
"doc_3": 4,
|
||||
"doc_4": 5,
|
||||
"doc_5": 6,
|
||||
}
|
||||
|
||||
mock_doc_mapping_rerank = {
|
||||
"doc_0": 2,
|
||||
"doc_1": 1,
|
||||
"doc_2": 4,
|
||||
"doc_3": 3,
|
||||
"doc_4": 6,
|
||||
"doc_5": 5,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data() -> tuple[list[LlmDoc], dict[str, int]]:
|
||||
return mock_docs, mock_doc_mapping
|
||||
|
||||
|
||||
def process_text(
|
||||
tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]]
|
||||
) -> tuple[str, list[CitationInfo]]:
|
||||
mock_docs, mock_doc_id_to_rank_map = mock_data
|
||||
mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
|
||||
processor = CitationProcessor(
|
||||
context_docs=mock_docs,
|
||||
doc_id_to_rank_map=mapping,
|
||||
display_doc_order_dict=mock_doc_mapping_rerank,
|
||||
stop_stream=None,
|
||||
)
|
||||
|
||||
result: list[DanswerAnswerPiece | CitationInfo] = []
|
||||
for token in tokens:
|
||||
result.extend(processor.process_token(token))
|
||||
result.extend(processor.process_token(None))
|
||||
|
||||
final_answer_text = ""
|
||||
citations = []
|
||||
for piece in result:
|
||||
if isinstance(piece, DanswerAnswerPiece):
|
||||
final_answer_text += piece.answer_piece or ""
|
||||
elif isinstance(piece, CitationInfo):
|
||||
citations.append(piece)
|
||||
|
||||
return final_answer_text, citations
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_name, input_tokens, expected_text, expected_citations",
|
||||
[
|
||||
(
|
||||
"Single citation",
|
||||
["Gro", "wth! [", "1", "]", "."],
|
||||
"Growth! [[2]](https://0.com).",
|
||||
["doc_0"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_citation_substitution(
|
||||
mock_data: tuple[list[LlmDoc], dict[str, int]],
|
||||
test_name: str,
|
||||
input_tokens: list[str],
|
||||
expected_text: str,
|
||||
expected_citations: list[str],
|
||||
) -> None:
|
||||
final_answer_text, citations = process_text(input_tokens, mock_data)
|
||||
assert (
|
||||
final_answer_text.strip() == expected_text.strip()
|
||||
), f"Test '{test_name}' failed: Final answer text does not match expected output."
|
||||
assert [
|
||||
citation.document_id for citation in citations
|
||||
] == expected_citations, (
|
||||
f"Test '{test_name}' failed: Citations do not match expected output."
|
||||
)
|
||||
120
backend/tests/unit/danswer/indexing/test_indexing_pipeline.py
Normal file
120
backend/tests/unit/danswer/indexing/test_indexing_pipeline.py
Normal file
@@ -0,0 +1,120 @@
|
||||
from typing import List
|
||||
|
||||
from danswer.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import DocumentSource
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.indexing.indexing_pipeline import filter_documents
|
||||
|
||||
|
||||
def create_test_document(
|
||||
doc_id: str = "test_id",
|
||||
title: str | None = "Test Title",
|
||||
semantic_id: str = "test_semantic_id",
|
||||
sections: List[Section] | None = None,
|
||||
) -> Document:
|
||||
if sections is None:
|
||||
sections = [Section(text="Test content", link="test_link")]
|
||||
return Document(
|
||||
id=doc_id,
|
||||
title=title,
|
||||
semantic_identifier=semantic_id,
|
||||
sections=sections,
|
||||
source=DocumentSource.FILE,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
|
||||
def test_filter_documents_empty_title_and_content() -> None:
|
||||
doc = create_test_document(
|
||||
title="", semantic_id="", sections=[Section(text="", link="test_link")]
|
||||
)
|
||||
result = filter_documents([doc])
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_filter_documents_empty_title_with_content() -> None:
|
||||
doc = create_test_document(
|
||||
title="", sections=[Section(text="Valid content", link="test_link")]
|
||||
)
|
||||
result = filter_documents([doc])
|
||||
assert len(result) == 1
|
||||
assert result[0].id == "test_id"
|
||||
|
||||
|
||||
def test_filter_documents_empty_content_with_title() -> None:
|
||||
doc = create_test_document(
|
||||
title="Valid Title", sections=[Section(text="", link="test_link")]
|
||||
)
|
||||
result = filter_documents([doc])
|
||||
assert len(result) == 1
|
||||
assert result[0].id == "test_id"
|
||||
|
||||
|
||||
def test_filter_documents_exceeding_max_chars() -> None:
|
||||
if not MAX_DOCUMENT_CHARS: # Skip if no max chars configured
|
||||
return
|
||||
long_text = "a" * (MAX_DOCUMENT_CHARS + 1)
|
||||
doc = create_test_document(sections=[Section(text=long_text, link="test_link")])
|
||||
result = filter_documents([doc])
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_filter_documents_valid_document() -> None:
|
||||
doc = create_test_document(
|
||||
title="Valid Title", sections=[Section(text="Valid content", link="test_link")]
|
||||
)
|
||||
result = filter_documents([doc])
|
||||
assert len(result) == 1
|
||||
assert result[0].id == "test_id"
|
||||
assert result[0].title == "Valid Title"
|
||||
|
||||
|
||||
def test_filter_documents_whitespace_only() -> None:
|
||||
doc = create_test_document(
|
||||
title=" ", semantic_id=" ", sections=[Section(text=" ", link="test_link")]
|
||||
)
|
||||
result = filter_documents([doc])
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_filter_documents_semantic_id_no_title() -> None:
|
||||
doc = create_test_document(
|
||||
title=None,
|
||||
semantic_id="Valid Semantic ID",
|
||||
sections=[Section(text="Valid content", link="test_link")],
|
||||
)
|
||||
result = filter_documents([doc])
|
||||
assert len(result) == 1
|
||||
assert result[0].semantic_identifier == "Valid Semantic ID"
|
||||
|
||||
|
||||
def test_filter_documents_multiple_sections() -> None:
|
||||
doc = create_test_document(
|
||||
sections=[
|
||||
Section(text="Content 1", link="test_link"),
|
||||
Section(text="Content 2", link="test_link"),
|
||||
Section(text="Content 3", link="test_link"),
|
||||
]
|
||||
)
|
||||
result = filter_documents([doc])
|
||||
assert len(result) == 1
|
||||
assert len(result[0].sections) == 3
|
||||
|
||||
|
||||
def test_filter_documents_multiple_documents() -> None:
|
||||
docs = [
|
||||
create_test_document(doc_id="1", title="Title 1"),
|
||||
create_test_document(
|
||||
doc_id="2", title="", sections=[Section(text="", link="test_link")]
|
||||
), # Should be filtered
|
||||
create_test_document(doc_id="3", title="Title 3"),
|
||||
]
|
||||
result = filter_documents(docs)
|
||||
assert len(result) == 2
|
||||
assert {doc.id for doc in result} == {"1", "3"}
|
||||
|
||||
|
||||
def test_filter_documents_empty_batch() -> None:
|
||||
result = filter_documents([])
|
||||
assert len(result) == 0
|
||||
198
backend/tests/unit/model_server/test_embedding.py
Normal file
198
backend/tests/unit/model_server/test_embedding.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from litellm.exceptions import RateLimitError
|
||||
|
||||
from model_server.encoders import CloudEmbedding
|
||||
from model_server.encoders import embed_text
|
||||
from model_server.encoders import local_rerank
|
||||
from model_server.encoders import process_embed_request
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_http_client() -> AsyncGenerator[AsyncMock, None]:
|
||||
with patch("httpx.AsyncClient") as mock:
|
||||
client = AsyncMock(spec=AsyncClient)
|
||||
mock.return_value = client
|
||||
client.post = AsyncMock()
|
||||
async with client as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embeddings() -> List[List[float]]:
|
||||
return [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cloud_embedding_context_manager() -> None:
|
||||
async with CloudEmbedding("fake-key", EmbeddingProvider.OPENAI) as embedding:
|
||||
assert not embedding._closed
|
||||
assert embedding._closed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cloud_embedding_explicit_close() -> None:
|
||||
embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI)
|
||||
assert not embedding._closed
|
||||
await embedding.aclose()
|
||||
assert embedding._closed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_embedding(
|
||||
mock_http_client: AsyncMock, sample_embeddings: List[List[float]]
|
||||
) -> None:
|
||||
with patch("openai.AsyncOpenAI") as mock_openai:
|
||||
mock_client = AsyncMock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock(embedding=emb) for emb in sample_embeddings]
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI)
|
||||
result = await embedding._embed_openai(
|
||||
["test1", "test2"], "text-embedding-ada-002"
|
||||
)
|
||||
|
||||
assert result == sample_embeddings
|
||||
mock_client.embeddings.create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_text_cloud_provider() -> None:
|
||||
with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed:
|
||||
mock_embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_embed.side_effect = AsyncMock(return_value=[[0.1, 0.2], [0.3, 0.4]])
|
||||
|
||||
result = await embed_text(
|
||||
texts=["test1", "test2"],
|
||||
text_type=EmbedTextType.QUERY,
|
||||
model_name="fake-model",
|
||||
deployment_name=None,
|
||||
max_context_length=512,
|
||||
normalize_embeddings=True,
|
||||
api_key="fake-key",
|
||||
provider_type=EmbeddingProvider.OPENAI,
|
||||
prefix=None,
|
||||
api_url=None,
|
||||
api_version=None,
|
||||
)
|
||||
|
||||
assert result == [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_embed.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_text_local_model() -> None:
|
||||
with patch("model_server.encoders.get_embedding_model") as mock_get_model:
|
||||
mock_model = MagicMock()
|
||||
mock_model.encode.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_get_model.return_value = mock_model
|
||||
|
||||
result = await embed_text(
|
||||
texts=["test1", "test2"],
|
||||
text_type=EmbedTextType.QUERY,
|
||||
model_name="fake-local-model",
|
||||
deployment_name=None,
|
||||
max_context_length=512,
|
||||
normalize_embeddings=True,
|
||||
api_key=None,
|
||||
provider_type=None,
|
||||
prefix=None,
|
||||
api_url=None,
|
||||
api_version=None,
|
||||
)
|
||||
|
||||
assert result == [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_model.encode.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_rerank() -> None:
|
||||
with patch("model_server.encoders.get_local_reranking_model") as mock_get_model:
|
||||
mock_model = MagicMock()
|
||||
mock_array = MagicMock()
|
||||
mock_array.tolist.return_value = [0.8, 0.6]
|
||||
mock_model.predict.return_value = mock_array
|
||||
mock_get_model.return_value = mock_model
|
||||
|
||||
result = await local_rerank(
|
||||
query="test query", docs=["doc1", "doc2"], model_name="fake-rerank-model"
|
||||
)
|
||||
|
||||
assert result == [0.8, 0.6]
|
||||
mock_model.predict.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_handling() -> None:
|
||||
with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed:
|
||||
mock_embed.side_effect = RateLimitError(
|
||||
"Rate limit exceeded", llm_provider="openai", model="fake-model"
|
||||
)
|
||||
|
||||
with pytest.raises(RateLimitError):
|
||||
await embed_text(
|
||||
texts=["test"],
|
||||
text_type=EmbedTextType.QUERY,
|
||||
model_name="fake-model",
|
||||
deployment_name=None,
|
||||
max_context_length=512,
|
||||
normalize_embeddings=True,
|
||||
api_key="fake-key",
|
||||
provider_type=EmbeddingProvider.OPENAI,
|
||||
prefix=None,
|
||||
api_url=None,
|
||||
api_version=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_embeddings() -> None:
|
||||
def mock_encode(*args: Any, **kwargs: Any) -> List[List[float]]:
|
||||
time.sleep(5)
|
||||
return [[0.1, 0.2, 0.3]]
|
||||
|
||||
test_req = EmbedRequest(
|
||||
texts=["test"],
|
||||
model_name="'nomic-ai/nomic-embed-text-v1'",
|
||||
deployment_name=None,
|
||||
max_context_length=512,
|
||||
normalize_embeddings=True,
|
||||
api_key=None,
|
||||
provider_type=None,
|
||||
text_type=EmbedTextType.QUERY,
|
||||
manual_query_prefix=None,
|
||||
manual_passage_prefix=None,
|
||||
api_url=None,
|
||||
api_version=None,
|
||||
)
|
||||
|
||||
with patch("model_server.encoders.get_embedding_model") as mock_get_model:
|
||||
mock_model = MagicMock()
|
||||
mock_model.encode = mock_encode
|
||||
mock_get_model.return_value = mock_model
|
||||
start_time = time.time()
|
||||
|
||||
tasks = [process_embed_request(test_req) for _ in range(5)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# 5 * 5 seconds = 25 seconds, this test ensures that the embeddings are at least yielding the thread
|
||||
# However, the developer may still introduce unnecessary blocking above the mock and this test will
|
||||
# still pass as long as it's less than (7 - 5) / 5 seconds
|
||||
assert end_time - start_time < 7
|
||||
2
ct.yaml
2
ct.yaml
@@ -6,7 +6,7 @@ chart-dirs:
|
||||
|
||||
# must be kept in sync with Chart.yaml
|
||||
chart-repos:
|
||||
- vespa=https://danswer-ai.github.io/vespa-helm-charts
|
||||
- vespa=https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
- postgresql=https://charts.bitnami.com/bitnami
|
||||
|
||||
helm-extra-args: --debug --timeout 600s
|
||||
|
||||
@@ -183,6 +183,13 @@ services:
|
||||
- GONG_CONNECTOR_START_TIME=${GONG_CONNECTOR_START_TIME:-}
|
||||
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
|
||||
- GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-}
|
||||
- MAX_DOCUMENT_CHARS=${MAX_DOCUMENT_CHARS:-}
|
||||
- MAX_FILE_SIZE_BYTES=${MAX_FILE_SIZE_BYTES:-}
|
||||
# Egnyte OAuth Configs
|
||||
- EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-}
|
||||
- EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-}
|
||||
- EGNYTE_BASE_DOMAIN=${EGNYTE_BASE_DOMAIN:-}
|
||||
- EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-}
|
||||
# Celery Configs (defaults are set in the supervisord.conf file.
|
||||
# prefer doing that to have one source of defaults)
|
||||
- CELERY_WORKER_INDEXING_CONCURRENCY=${CELERY_WORKER_INDEXING_CONCURRENCY:-}
|
||||
|
||||
74
deployment/docker_compose/docker-compose.resources.yml
Normal file
74
deployment/docker_compose/docker-compose.resources.yml
Normal file
@@ -0,0 +1,74 @@
|
||||
# Docker service resource limits. Most are commented out by default.
|
||||
# 'background' service has preset (override-able) limits due to variable resource needs.
|
||||
# Uncomment and set env vars for specific service limits.
|
||||
# See: https://docs.danswer.dev/deployment/resource-sizing for details.
|
||||
|
||||
services:
|
||||
background:
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
cpus: ${BACKGROUND_CPU_LIMIT:-4}
|
||||
memory: ${BACKGROUND_MEM_LIMIT:-4g}
|
||||
# reservations:
|
||||
# cpus: ${BACKGROUND_CPU_RESERVATION}
|
||||
# memory: ${BACKGROUND_MEM_RESERVATION}
|
||||
|
||||
# nginx:
|
||||
# deploy:
|
||||
# resources:
|
||||
# limits:
|
||||
# cpus: ${NGINX_CPU_LIMIT}
|
||||
# memory: ${NGINX_MEM_LIMIT}
|
||||
# reservations:
|
||||
# cpus: ${NGINX_CPU_RESERVATION}
|
||||
# memory: ${NGINX_MEM_RESERVATION}
|
||||
# api_server:
|
||||
# deploy:
|
||||
# resources:
|
||||
# limits:
|
||||
# cpus: ${API_SERVER_CPU_LIMIT}
|
||||
# memory: ${API_SERVER_MEM_LIMIT}
|
||||
# reservations:
|
||||
# cpus: ${API_SERVER_CPU_RESERVATION}
|
||||
# memory: ${API_SERVER_MEM_RESERVATION}
|
||||
|
||||
# index:
|
||||
# deploy:
|
||||
# resources:
|
||||
# limits:
|
||||
# cpus: ${VESPA_CPU_LIMIT}
|
||||
# memory: ${VESPA_MEM_LIMIT}
|
||||
# reservations:
|
||||
# cpus: ${VESPA_CPU_RESERVATION}
|
||||
# memory: ${VESPA_MEM_RESERVATION}
|
||||
|
||||
# inference_model_server:
|
||||
# deploy:
|
||||
# resources:
|
||||
# limits:
|
||||
# cpus: ${INFERENCE_CPU_LIMIT}
|
||||
# memory: ${INFERENCE_MEM_LIMIT}
|
||||
# reservations:
|
||||
# cpus: ${INFERENCE_CPU_RESERVATION}
|
||||
# memory: ${INFERENCE_MEM_RESERVATION}
|
||||
|
||||
# indexing_model_server:
|
||||
# deploy:
|
||||
# resources:
|
||||
# limits:
|
||||
# cpus: ${INDEXING_CPU_LIMIT}
|
||||
# memory: ${INDEXING_MEM_LIMIT}
|
||||
# reservations:
|
||||
# cpus: ${INDEXING_CPU_RESERVATION}
|
||||
# memory: ${INDEXING_MEM_RESERVATION}
|
||||
|
||||
# relational_db:
|
||||
# deploy:
|
||||
# resources:
|
||||
# limits:
|
||||
# cpus: ${POSTGRES_CPU_LIMIT}
|
||||
# memory: ${POSTGRES_MEM_LIMIT}
|
||||
# reservations:
|
||||
# cpus: ${POSTGRES_CPU_RESERVATION}
|
||||
# memory: ${POSTGRES_MEM_RESERVATION}
|
||||
@@ -3,13 +3,13 @@ dependencies:
|
||||
repository: https://charts.bitnami.com/bitnami
|
||||
version: 14.3.1
|
||||
- name: vespa
|
||||
repository: https://danswer-ai.github.io/vespa-helm-charts
|
||||
version: 0.2.16
|
||||
repository: https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
version: 0.2.18
|
||||
- name: nginx
|
||||
repository: oci://registry-1.docker.io/bitnamicharts
|
||||
version: 15.14.0
|
||||
- name: redis
|
||||
repository: https://charts.bitnami.com/bitnami
|
||||
version: 20.1.0
|
||||
digest: sha256:711bbb76ba6ab604a36c9bf1839ab6faa5610afb21e535afd933c78f2d102232
|
||||
generated: "2024-11-07T09:39:30.17171-08:00"
|
||||
digest: sha256:5c9eb3d55d5f8e3beb64f26d26f686c8d62755daa10e2e6d87530bdf2fbbf957
|
||||
generated: "2024-12-10T10:47:35.812483-08:00"
|
||||
|
||||
@@ -23,8 +23,8 @@ dependencies:
|
||||
repository: https://charts.bitnami.com/bitnami
|
||||
condition: postgresql.enabled
|
||||
- name: vespa
|
||||
version: 0.2.16
|
||||
repository: https://danswer-ai.github.io/vespa-helm-charts
|
||||
version: 0.2.18
|
||||
repository: https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
condition: vespa.enabled
|
||||
- name: nginx
|
||||
version: 15.14.0
|
||||
|
||||
@@ -61,6 +61,8 @@ data:
|
||||
WEB_CONNECTOR_VALIDATE_URLS: ""
|
||||
GONG_CONNECTOR_START_TIME: ""
|
||||
NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP: ""
|
||||
MAX_DOCUMENT_CHARS: ""
|
||||
MAX_FILE_SIZE_BYTES: ""
|
||||
# DanswerBot SlackBot Configs
|
||||
DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER: ""
|
||||
DANSWER_BOT_DISPLAY_ERROR_MSGS: ""
|
||||
|
||||
@@ -66,6 +66,9 @@ ARG NEXT_PUBLIC_POSTHOG_HOST
|
||||
ENV NEXT_PUBLIC_POSTHOG_KEY=${NEXT_PUBLIC_POSTHOG_KEY}
|
||||
ENV NEXT_PUBLIC_POSTHOG_HOST=${NEXT_PUBLIC_POSTHOG_HOST}
|
||||
|
||||
ARG NEXT_PUBLIC_CLOUD_ENABLED
|
||||
ENV NEXT_PUBLIC_CLOUD_ENABLED=${NEXT_PUBLIC_CLOUD_ENABLED}
|
||||
|
||||
ARG NEXT_PUBLIC_SENTRY_DSN
|
||||
ENV NEXT_PUBLIC_SENTRY_DSN=${NEXT_PUBLIC_SENTRY_DSN}
|
||||
|
||||
@@ -138,6 +141,9 @@ ARG NEXT_PUBLIC_POSTHOG_HOST
|
||||
ENV NEXT_PUBLIC_POSTHOG_KEY=${NEXT_PUBLIC_POSTHOG_KEY}
|
||||
ENV NEXT_PUBLIC_POSTHOG_HOST=${NEXT_PUBLIC_POSTHOG_HOST}
|
||||
|
||||
ARG NEXT_PUBLIC_CLOUD_ENABLED
|
||||
ENV NEXT_PUBLIC_CLOUD_ENABLED=${NEXT_PUBLIC_CLOUD_ENABLED}
|
||||
|
||||
ARG NEXT_PUBLIC_SENTRY_DSN
|
||||
ENV NEXT_PUBLIC_SENTRY_DSN=${NEXT_PUBLIC_SENTRY_DSN}
|
||||
|
||||
|
||||
4
web/package-lock.json
generated
4
web/package-lock.json
generated
@@ -1,12 +1,12 @@
|
||||
{
|
||||
"name": "qa",
|
||||
"version": "0.2.0-dev",
|
||||
"version": "1.0.0-dev",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"name": "qa",
|
||||
"version": "0.2.0-dev",
|
||||
"version": "1.0.0-dev",
|
||||
"dependencies": {
|
||||
"@dnd-kit/core": "^6.1.0",
|
||||
"@dnd-kit/modifiers": "^7.0.0",
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "qa",
|
||||
"version": "0.2.0-dev",
|
||||
"version": "1.0.0-dev",
|
||||
"version-comment": "version field must be SemVer or chromatic will barf",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
|
||||
BIN
web/public/Egnyte.png
Normal file
BIN
web/public/Egnyte.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 12 KiB |
BIN
web/public/Wikipedia.png
Normal file
BIN
web/public/Wikipedia.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 769 KiB |
File diff suppressed because one or more lines are too long
|
Before Width: | Height: | Size: 164 KiB |
@@ -82,7 +82,7 @@ export const DanswerApiKeyForm = ({
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting, values, setFieldValue }) => (
|
||||
<Form>
|
||||
<Form className="w-full overflow-visible">
|
||||
<Text className="mb-4 text-lg">
|
||||
Choose a memorable name for your API key. This is optional and
|
||||
can be added or changed later!
|
||||
|
||||
@@ -45,9 +45,6 @@ function NewApiKeyModal({
|
||||
<div className="px-8 py-8">
|
||||
<div className="flex w-full border-b border-border mb-4 pb-4">
|
||||
<Title>New API Key</Title>
|
||||
<div onClick={onClose} className="ml-auto p-1 rounded hover:bg-hover">
|
||||
<FiX size={18} />
|
||||
</div>
|
||||
</div>
|
||||
<div className="h-32">
|
||||
<Text className="mb-4">
|
||||
|
||||
@@ -96,6 +96,16 @@ export function SlackBotTable({ slackBots }: { slackBots: SlackBot[] }) {
|
||||
</ClickableTableRow>
|
||||
);
|
||||
})}
|
||||
{slackBots.length === 0 && (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={4}
|
||||
className="text-center text-muted-foreground"
|
||||
>
|
||||
Please add a New Slack Bot to begin chatting with Danswer!
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
{slackBots.length > NUM_IN_PAGE && (
|
||||
|
||||
@@ -275,8 +275,9 @@ export function CustomLLMProviderUpdateForm({
|
||||
<SubLabel>
|
||||
<>
|
||||
<div>
|
||||
Additional configurations needed by the model provider. Are
|
||||
passed to litellm via environment variables.
|
||||
Additional configurations needed by the model provider. These
|
||||
are passed to litellm via environment + as arguments into the
|
||||
`completion` call.
|
||||
</div>
|
||||
|
||||
<div className="mt-2">
|
||||
@@ -290,14 +291,14 @@ export function CustomLLMProviderUpdateForm({
|
||||
<FieldArray
|
||||
name="custom_config_list"
|
||||
render={(arrayHelpers: ArrayHelpers<any[]>) => (
|
||||
<div>
|
||||
<div className="w-full">
|
||||
{formikProps.values.custom_config_list.map((_, index) => {
|
||||
return (
|
||||
<div
|
||||
key={index}
|
||||
className={index === 0 ? "mt-2" : "mt-6"}
|
||||
className={(index === 0 ? "mt-2" : "mt-6") + " w-full"}
|
||||
>
|
||||
<div className="flex">
|
||||
<div className="flex w-full">
|
||||
<div className="w-full mr-6 border border-border p-3 rounded">
|
||||
<div>
|
||||
<Label>Key</Label>
|
||||
@@ -457,6 +458,7 @@ export function CustomLLMProviderUpdateForm({
|
||||
<Button
|
||||
type="button"
|
||||
variant="destructive"
|
||||
className="ml-3"
|
||||
icon={FiTrash}
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
|
||||
@@ -2,6 +2,11 @@ import CardSection from "@/components/admin/CardSection";
|
||||
import { getNameFromPath } from "@/lib/fileUtils";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import Title from "@/components/ui/title";
|
||||
import { EditIcon } from "@/components/icons/icons";
|
||||
|
||||
import { useState } from "react";
|
||||
import { ChevronUpIcon } from "lucide-react";
|
||||
import { ChevronDownIcon } from "@/components/icons/icons";
|
||||
|
||||
function convertObjectToString(obj: any): string | any {
|
||||
// Check if obj is an object and not an array or null
|
||||
@@ -39,14 +44,83 @@ function buildConfigEntries(
|
||||
return obj;
|
||||
}
|
||||
|
||||
function ConfigItem({ label, value }: { label: string; value: any }) {
|
||||
const [isExpanded, setIsExpanded] = useState(false);
|
||||
const isExpandable = Array.isArray(value) && value.length > 5;
|
||||
|
||||
const renderValue = () => {
|
||||
if (Array.isArray(value)) {
|
||||
const displayedItems = isExpanded ? value : value.slice(0, 5);
|
||||
return (
|
||||
<ul className="list-disc max-w-full pl-4 mt-2 overflow-x-auto">
|
||||
{displayedItems.map((item, index) => (
|
||||
<li
|
||||
key={index}
|
||||
className="mb-1 max-w-full overflow-hidden text-right text-ellipsis whitespace-nowrap"
|
||||
>
|
||||
{convertObjectToString(item)}
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
);
|
||||
} else if (typeof value === "object" && value !== null) {
|
||||
return (
|
||||
<div className="mt-2 overflow-x-auto">
|
||||
{Object.entries(value).map(([key, val]) => (
|
||||
<div key={key} className="mb-1">
|
||||
<span className="font-semibold">{key}:</span>{" "}
|
||||
{convertObjectToString(val)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return convertObjectToString(value) || "-";
|
||||
};
|
||||
|
||||
return (
|
||||
<li className="w-full py-2">
|
||||
<div className="flex items-center justify-between w-full">
|
||||
<span className="mb-2">{label}</span>
|
||||
<div className="mt-2 overflow-x-auto w-fit">
|
||||
{renderValue()}
|
||||
|
||||
{isExpandable && (
|
||||
<button
|
||||
onClick={() => setIsExpanded(!isExpanded)}
|
||||
className="mt-2 ml-auto text-text-600 hover:text-text-800 flex items-center"
|
||||
>
|
||||
{isExpanded ? (
|
||||
<>
|
||||
<ChevronUpIcon className="h-4 w-4 mr-1" />
|
||||
Show less
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<ChevronDownIcon className="h-4 w-4 mr-1" />
|
||||
Show all ({value.length} items)
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</li>
|
||||
);
|
||||
}
|
||||
|
||||
export function AdvancedConfigDisplay({
|
||||
pruneFreq,
|
||||
refreshFreq,
|
||||
indexingStart,
|
||||
onRefreshEdit,
|
||||
onPruningEdit,
|
||||
}: {
|
||||
pruneFreq: number | null;
|
||||
refreshFreq: number | null;
|
||||
indexingStart: Date | null;
|
||||
onRefreshEdit: () => void;
|
||||
onPruningEdit: () => void;
|
||||
}) {
|
||||
const formatRefreshFrequency = (seconds: number | null): string => {
|
||||
if (seconds === null) return "-";
|
||||
@@ -75,14 +149,21 @@ export function AdvancedConfigDisplay({
|
||||
<>
|
||||
<Title className="mt-8 mb-2">Advanced Configuration</Title>
|
||||
<CardSection>
|
||||
<ul className="w-full text-sm divide-y divide-neutral-200 dark:divide-neutral-700">
|
||||
<ul className="w-full text-sm divide-y divide-background-200 dark:divide-background-700">
|
||||
{pruneFreq && (
|
||||
<li
|
||||
key={0}
|
||||
className="w-full flex justify-between items-center py-2"
|
||||
>
|
||||
<span>Pruning Frequency</span>
|
||||
<span>{formatPruneFrequency(pruneFreq)}</span>
|
||||
<span className="ml-auto w-24">
|
||||
{formatPruneFrequency(pruneFreq)}
|
||||
</span>
|
||||
<span className="w-8 text-right">
|
||||
<button onClick={() => onPruningEdit()}>
|
||||
<EditIcon size={12} />
|
||||
</button>
|
||||
</span>
|
||||
</li>
|
||||
)}
|
||||
{refreshFreq && (
|
||||
@@ -91,7 +172,14 @@ export function AdvancedConfigDisplay({
|
||||
className="w-full flex justify-between items-center py-2"
|
||||
>
|
||||
<span>Refresh Frequency</span>
|
||||
<span>{formatRefreshFrequency(refreshFreq)}</span>
|
||||
<span className="ml-auto w-24">
|
||||
{formatRefreshFrequency(refreshFreq)}
|
||||
</span>
|
||||
<span className="w-8 text-right">
|
||||
<button onClick={() => onRefreshEdit()}>
|
||||
<EditIcon size={12} />
|
||||
</button>
|
||||
</span>
|
||||
</li>
|
||||
)}
|
||||
{indexingStart && (
|
||||
@@ -127,15 +215,9 @@ export function ConfigDisplay({
|
||||
<>
|
||||
<Title className="mb-2">Configuration</Title>
|
||||
<CardSection>
|
||||
<ul className="w-full text-sm divide-y divide-neutral-200 dark:divide-neutral-700">
|
||||
<ul className="w-full text-sm divide-y divide-background-200 dark:divide-background-700">
|
||||
{configEntries.map(([key, value]) => (
|
||||
<li
|
||||
key={key}
|
||||
className="w-full flex justify-between items-center py-2"
|
||||
>
|
||||
<span>{key}</span>
|
||||
<span>{convertObjectToString(value) || "-"}</span>
|
||||
</li>
|
||||
<ConfigItem key={key} label={key} value={value} />
|
||||
))}
|
||||
</ul>
|
||||
</CardSection>
|
||||
|
||||
@@ -7,7 +7,10 @@ import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { CCPairStatus } from "@/components/Status";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import CredentialSection from "@/components/credentials/CredentialSection";
|
||||
import { updateConnectorCredentialPairName } from "@/lib/connector";
|
||||
import {
|
||||
updateConnectorCredentialPairName,
|
||||
updateConnectorCredentialPairProperty,
|
||||
} from "@/lib/connector";
|
||||
import { credentialTemplates } from "@/lib/connectors/credentials";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
@@ -26,12 +29,33 @@ import { buildCCPairInfoUrl } from "./lib";
|
||||
import { CCPairFullInfo, ConnectorCredentialPairStatus } from "./types";
|
||||
import { EditableStringFieldDisplay } from "@/components/EditableStringFieldDisplay";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import EditPropertyModal from "@/components/modals/EditPropertyModal";
|
||||
|
||||
import * as Yup from "yup";
|
||||
|
||||
// since the uploaded files are cleaned up after some period of time
|
||||
// re-indexing will not work for the file connector. Also, it would not
|
||||
// make sense to re-index, since the files will not have changed.
|
||||
const CONNECTOR_TYPES_THAT_CANT_REINDEX: ValidSources[] = [ValidSources.File];
|
||||
|
||||
// synchronize these validations with the SQLAlchemy connector class until we have a
|
||||
// centralized schema for both frontend and backend
|
||||
const RefreshFrequencySchema = Yup.object().shape({
|
||||
propertyValue: Yup.number()
|
||||
.typeError("Property value must be a valid number")
|
||||
.integer("Property value must be an integer")
|
||||
.min(60, "Property value must be greater than or equal to 60")
|
||||
.required("Property value is required"),
|
||||
});
|
||||
|
||||
const PruneFrequencySchema = Yup.object().shape({
|
||||
propertyValue: Yup.number()
|
||||
.typeError("Property value must be a valid number")
|
||||
.integer("Property value must be an integer")
|
||||
.min(86400, "Property value must be greater than or equal to 86400")
|
||||
.required("Property value is required"),
|
||||
});
|
||||
|
||||
function Main({ ccPairId }: { ccPairId: number }) {
|
||||
const router = useRouter(); // Initialize the router
|
||||
const {
|
||||
@@ -45,6 +69,8 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
);
|
||||
|
||||
const [hasLoadedOnce, setHasLoadedOnce] = useState(false);
|
||||
const [editingRefreshFrequency, setEditingRefreshFrequency] = useState(false);
|
||||
const [editingPruningFrequency, setEditingPruningFrequency] = useState(false);
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
const finishConnectorDeletion = useCallback(() => {
|
||||
@@ -90,6 +116,86 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
}
|
||||
};
|
||||
|
||||
const handleRefreshEdit = async () => {
|
||||
setEditingRefreshFrequency(true);
|
||||
};
|
||||
|
||||
const handlePruningEdit = async () => {
|
||||
setEditingPruningFrequency(true);
|
||||
};
|
||||
|
||||
const handleRefreshSubmit = async (
|
||||
propertyName: string,
|
||||
propertyValue: string
|
||||
) => {
|
||||
const parsedRefreshFreq = parseInt(propertyValue, 10);
|
||||
|
||||
if (isNaN(parsedRefreshFreq)) {
|
||||
setPopup({
|
||||
message: "Invalid refresh frequency: must be an integer",
|
||||
type: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await updateConnectorCredentialPairProperty(
|
||||
ccPairId,
|
||||
propertyName,
|
||||
String(parsedRefreshFreq)
|
||||
);
|
||||
if (!response.ok) {
|
||||
throw new Error(await response.text());
|
||||
}
|
||||
mutate(buildCCPairInfoUrl(ccPairId));
|
||||
setPopup({
|
||||
message: "Connector refresh frequency updated successfully",
|
||||
type: "success",
|
||||
});
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: "Failed to update connector refresh frequency",
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const handlePruningSubmit = async (
|
||||
propertyName: string,
|
||||
propertyValue: string
|
||||
) => {
|
||||
const parsedFreq = parseInt(propertyValue, 10);
|
||||
|
||||
if (isNaN(parsedFreq)) {
|
||||
setPopup({
|
||||
message: "Invalid pruning frequency: must be an integer",
|
||||
type: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await updateConnectorCredentialPairProperty(
|
||||
ccPairId,
|
||||
propertyName,
|
||||
String(parsedFreq)
|
||||
);
|
||||
if (!response.ok) {
|
||||
throw new Error(await response.text());
|
||||
}
|
||||
mutate(buildCCPairInfoUrl(ccPairId));
|
||||
setPopup({
|
||||
message: "Connector pruning frequency updated successfully",
|
||||
type: "success",
|
||||
});
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: "Failed to update connector pruning frequency",
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if (isLoading) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
@@ -114,9 +220,35 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
refresh_freq: refreshFreq,
|
||||
indexing_start: indexingStart,
|
||||
} = ccPair.connector;
|
||||
|
||||
return (
|
||||
<>
|
||||
{popup}
|
||||
|
||||
{editingRefreshFrequency && (
|
||||
<EditPropertyModal
|
||||
propertyTitle="Refresh Frequency"
|
||||
propertyDetails="How often the connector should refresh (in seconds)"
|
||||
propertyName="refresh_frequency"
|
||||
propertyValue={String(refreshFreq)}
|
||||
validationSchema={RefreshFrequencySchema}
|
||||
onSubmit={handleRefreshSubmit}
|
||||
onClose={() => setEditingRefreshFrequency(false)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{editingPruningFrequency && (
|
||||
<EditPropertyModal
|
||||
propertyTitle="Pruning Frequency"
|
||||
propertyDetails="How often the connector should be pruned (in seconds)"
|
||||
propertyName="pruning_frequency"
|
||||
propertyValue={String(pruneFreq)}
|
||||
validationSchema={PruneFrequencySchema}
|
||||
onSubmit={handlePruningSubmit}
|
||||
onClose={() => setEditingPruningFrequency(false)}
|
||||
/>
|
||||
)}
|
||||
|
||||
<BackButton
|
||||
behaviorOverride={() => router.push("/admin/indexing/status")}
|
||||
/>
|
||||
@@ -125,7 +257,7 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
<SourceIcon iconSize={32} sourceType={ccPair.connector.source} />
|
||||
</div>
|
||||
|
||||
<div className="ml-1">
|
||||
<div className="ml-1 overflow-hidden text-ellipsis whitespace-nowrap flex-1 mr-4">
|
||||
<EditableStringFieldDisplay
|
||||
value={ccPair.name}
|
||||
isEditable={ccPair.is_editable_for_current_user}
|
||||
@@ -213,6 +345,8 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
pruneFreq={pruneFreq}
|
||||
indexingStart={indexingStart}
|
||||
refreshFreq={refreshFreq}
|
||||
onRefreshEdit={handleRefreshEdit}
|
||||
onPruningEdit={handlePruningEdit}
|
||||
/>
|
||||
)}
|
||||
|
||||
|
||||
@@ -19,7 +19,11 @@ import AdvancedFormPage from "./pages/Advanced";
|
||||
import DynamicConnectionForm from "./pages/DynamicConnectorCreationForm";
|
||||
import CreateCredential from "@/components/credentials/actions/CreateCredential";
|
||||
import ModifyCredential from "@/components/credentials/actions/ModifyCredential";
|
||||
import { ConfigurableSources, ValidSources } from "@/lib/types";
|
||||
import {
|
||||
ConfigurableSources,
|
||||
oauthSupportedSources,
|
||||
ValidSources,
|
||||
} from "@/lib/types";
|
||||
import { Credential, credentialTemplates } from "@/lib/connectors/credentials";
|
||||
import {
|
||||
ConnectionConfiguration,
|
||||
@@ -45,6 +49,8 @@ import { useRouter } from "next/navigation";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import { prepareOAuthAuthorizationRequest } from "@/lib/oauth_utils";
|
||||
import { EE_ENABLED, NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
import TemporaryLoadingModal from "@/components/TemporaryLoadingModal";
|
||||
import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth";
|
||||
export interface AdvancedConfig {
|
||||
refreshFreq: number;
|
||||
pruneFreq: number;
|
||||
@@ -154,14 +160,9 @@ export default function AddConnector({
|
||||
const configuration: ConnectionConfiguration = connectorConfigs[connector];
|
||||
|
||||
// Form context and popup management
|
||||
const {
|
||||
setFormStep,
|
||||
setAllowCreate: setAllowCreate,
|
||||
formStep,
|
||||
nextFormStep,
|
||||
prevFormStep,
|
||||
} = useFormContext();
|
||||
const { setFormStep, setAllowCreate, formStep } = useFormContext();
|
||||
const { popup, setPopup } = usePopup();
|
||||
const [uploading, setUploading] = useState(false);
|
||||
|
||||
// Hooks for Google Drive and Gmail credentials
|
||||
const { liveGDriveCredential } = useGoogleDriveCredentials(connector);
|
||||
@@ -339,16 +340,24 @@ export default function AddConnector({
|
||||
}
|
||||
// File-specific handling
|
||||
if (connector == "file") {
|
||||
const response = await submitFiles(
|
||||
selectedFiles,
|
||||
setPopup,
|
||||
name,
|
||||
access_type,
|
||||
groups
|
||||
);
|
||||
if (response) {
|
||||
onSuccess();
|
||||
setUploading(true);
|
||||
try {
|
||||
const response = await submitFiles(
|
||||
selectedFiles,
|
||||
setPopup,
|
||||
name,
|
||||
access_type,
|
||||
groups
|
||||
);
|
||||
if (response) {
|
||||
onSuccess();
|
||||
}
|
||||
} catch (error) {
|
||||
setPopup({ message: "Error uploading files", type: "error" });
|
||||
} finally {
|
||||
setUploading(false);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -410,9 +419,9 @@ export default function AddConnector({
|
||||
<div className="mx-auto mb-8 w-full">
|
||||
{popup}
|
||||
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
{uploading && (
|
||||
<TemporaryLoadingModal content="Uploading files..." />
|
||||
)}
|
||||
|
||||
<AdminPageTitle
|
||||
includeDivider={false}
|
||||
@@ -444,26 +453,38 @@ export default function AddConnector({
|
||||
{/* Button to pop up a form to manually enter credentials */}
|
||||
<button
|
||||
className="mt-6 text-sm bg-background-900 px-2 py-1.5 flex text-text-200 flex-none rounded mr-4"
|
||||
onClick={() =>
|
||||
setCreateConnectorToggle(
|
||||
(createConnectorToggle) => !createConnectorToggle
|
||||
)
|
||||
}
|
||||
onClick={async () => {
|
||||
const redirectUrl =
|
||||
await getConnectorOauthRedirectUrl(connector);
|
||||
// if redirect is supported, just use it
|
||||
if (redirectUrl) {
|
||||
window.location.href = redirectUrl;
|
||||
} else {
|
||||
setCreateConnectorToggle(
|
||||
(createConnectorToggle) =>
|
||||
!createConnectorToggle
|
||||
);
|
||||
}
|
||||
}}
|
||||
>
|
||||
Create New
|
||||
</button>
|
||||
|
||||
{/* Button to sign in via OAuth */}
|
||||
<button
|
||||
onClick={handleAuthorize}
|
||||
className="mt-6 text-sm bg-blue-500 px-2 py-1.5 flex text-text-200 flex-none rounded"
|
||||
disabled={isAuthorizing}
|
||||
hidden={!isAuthorizeVisible}
|
||||
>
|
||||
{isAuthorizing
|
||||
? "Authorizing..."
|
||||
: `Authorize with ${getSourceDisplayName(connector)}`}
|
||||
</button>
|
||||
{oauthSupportedSources.includes(connector) &&
|
||||
NEXT_PUBLIC_CLOUD_ENABLED && (
|
||||
<button
|
||||
onClick={handleAuthorize}
|
||||
className="mt-6 text-sm bg-blue-500 px-2 py-1.5 flex text-text-200 flex-none rounded"
|
||||
disabled={isAuthorizing}
|
||||
hidden={!isAuthorizeVisible}
|
||||
>
|
||||
{isAuthorizing
|
||||
? "Authorizing..."
|
||||
: `Authorize with ${getSourceDisplayName(
|
||||
connector
|
||||
)}`}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
|
||||
@@ -104,7 +104,9 @@ const GDriveMain = ({}: {}) => {
|
||||
const googleDriveServiceAccountCredential:
|
||||
| Credential<GoogleDriveServiceAccountCredentialJson>
|
||||
| undefined = credentialsData.find(
|
||||
(credential) => credential.credential_json?.google_service_account_key
|
||||
(credential) =>
|
||||
credential.credential_json?.google_service_account_key &&
|
||||
credential.source === "google_drive"
|
||||
);
|
||||
|
||||
const googleDriveConnectorIndexingStatuses: ConnectorIndexingStatus<
|
||||
|
||||
@@ -135,7 +135,7 @@ export const DocumentFeedbackTable = ({
|
||||
/>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="ml-auto flex w-16">
|
||||
<div className="relative">
|
||||
<div
|
||||
key={document.document_id}
|
||||
className="h-10 ml-auto mr-8"
|
||||
|
||||
@@ -353,13 +353,9 @@ export function CCPairIndexingStatusTable({
|
||||
);
|
||||
};
|
||||
const toggleSources = () => {
|
||||
const currentToggledCount =
|
||||
Object.values(connectorsToggled).filter(Boolean).length;
|
||||
const shouldToggleOn = currentToggledCount < sortedSources.length / 2;
|
||||
|
||||
const connectors = sortedSources.reduce(
|
||||
(acc, source) => {
|
||||
acc[source] = shouldToggleOn;
|
||||
acc[source] = shouldExpand;
|
||||
return acc;
|
||||
},
|
||||
{} as Record<ValidSources, boolean>
|
||||
@@ -368,6 +364,7 @@ export function CCPairIndexingStatusTable({
|
||||
setConnectorsToggled(connectors);
|
||||
Cookies.set(TOGGLED_CONNECTORS_COOKIE_NAME, JSON.stringify(connectors));
|
||||
};
|
||||
|
||||
const shouldExpand =
|
||||
Object.values(connectorsToggled).filter(Boolean).length <
|
||||
sortedSources.length;
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
import useSWR from "swr";
|
||||
import { InputPrompt } from "./interfaces";
|
||||
|
||||
const fetcher = (url: string) => fetch(url).then((res) => res.json());
|
||||
|
||||
export const useAdminInputPrompts = () => {
|
||||
const { data, error, mutate } = useSWR<InputPrompt[]>(
|
||||
`/api/admin/input_prompt`,
|
||||
fetcher
|
||||
);
|
||||
|
||||
return {
|
||||
data,
|
||||
error,
|
||||
isLoading: !error && !data,
|
||||
refreshInputPrompts: mutate,
|
||||
};
|
||||
};
|
||||
|
||||
export const useInputPrompts = (includePublic: boolean = false) => {
|
||||
const { data, error, mutate } = useSWR<InputPrompt[]>(
|
||||
`/api/input_prompt${includePublic ? "?include_public=true" : ""}`,
|
||||
fetcher
|
||||
);
|
||||
|
||||
return {
|
||||
data,
|
||||
error,
|
||||
isLoading: !error && !data,
|
||||
refreshInputPrompts: mutate,
|
||||
};
|
||||
};
|
||||
|
||||
export const useInputPrompt = (id: number) => {
|
||||
const { data, error, mutate } = useSWR<InputPrompt>(
|
||||
`/api/input_prompt/${id}`,
|
||||
fetcher
|
||||
);
|
||||
|
||||
return {
|
||||
data,
|
||||
error,
|
||||
isLoading: !error && !data,
|
||||
refreshInputPrompt: mutate,
|
||||
};
|
||||
};
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user