Compare commits

..

88 Commits

Author SHA1 Message Date
pablodanswer
da86610022 nit 2025-01-06 18:36:38 -08:00
pablodanswer
0027759dbf nit 2025-01-06 17:55:02 -08:00
pablodanswer
595ef152d2 updated UX 2025-01-05 14:38:52 -08:00
pablodanswer
083d669d1b minor logic update 2025-01-05 13:22:15 -08:00
pablodanswer
3ac31136b2 base functional 2025-01-03 17:11:14 -08:00
pablodanswer
a73a438d95 k 2025-01-03 14:33:24 -08:00
pablodanswer
c0770481e8 finalize 2025-01-03 14:29:52 -08:00
pablodanswer
c27d13c07f rm danswer 2025-01-03 14:27:25 -08:00
pablodanswer
ab34c4e772 add my docs v1 2025-01-03 14:25:56 -08:00
rkuo-danswer
66f9124135 Merge pull request #3584 from onyx-dot-app/bugfix/log_spacing
fix formatting
2025-01-03 00:43:36 -08:00
Richard Kuo
8f0fb70bbf fix formatting 2025-01-02 23:21:54 -08:00
rkuo-danswer
ef5e5c80bb Merge pull request #3577 from onyx-dot-app/bugfix/model_server_exception_logging
fix response logging
2025-01-02 23:08:46 -08:00
rkuo-danswer
03acb6587a Feature/model server logging (#3579)
* improve model server logging

* improve exception logging with provider/model names

* get everything into one log line

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-01-03 01:40:29 +00:00
hagen-danswer
d1ec72b5e5 Reworked salesforce connector to use bulk api (#3581) 2025-01-02 18:09:02 -08:00
Weves
3b214133a8 Airtable improvement 2025-01-02 17:56:05 -08:00
rkuo-danswer
2232702e99 retry the individual delete's (#3580)
* retry the individual delete's

* need to raise inside the retry

* just use retry for now

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-01-02 17:39:37 -08:00
hagen-danswer
8108ff0a4b Added logging for permissions upsert queue length 2025-01-02 17:39:01 -08:00
Richard Kuo
f64e78e986 fix response logging 2025-01-02 13:39:19 -08:00
Chris Weaver
08312a4394 Update Slack link in README.md 2025-01-01 10:03:59 -08:00
Weves
92add655e0 Slack fixes 2024-12-31 18:04:12 -08:00
Chris Weaver
d64464ca7c Add support for OAuth connectors that require user input (#3571)
* Add support for OAuth connectors that require user input

* Cleanup

* Fix linear

* Small re-naming

* Remove console.log
2024-12-31 18:03:33 -08:00
Yuhong Sun
ccd3983802 Linear OAuth Connector (#3570) 2024-12-31 16:11:09 -08:00
pablonyx
240f3e4fff Ensure users cannot modify their roles on main
Ensure users cannot modify their roles
2024-12-31 15:59:27 -05:00
pablonyx
1291b3d930 Add anonymous user to main
Anonymous user
2024-12-31 15:58:52 -05:00
rkuo-danswer
d05f1997b5 Merge pull request #3569 from onyx-dot-app/bugfix/alt_index
we didn't want to rename the alt index suffix, reverting
2024-12-31 12:39:00 -08:00
Chris Weaver
aa2e2a62b9 Small Egnyte tweaks (#3568) 2024-12-31 19:28:38 +00:00
Richard Kuo
174e5968f8 we didn't want to rename the alt index suffix, reverting 2024-12-31 11:28:11 -08:00
pablodanswer
1f27606e17 minor clean up 2024-12-31 13:04:02 -05:00
pablodanswer
60355b84c1 quick nits 2024-12-31 13:04:02 -05:00
pablodanswer
680ab9ea30 updated logic 2024-12-31 13:04:02 -05:00
pablodanswer
c2447dbb1c cosmetic updates 2024-12-31 13:04:02 -05:00
pablodanswer
52bad522f8 update for multi-tenant clarity 2024-12-31 13:04:02 -05:00
pablodanswer
63e5e58313 add anonymous user 2024-12-31 13:04:02 -05:00
rkuo-danswer
2643782e30 Merge pull request #3567 from onyx-dot-app/bugfix/revert_vespa
Revert "More efficient Vespa indexing (#3552)"
2024-12-31 09:47:00 -08:00
Richard Kuo
3eb72e5c1d Revert "More efficient Vespa indexing (#3552)"
This reverts commit 2783216781.
2024-12-31 09:40:23 -08:00
rkuo-danswer
9b65c23a7e Merge pull request #3566 from onyx-dot-app/bugfix/primary_task_timings
re-enable celery task execution logging in primary worker
2024-12-31 01:29:05 -08:00
Richard Kuo (Danswer)
b43a8e48c6 add some return types to distinguish when the task is actually performing work 2024-12-31 00:10:33 -08:00
Richard Kuo (Danswer)
1955c1d67b re-enable celery task execution logging in primary worker 2024-12-30 21:53:00 -08:00
Chris Weaver
3f92ed9d29 Airtable connector (#3564)
* Airtable connector

* Improvements

* improve

* Clean up test

* Add secrets

* Fix mypy + add access token

* Mock unstructured call

* Adjust comments

* Fix ID in test
2024-12-31 03:06:28 +00:00
Weves
618369f4a1 Small fix 2024-12-30 19:20:30 -08:00
pablonyx
2783216781 More efficient Vespa indexing (#3552)
---------

Co-authored-by: Chris Weaver <25087905+Weves@users.noreply.github.com>
2024-12-30 18:51:14 -08:00
rkuo-danswer
bec0f9fb23 permission sync in cloud and beat expiry adjustment (#3544)
* try fixing exception in cloud

* raise beat expiry ... 60 seconds might be starving certain tasks completely

* adjust expiry down to 10 min

* raise concurrency overflow for indexing worker.

* parent pid check

* fix comment

* fix parent pid check, also actually raise an exception from the task if the spawned task exit status is bad

* fix pid check

* some cleanup and task wait fixes

* review fixes

* comment some code so we don't change too many things at once

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2024-12-31 01:05:57 +00:00
pablodanswer
97a03e7fc8 nit 2024-12-29 21:07:12 -05:00
pablodanswer
8d6e8269b7 k 2024-12-29 21:07:12 -05:00
pablodanswer
9ce2c6c517 minor change 2024-12-29 21:07:12 -05:00
pablodanswer
2ad8bdbc65 k 2024-12-29 21:07:12 -05:00
rkuo-danswer
a83c9b40d5 Bugfix/oauth fix (#3507)
* old oauth file left behind

* fix function change that was lost in merge

* fix some testing vars

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-30 01:49:12 +00:00
Chris Weaver
340fab1375 Additional error handling + logging for google drive connector (#3563)
* Additional error handling + logging for google drive connector

* Fix mypy
2024-12-29 17:48:02 -08:00
hagen-danswer
3ec338307f Fixed indexing issues with Salesforce 2024-12-29 16:45:29 -08:00
pablonyx
27acd3387a Auth specific rate limiting (#3463)
* k

* v1

* fully functional

* finalize

* nit

* nit

* nit

* clean up with wrapper + comments

* k

* update

* minor clean
2024-12-29 23:34:23 +00:00
hagen-danswer
d14ef431a7 Improve Salesforce connector 2024-12-29 14:18:40 -08:00
pablonyx
9bffeb65af Eagerly load CCpair connectors (#3531)
* remove left over vim command

* eager loading

* Revert "remove left over vim command"

This reverts commit 184a134ae0.
2024-12-29 15:58:38 +00:00
Yuhong Sun
f4806da653 Fix Null Value in PG (#3559)
* k

* k

* k

* k

* k
2024-12-29 01:53:16 +00:00
pablonyx
e2700b2bbd Remove left over yaml errors (#3527)
* remove left over vim command

* additional misconfigurations

* ensure all regions updated
2024-12-29 01:45:07 +00:00
Yuhong Sun
fc81a3fb12 Zendesk Retries (#3558)
* k

* k

* k

* k
2024-12-28 23:51:49 +00:00
pablonyx
2203cfabea Prevent SSRF risk
Prevent SSRF risk
2024-12-28 15:25:57 -05:00
pablodanswer
f4050306d6 Prevent SSRF risk 2024-12-28 15:25:12 -05:00
Weves
2d960a477f Fix discourse connector 2024-12-24 12:43:10 -08:00
hagen-danswer
8837b8ea79 Curators can now update the curator relationship (#3536)
* Curators can now update the curator relationship

* mypy

* mypy

* whoops haha
2024-12-24 18:49:58 +00:00
hagen-danswer
3dfb214f73 Slackbot polish (#3547) 2024-12-24 16:19:15 +00:00
pablonyx
18d7262608 fix logo rendering (#3542) 2024-12-22 23:00:33 +00:00
pablonyx
09b879ee73 Ensure gmail works for personal accounts (#3541)
* Ensure gmail works for personal accounts

* nit

* minor update
2024-12-22 23:00:14 +00:00
rkuo-danswer
aaa668c963 Merge pull request #3534 from onyx-dot-app/bugfix/validate_ttl
raise activity timeout to one hour
2024-12-22 13:13:57 -08:00
pablonyx
edb877f4bc fix NUL character (#3540) 2024-12-21 23:30:25 +00:00
rkuo-danswer
eb369caefb log attempt id, log elapsed since task execution start, remove log spam (#3539)
* log attempt id, log elapsed since task execution start, remove log spam

* diagnostic lock logs

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-21 23:03:50 +00:00
Chris Weaver
b9567eabd7 Fix bedrock w/ access keys (#3538)
* Fix bedrock w/ access keys

* cleanup

* Remove extra #
2024-12-21 02:24:11 +00:00
Richard Kuo (Danswer)
13bbf67091 raise activity timeout to one hour 2024-12-20 16:18:35 -08:00
hagen-danswer
457a4c73f0 Made sure confluence connector recursive by page includes top level page (#3532)
* Made sure confluence connector by page includes top level page

* surface level change
2024-12-20 21:53:59 +00:00
rkuo-danswer
ce37688b5b allow limited user to create chat session (#3533)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-20 21:36:41 +00:00
pablonyx
4e2c90f4af Proper user deletion / organization leaving (#3460)
* Proper user deletion / organization leaving

* minor nit

* update

* udpate provisioning

* minor cleanup

* typing

* post rebase
2024-12-20 21:01:03 +00:00
pablonyx
513dd8a319 update toggling states (#3519) 2024-12-20 20:27:22 +00:00
hagen-danswer
71c5043832 Added filter to exclude attachments with unsupported file extensions (#3530)
* Added filter to exclude attachments with unsupported file extensions

* extension
2024-12-20 19:48:15 +00:00
pablonyx
64b6f15e95 AWS extraneous error fix (#3529)
* remove left over vim command

* aws fix

* k

* remove double
2024-12-20 19:31:04 +00:00
hagen-danswer
35022f5f09 Fix group table (#3523) 2024-12-20 17:51:26 +00:00
hagen-danswer
0d44014c16 Cleanup PR template to make it more concise (#3524) 2024-12-20 17:49:31 +00:00
Yuhong Sun
1b9e9f48fa Update README.md 2024-12-20 10:26:57 -08:00
Yuhong Sun
05fb5aa27b Update README.md 2024-12-20 10:25:34 -08:00
Yuhong Sun
3b645b72a3 Crop Logo Closer 2024-12-20 10:23:52 -08:00
Yuhong Sun
fe770b5c3a Fix Logo On DarkMode (#3525) 2024-12-20 10:15:48 -08:00
hagen-danswer
1eaf885f50 associating credentials with connectors is not considered editing (#3522)
* associating credentials with connectors is not considered editing

* formatting

* formatting

* Update credentials.py

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-12-20 17:36:25 +00:00
rkuo-danswer
a187aa508c use redis exclusively with active signal renewal in more places to perform indexing fence validation (#3517)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2024-12-20 06:54:00 +00:00
pablonyx
aa4bfa2a78 Forgot password feature (#3437)
* forgot password feature

* improved config

* nit

* nit
2024-12-20 04:53:37 +00:00
pablonyx
9011b8a139 Update citations in shared chat display (#3487)
* update shared chat display

* Change Copy

* fix icon

* remove secret!

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2024-12-20 01:48:29 +00:00
pablonyx
59c774353a Latex formatting (#3499) 2024-12-19 14:48:06 -08:00
pablonyx
b458d504af Sidebar Default Open (#3488) 2024-12-19 14:04:50 -08:00
Yuhong Sun
f83e7bfcd9 Fix Default CC Pair (#3513) 2024-12-19 09:43:12 -08:00
pablonyx
4d2e26ce4b MT Cloud Tracking Fix (#3514) 2024-12-19 08:47:02 -08:00
pablonyx
817fdc1f36 Ensure metadata overrides file contents (#3512)
* ensure metadata overrides file contents

* update more blocks
2024-12-19 04:44:24 +00:00
222 changed files with 9597 additions and 3203 deletions

View File

@@ -6,24 +6,6 @@
[Describe the tests you ran to verify your changes]
## Accepted Risk (provide if relevant)
N/A
## Related Issue(s) (provide if relevant)
N/A
## Mental Checklist:
- All of the automated tests pass
- All PR comments are addressed and marked resolved
- If there are migrations, they have been rebased to latest main
- If there are new dependencies, they are added to the requirements
- If there are new environment variables, they are added to all of the deployment methods
- If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
- Docker images build and basic functionalities work
- Author has done a final read through of the PR right before merge
## Backporting (check the box to trigger backport action)
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)

View File

@@ -66,6 +66,7 @@ jobs:
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
NEXT_PUBLIC_GTM_ENABLED=true
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
# needed due to weird interactions with the builds for different platforms
no-cache: true
labels: ${{ steps.meta.outputs.labels }}

View File

@@ -26,7 +26,15 @@ env:
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
# Slab
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
# Salesforce
SF_USERNAME: ${{ secrets.SF_USERNAME }}
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
# Airtable
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/

View File

@@ -3,7 +3,7 @@
<a name="readme-top"></a>
<h2 align="center">
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/LogoOnyx.png?raw=true)" /></a>
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/OnyxLogoCropped.jpg?raw=true)" /></a>
</h2>
<p align="center">
@@ -13,7 +13,7 @@
<a href="https://docs.onyx.app/" target="_blank">
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
</a>
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ" target="_blank">
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
</a>
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
@@ -24,7 +24,7 @@
</a>
</p>
<strong>[Onyx](https://www.onyx.app/)</strong> (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
@@ -133,15 +133,3 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
## ⭐Star History
[![Star History Chart](https://api.star-history.com/svg?repos=onyx-dot-app/onyx&type=Date)](https://star-history.com/#onyx-dot-app/onyx&Date)
## ✨Contributors
<a href="https://github.com/onyx-dot-app/onyx/graphs/contributors">
<img alt="contributors" src="https://contrib.rocks/image?repo=onyx-dot-app/onyx"/>
</a>
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
↑ Back to Top ↑
</a>
</p>

1
backend/.gitignore vendored
View File

@@ -9,3 +9,4 @@ api_keys.py
vespa-app.zip
dynamic_config_storage/
celerybeat-schedule*
onyx/connectors/salesforce/data/

View File

@@ -4,7 +4,7 @@ from onyx.configs.app_configs import USE_IAM_AUTH
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import AWS_REGION
from onyx.configs.app_configs import AWS_REGION_NAME
from onyx.db.engine import build_connection_string
from onyx.db.engine import get_all_tenant_ids
from sqlalchemy import event
@@ -120,7 +120,7 @@ def provide_iam_token_for_alembic(
) -> None:
if USE_IAM_AUTH:
# Database connection settings
region = AWS_REGION
region = AWS_REGION_NAME
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER

View File

@@ -0,0 +1,129 @@
from alembic import op
import sqlalchemy as sa
import datetime
# revision identifiers, used by Alembic.
revision = "25d86cbfce78"
down_revision = "c0aab6edb6dd"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create user_folder table with additional 'display_priority' field
op.create_table(
"user_folder",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column(
"parent_id", sa.Integer(), sa.ForeignKey("user_folder.id"), nullable=True
),
sa.Column("name", sa.String(length=255), nullable=True),
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
sa.Column("created_at", sa.DateTime(), default=datetime.datetime.utcnow),
)
# Migrate data from chat_folder to user_folder
op.execute(
"""
INSERT INTO user_folder (id, user_id, name, display_priority, created_at)
SELECT id, user_id, name, display_priority, CURRENT_TIMESTAMP FROM chat_folder
"""
)
# Update chat_session table to reference user_folder instead of chat_folder
op.drop_constraint(
"chat_session_chat_folder_fk", "chat_session", type_="foreignkey"
)
op.alter_column(
"chat_session",
"folder_id",
existing_type=sa.Integer(),
nullable=True,
existing_nullable=True,
existing_server_default=None,
)
op.create_foreign_key(
"fk_chat_session_folder_id_user_folder",
"chat_session",
"user_folder",
["folder_id"],
["id"],
ondelete="SET NULL",
)
# Drop the chat_folder table
op.drop_table("chat_folder")
# Create user_file table
op.create_table(
"user_file",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
sa.Column(
"parent_folder_id",
sa.Integer(),
sa.ForeignKey("user_folder.id"),
nullable=True,
),
sa.Column("file_type", sa.String(), nullable=True),
sa.Column("file_id", sa.String(length=255), nullable=False),
sa.Column("document_id", sa.String(length=255), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column(
"created_at",
sa.DateTime(),
default=datetime.datetime.utcnow,
),
)
def downgrade() -> None:
# Recreate chat_folder table
op.create_table(
"chat_folder",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column(
"user_id",
sa.UUID(),
sa.ForeignKey("user.id", ondelete="CASCADE"),
nullable=True,
),
sa.Column("name", sa.String(length=255), nullable=True),
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
)
# Migrate data back from user_folder to chat_folder
op.execute(
"""
INSERT INTO chat_folder (id, user_id, name, display_priority)
SELECT id, user_id, name, display_priority FROM user_folder
WHERE id IN (SELECT DISTINCT folder_id FROM chat_session WHERE folder_id IS NOT NULL)
"""
)
# Update chat_session table to reference chat_folder again
op.drop_constraint(
"fk_chat_session_folder_id_user_folder", "chat_session", type_="foreignkey"
)
op.alter_column(
"chat_session",
"folder_id",
existing_type=sa.Integer(),
nullable=True,
existing_nullable=True,
existing_server_default=None,
)
op.create_foreign_key(
"chat_session_chat_folder_fk",
"chat_session",
"chat_folder",
["folder_id"],
["id"],
ondelete="SET NULL",
)
# Drop the user_file table
op.drop_table("user_file")
# Drop the user_folder table
op.drop_table("user_folder")

View File

@@ -122,7 +122,7 @@ def _cleanup_document_set__user_group_relationships__no_commit(
)
def validate_user_creation_permissions(
def validate_object_creation_for_user(
db_session: Session,
user: User | None,
target_group_ids: list[int] | None = None,
@@ -440,32 +440,108 @@ def remove_curator_status__no_commit(db_session: Session, user: User) -> None:
_validate_curator_status__no_commit(db_session, [user])
def update_user_curator_relationship(
def _validate_curator_relationship_update_requester(
db_session: Session,
user_group_id: int,
set_curator_request: SetCuratorRequest,
user_making_change: User | None = None,
) -> None:
user = fetch_user_by_id(db_session, set_curator_request.user_id)
if not user:
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
"""
This function validates that the user making the change has the necessary permissions
to update the curator relationship for the target user in the given user group.
"""
if user.role == UserRole.ADMIN:
if user_making_change is None or user_making_change.role == UserRole.ADMIN:
return
# check if the user making the change is a curator in the group they are changing the curator relationship for
user_making_change_curator_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user_making_change.id,
# only check if the user making the change is a curator if they are a curator
# otherwise, they are a global_curator and can update the curator relationship
# for any group they are a member of
only_curator_groups=user_making_change.role == UserRole.CURATOR,
)
requestor_curator_group_ids = [
group.id for group in user_making_change_curator_groups
]
if user_group_id not in requestor_curator_group_ids:
raise ValueError(
f"User '{user.email}' is an admin and therefore has all permissions "
f"user making change {user_making_change.email} is not a curator,"
f" admin, or global_curator for group '{user_group_id}'"
)
def _validate_curator_relationship_update_request(
db_session: Session,
user_group_id: int,
target_user: User,
) -> None:
"""
This function validates that the curator_relationship_update request itself is valid.
"""
if target_user.role == UserRole.ADMIN:
raise ValueError(
f"User '{target_user.email}' is an admin and therefore has all permissions "
"of a curator. If you'd like this user to only have curator permissions, "
"you must update their role to BASIC then assign them to be CURATOR in the "
"appropriate groups."
)
elif target_user.role == UserRole.GLOBAL_CURATOR:
raise ValueError(
f"User '{target_user.email}' is a global_curator and therefore has all "
"permissions of a curator for all groups. If you'd like this user to only "
"have curator permissions for a specific group, you must update their role "
"to BASIC then assign them to be CURATOR in the appropriate groups."
)
elif target_user.role not in [UserRole.CURATOR, UserRole.BASIC]:
raise ValueError(
f"This endpoint can only be used to update the curator relationship for "
"users with the CURATOR or BASIC role. \n"
f"Target user: {target_user.email} \n"
f"Target user role: {target_user.role} \n"
)
# check if the target user is in the group they are changing the curator relationship for
requested_user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=set_curator_request.user_id,
user_id=target_user.id,
only_curator_groups=False,
)
group_ids = [group.id for group in requested_user_groups]
if user_group_id not in group_ids:
raise ValueError(f"user is not in group '{user_group_id}'")
raise ValueError(
f"target user {target_user.email} is not in group '{user_group_id}'"
)
def update_user_curator_relationship(
db_session: Session,
user_group_id: int,
set_curator_request: SetCuratorRequest,
user_making_change: User | None = None,
) -> None:
target_user = fetch_user_by_id(db_session, set_curator_request.user_id)
if not target_user:
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
_validate_curator_relationship_update_request(
db_session=db_session,
user_group_id=user_group_id,
target_user=target_user,
)
_validate_curator_relationship_update_requester(
db_session=db_session,
user_group_id=user_group_id,
user_making_change=user_making_change,
)
logger.info(
f"user_making_change={user_making_change.email if user_making_change else 'None'} is "
f"updating the curator relationship for user={target_user.email} "
f"in group={user_group_id} to is_curator={set_curator_request.is_curator}"
)
relationship_to_update = (
db_session.query(User__UserGroup)
@@ -486,7 +562,7 @@ def update_user_curator_relationship(
)
db_session.add(relationship_to_update)
_validate_curator_status__no_commit(db_session, [user])
_validate_curator_status__no_commit(db_session, [target_user])
db_session.commit()

View File

@@ -40,6 +40,7 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.main import get_application as get_application_base
from onyx.main import include_auth_router_with_prefix
from onyx.main import include_router_with_global_prefix_prepended
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
@@ -62,7 +63,7 @@ def get_application() -> FastAPI:
if AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
oauth_client,
@@ -74,19 +75,17 @@ def get_application() -> FastAPI:
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
),
prefix="/auth/oauth",
tags=["auth"],
)
# Need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_logout_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
if AUTH_TYPE == AuthType.OIDC:
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
@@ -97,19 +96,21 @@ def get_application() -> FastAPI:
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
),
prefix="/auth/oidc",
tags=["auth"],
)
# need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_auth_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
elif AUTH_TYPE == AuthType.SAML:
include_router_with_global_prefix_prepended(application, saml_router)
include_auth_router_with_prefix(
application,
saml_router,
prefix="/auth/saml",
)
# RBAC / group access control
include_router_with_global_prefix_prepended(application, user_group_router)

View File

@@ -1,5 +1,7 @@
import base64
import json
import uuid
from typing import Any
from typing import cast
import requests
@@ -10,11 +12,29 @@ from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from onyx.auth.users import current_user
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
@@ -62,14 +82,7 @@ class SlackOAuth:
@classmethod
def generate_oauth_url(cls, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={cls.REDIRECT_URI}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
return url
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
@@ -77,10 +90,14 @@ class SlackOAuth:
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={cls.DEV_REDIRECT_URI}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
@@ -102,82 +119,151 @@ class SlackOAuth:
return session
# Work in progress
# class ConfluenceCloudOAuth:
# """work in progress"""
class ConfluenceCloudOAuth:
"""work in progress"""
# # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
# class OAuthSession(BaseModel):
# """Stored in redis to be looked up on callback"""
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
# email: str
# redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
# CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
# CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
# TOKEN_URL = "https://auth.atlassian.com/oauth/token"
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
# # All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
# CONFLUENCE_OAUTH_SCOPE = (
# "read:confluence-props%20"
# "read:confluence-content.all%20"
# "read:confluence-content.summary%20"
# "read:confluence-content.permission%20"
# "read:confluence-user%20"
# "read:confluence-groups%20"
# "readonly:content.attachment:confluence"
# )
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
CONFLUENCE_OAUTH_SCOPE = (
"read:confluence-props%20"
"read:confluence-content.all%20"
"read:confluence-content.summary%20"
"read:confluence-content.permission%20"
"read:confluence-user%20"
"read:confluence-groups%20"
"readonly:content.attachment:confluence"
)
# REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
# DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
# # eventually for Confluence Data Center
# # oauth_url = (
# # f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# # f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# # f"&redirect_uri={redirectme_uri}"
# # )
# eventually for Confluence Data Center
# oauth_url = (
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# f"&redirect_uri={redirectme_uri}"
# )
# @classmethod
# def generate_oauth_url(cls, state: str) -> str:
# return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
# @classmethod
# def generate_dev_oauth_url(cls, state: str) -> str:
# """dev mode workaround for localhost testing
# - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
# """
# return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
# @classmethod
# def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# url = (
# "https://auth.atlassian.com/authorize"
# f"?audience=api.atlassian.com"
# f"&client_id={cls.CLIENT_ID}"
# f"&redirect_uri={redirect_uri}"
# f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
# f"&state={state}"
# "&response_type=code"
# "&prompt=consent"
# )
# return url
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
"https://auth.atlassian.com/authorize"
f"?audience=api.atlassian.com"
f"&client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
f"&state={state}"
"&response_type=code"
"&prompt=consent"
)
return url
# @classmethod
# def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
# """Temporary state to store in redis. to be looked up on auth response.
# Returns a json string.
# """
# session = ConfluenceCloudOAuth.OAuthSession(
# email=email, redirect_on_success=redirect_on_success
# )
# return session.model_dump_json()
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = ConfluenceCloudOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
# @classmethod
# def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
# session = SlackOAuth.OAuthSession.model_validate_json(session_json)
# return session
@classmethod
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
class GoogleDriveOAuth:
# https://developers.google.com/identity/protocols/oauth2
# https://developers.google.com/identity/protocols/oauth2/web-server
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
TOKEN_URL = "https://oauth2.googleapis.com/token"
# SCOPE is per https://docs.onyx.app/connectors/google-drive
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
SCOPE = (
"https://www.googleapis.com/auth/drive.readonly%20"
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
"https://www.googleapis.com/auth/admin.directory.group.readonly"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# without prompt=consent, a refresh token is only issued the first time the user approves
url = (
f"https://accounts.google.com/o/oauth2/v2/auth"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
"&response_type=code"
f"&scope={cls.SCOPE}"
"&access_type=offline"
f"&state={state}"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = GoogleDriveOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/prepare-authorization-request")
@@ -192,8 +278,11 @@ def prepare_authorization_request(
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
oauth_uuid_str = str(oauth_uuid)
# urlsafe b64 encode the uuid for the oauth url
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
@@ -203,6 +292,11 @@ def prepare_authorization_request(
session = SlackOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.GOOGLE_DRIVE:
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
session = GoogleDriveOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
# elif connector == DocumentSource.CONFLUENCE:
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
# session = ConfluenceCloudOAuth.session_dump_json(
@@ -210,8 +304,6 @@ def prepare_authorization_request(
# )
# elif connector == DocumentSource.JIRA:
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
# elif connector == DocumentSource.GOOGLE_DRIVE:
# oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
else:
oauth_url = None
@@ -223,6 +315,7 @@ def prepare_authorization_request(
r = get_redis_client(tenant_id=tenant_id)
# store important session state to retrieve when the user is redirected back
# 10 min is the max we want an oauth flow to be valid
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
@@ -421,3 +514,116 @@ def handle_slack_oauth_callback(
# "redirect_on_success": session.redirect_on_success,
# }
# )
@router.post("/connector/google-drive/callback")
def handle_google_drive_oauth_callback(
code: str,
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = GoogleDriveOAuth.parse_session(session_json)
# Exchange the authorization code for an access token
response = requests.post(
GoogleDriveOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": GoogleDriveOAuth.CLIENT_ID,
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
"grant_type": "authorization_code",
},
)
response.raise_for_status()
authorization_response: dict[str, Any] = response.json()
# the connector wants us to store the json in its authorized_user_info format
# returned from OAuthCredentials.get_authorized_user_info().
# So refresh immediately via get_google_oauth_creds with the params filled in
# from fields in authorization_response to get the json we need
authorized_user_info = {}
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
)
if not oauth_creds:
raise RuntimeError("get_google_oauth_creds returned None.")
# save off the credentials
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
credential_dict: dict[str, str] = {}
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
credential_dict[
DB_CREDENTIALS_AUTHENTICATION_METHOD
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
credential_info = CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
name="OAuth (interactive)",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Google Drive OAuth completed successfully.",
"redirect_on_success": session.redirect_on_success,
}
)

View File

@@ -3,6 +3,7 @@ from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.orm import Session
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
@@ -12,15 +13,23 @@ from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import auth_backend
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_jwt_strategy
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.auth import get_user_count
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_tenant
from onyx.db.notification import create_notification
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.server.settings.store import load_settings
from onyx.server.settings.store import store_settings
from onyx.utils.logger import setup_logger
@@ -114,3 +123,48 @@ async def impersonate_user(
samesite="lax",
)
return response
@router.post("/leave-organization")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> None:
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -39,3 +39,8 @@ class TenantCreationPayload(BaseModel):
tenant_id: str
email: str
referral_source: str | None = None
class TenantDeletionPayload(BaseModel):
tenant_id: str
email: str

View File

@@ -15,6 +15,7 @@ from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import TenantCreationPayload
from ee.onyx.server.tenants.models import TenantDeletionPayload
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import drop_schema
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
@@ -185,6 +186,7 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
try:
# Drop the tenant's schema to rollback provisioning
drop_schema(tenant_id)
# Remove tenant mapping
with Session(get_sqlalchemy_engine()) as db_session:
db_session.query(UserTenantMapping).filter(
@@ -320,3 +322,26 @@ async def submit_to_hubspot(
if response.status_code != 200:
logger.error(f"Failed to submit to HubSpot: {response.text}")
async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
payload = TenantDeletionPayload(tenant_id=tenant_id, email=email)
async with aiohttp.ClientSession() as session:
async with session.delete(
f"{CONTROL_PLANE_API_BASE_URL}/tenants/delete",
headers=headers,
json=payload.model_dump(),
) as response:
print(response)
if response.status != 200:
error_text = await response.text()
logger.error(f"Control plane tenant creation failed: {error_text}")
raise Exception(
f"Failed to delete tenant on control plane: {error_text}"
)

View File

@@ -68,3 +68,11 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
)
db_session.rollback()
def remove_all_users_from_tenant(tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()

View File

@@ -83,7 +83,7 @@ def patch_user_group(
def set_user_curator(
user_group_id: int,
set_curator_request: SetCuratorRequest,
_: User | None = Depends(current_admin_user),
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> None:
try:
@@ -91,6 +91,7 @@ def set_user_curator(
db_session=db_session,
user_group_id=user_group_id,
set_curator_request=set_curator_request,
user_making_change=user,
)
except ValueError as e:
logger.error(f"Error setting user curator: {e}")

View File

@@ -10,6 +10,7 @@ logger = setup_logger()
def posthog_on_error(error: Any, items: Any) -> None:
"""Log any PostHog delivery errors."""
logger.error(f"PostHog error: {error}, items: {items}")
@@ -24,15 +25,10 @@ posthog = Posthog(
def event_telemetry(
distinct_id: str, event: str, properties: dict | None = None
) -> None:
logger.info(f"Capturing Posthog event: {distinct_id} {event} {properties}")
print("API KEY", POSTHOG_API_KEY)
print("HOST", POSTHOG_HOST)
"""Capture and send an event to PostHog, flushing immediately."""
logger.info(f"Capturing PostHog event: {distinct_id} {event} {properties}")
try:
print(type(distinct_id))
print(type(event))
print(type(properties))
response = posthog.capture(distinct_id, event, properties)
posthog.capture(distinct_id, event, properties)
posthog.flush()
print(response)
except Exception as e:
logger.error(f"Error capturing Posthog event: {e}")
logger.error(f"Error capturing PostHog event: {e}")

View File

@@ -1,5 +1,6 @@
import asyncio
import json
import time
from types import TracebackType
from typing import cast
from typing import Optional
@@ -320,8 +321,6 @@ async def embed_text(
api_url: str | None,
api_version: str | None,
) -> list[Embedding]:
logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}")
if not all(texts):
logger.error("Empty strings provided for embedding")
raise ValueError("Empty strings are not allowed for embedding.")
@@ -330,8 +329,17 @@ async def embed_text(
logger.error("No texts provided for embedding")
raise ValueError("No texts provided for embedding.")
start = time.monotonic()
total_chars = 0
for text in texts:
total_chars += len(text)
if provider_type is not None:
logger.debug(f"Using cloud provider {provider_type} for embedding")
logger.info(
f"Embedding {len(texts)} texts with {total_chars} total characters with provider: {provider_type}"
)
if api_key is None:
logger.error("API key not provided for cloud model")
raise RuntimeError("API key not provided for cloud model")
@@ -363,8 +371,16 @@ async def embed_text(
logger.error(error_message)
raise ValueError(error_message)
elapsed = time.monotonic() - start
logger.info(
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
f"with provider {provider_type} in {elapsed:.2f}"
)
elif model_name is not None:
logger.debug(f"Using local model {model_name} for embedding")
logger.info(
f"Embedding {len(texts)} texts with {total_chars} total characters with local model: {model_name}"
)
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
local_model = get_embedding_model(
@@ -382,13 +398,17 @@ async def embed_text(
for embedding in embeddings_vectors
]
elapsed = time.monotonic() - start
logger.info(
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
f"with local model {model_name} in {elapsed:.2f}"
)
else:
logger.error("Neither model name nor provider specified for embedding")
raise ValueError(
"Either model name or provider must be provided to run embeddings."
)
logger.info(f"Successfully embedded {len(texts)} texts")
return embeddings
@@ -440,7 +460,8 @@ async def process_embed_request(
) -> EmbedResponse:
if not embed_request.texts:
raise HTTPException(status_code=400, detail="No texts to be embedded")
elif not all(embed_request.texts):
if not all(embed_request.texts):
raise ValueError("Empty strings are not allowed for embedding.")
try:
@@ -471,9 +492,12 @@ async def process_embed_request(
detail=str(e),
)
except Exception as e:
exception_detail = f"Error during embedding process:\n{str(e)}"
logger.exception(exception_detail)
raise HTTPException(status_code=500, detail=exception_detail)
logger.exception(
f"Error during embedding process: provider={embed_request.provider_type} model={embed_request.model_name}"
)
raise HTTPException(
status_code=500, detail=f"Error during embedding process: {e}"
)
@router.post("/cross-encoder-scores")

View File

@@ -44,6 +44,7 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
the files in the existing huggingface cache that don't exist in the temp
huggingface cache.
"""
for item in source.iterdir():
target_path = dest / item.relative_to(source)
if item.is_dir():

View File

@@ -0,0 +1,80 @@
import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from textwrap import dedent
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import EMAIL_FROM
from onyx.configs.app_configs import SMTP_PASS
from onyx.configs.app_configs import SMTP_PORT
from onyx.configs.app_configs import SMTP_SERVER
from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.models import User
def send_email(
user_email: str,
subject: str,
body: str,
mail_from: str = EMAIL_FROM,
) -> None:
if not EMAIL_CONFIGURED:
raise ValueError("Email is not configured.")
msg = MIMEMultipart()
msg["Subject"] = subject
msg["To"] = user_email
if mail_from:
msg["From"] = mail_from
msg.attach(MIMEText(body))
try:
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
s.starttls()
s.login(SMTP_USER, SMTP_PASS)
s.send_message(msg)
except Exception as e:
raise e
def send_user_email_invite(user_email: str, current_user: User) -> None:
subject = "Invitation to Join Onyx Workspace"
body = dedent(
f"""\
Hello,
You have been invited to join a workspace on Onyx.
To join the workspace, please visit the following link:
{WEB_DOMAIN}/auth/login
Best regards,
The Onyx Team
"""
)
send_email(user_email, subject, body, current_user.email)
def send_forgot_password_email(
user_email: str,
token: str,
mail_from: str = EMAIL_FROM,
) -> None:
subject = "Onyx Forgot Password"
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
body = f"Click the following link to reset your password: {link}"
send_email(user_email, subject, body, mail_from)
def send_user_verification_email(
user_email: str,
token: str,
mail_from: str = EMAIL_FROM,
) -> None:
subject = "Onyx Email Verification"
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
body = f"Click the following link to verify your email address: {link}"
send_email(user_email, subject, body, mail_from)

View File

@@ -30,13 +30,16 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
)
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
def fetch_no_auth_user(
store: KeyValueStore, *, anonymous_user_enabled: bool | None = None
) -> UserInfo:
return UserInfo(
id=NO_AUTH_USER_ID,
email=NO_AUTH_USER_EMAIL,
is_active=True,
is_superuser=False,
is_verified=True,
role=UserRole.ADMIN,
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
preferences=load_no_auth_user_preferences(store),
is_anonymous_user=anonymous_user_enabled,
)

View File

@@ -49,4 +49,7 @@ class UserCreate(schemas.BaseUserCreate):
class UserUpdate(schemas.BaseUserUpdate):
role: UserRole
"""
Role updates are not allowed through the user update endpoint for security reasons
Role changes should be handled through a separate, admin-only process
"""

View File

@@ -1,10 +1,7 @@
import smtplib
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
from datetime import timezone
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from typing import cast
from typing import Dict
from typing import List
@@ -53,19 +50,17 @@ from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.email_utils import send_forgot_password_email
from onyx.auth.email_utils import send_user_verification_email
from onyx.auth.invited_users import get_invited_users
from onyx.auth.schemas import UserCreate
from onyx.auth.schemas import UserRole
from onyx.auth.schemas import UserUpdate
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.app_configs import EMAIL_FROM
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import SMTP_PASS
from onyx.configs.app_configs import SMTP_PORT
from onyx.configs.app_configs import SMTP_SERVER
from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
@@ -74,6 +69,7 @@ from onyx.configs.constants import AuthType
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
from onyx.db.api_key import fetch_user_for_api_key
@@ -89,7 +85,7 @@ from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import User
from onyx.db.users import get_user_by_email
from onyx.server.utils import BasicAuthenticationError
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
@@ -103,6 +99,11 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True
@@ -143,6 +144,20 @@ def user_needs_to_be_verified() -> bool:
return False
def anonymous_user_enabled() -> bool:
if MULTI_TENANT:
return False
redis_client = get_redis_client(tenant_id=None)
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
if value is None:
return False
assert isinstance(value, bytes)
return int(value.decode("utf-8")) == 1
def verify_email_is_invited(email: str) -> None:
whitelist = get_invited_users()
if not whitelist:
@@ -193,30 +208,6 @@ def verify_email_domain(email: str) -> None:
)
def send_user_verification_email(
user_email: str,
token: str,
mail_from: str = EMAIL_FROM,
) -> None:
msg = MIMEMultipart()
msg["Subject"] = "Onyx Email Verification"
msg["To"] = user_email
if mail_from:
msg["From"] = mail_from
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
body = MIMEText(f"Click the following link to verify your email address: {link}")
msg.attach(body)
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
s.starttls()
# If credentials fails with gmail, check (You need an app password, not just the basic email password)
# https://support.google.com/accounts/answer/185833?sjid=8512343437447396151-NA
s.login(SMTP_USER, SMTP_PASS)
s.send_message(msg)
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reset_password_token_secret = USER_AUTH_SECRET
verification_token_secret = USER_AUTH_SECRET
@@ -281,7 +272,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
if not user.role.is_web_login() and user_create.role.is_web_login():
user_update = UserUpdate(
password=user_create.password,
role=user_create.role,
is_verified=user_create.is_verified,
)
user = await self.update(user_update, user)
@@ -506,7 +496,15 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def on_after_forgot_password(
self, user: User, token: str, request: Optional[Request] = None
) -> None:
logger.notice(f"User {user.id} has forgot their password. Reset token: {token}")
if not EMAIL_CONFIGURED:
logger.error(
"Email is not configured. Please configure email in the admin panel"
)
raise HTTPException(
status.HTTP_500_INTERNAL_SERVER_ERROR,
"Your admin has not enbaled this feature.",
)
send_forgot_password_email(user.email, token)
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None
@@ -624,9 +622,7 @@ def get_database_strategy(
auth_backend = AuthenticationBackend(
name="jwt" if MULTI_TENANT else "database",
transport=cookie_transport,
get_strategy=get_jwt_strategy if MULTI_TENANT else get_database_strategy, # type: ignore
name="jwt", transport=cookie_transport, get_strategy=get_jwt_strategy
) # type: ignore
@@ -713,30 +709,36 @@ async def double_check_user(
user: User | None,
optional: bool = DISABLE_AUTH,
include_expired: bool = False,
allow_anonymous_access: bool = False,
) -> User | None:
if optional:
return user
if user is not None:
# If user attempted to authenticate, verify them, do not default
# to anonymous access if it fails.
if user_needs_to_be_verified() and not user.is_verified:
raise BasicAuthenticationError(
detail="Access denied. User is not verified.",
)
if (
user.oidc_expiry
and user.oidc_expiry < datetime.now(timezone.utc)
and not include_expired
):
raise BasicAuthenticationError(
detail="Access denied. User's OIDC token has expired.",
)
return user
if allow_anonymous_access:
return None
if user is None:
raise BasicAuthenticationError(
detail="Access denied. User is not authenticated.",
)
if user_needs_to_be_verified() and not user.is_verified:
raise BasicAuthenticationError(
detail="Access denied. User is not verified.",
)
if (
user.oidc_expiry
and user.oidc_expiry < datetime.now(timezone.utc)
and not include_expired
):
raise BasicAuthenticationError(
detail="Access denied. User's OIDC token has expired.",
)
return user
raise BasicAuthenticationError(
detail="Access denied. User is not authenticated.",
)
async def current_user_with_expired_token(
@@ -751,6 +753,14 @@ async def current_limited_user(
return await double_check_user(user)
async def current_chat_accesssible_user(
user: User | None = Depends(optional_user),
) -> User | None:
return await double_check_user(
user, allow_anonymous_access=anonymous_user_enabled()
)
async def current_user(
user: User | None = Depends(optional_user),
) -> User | None:

View File

@@ -414,11 +414,21 @@ def on_setup_logging(
task_logger.setLevel(loglevel)
task_logger.propagate = False
# Hide celery task received and succeeded/failed messages
# hide celery task received spam
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] received"
strategy.logger.setLevel(logging.WARNING)
# uncomment this to hide celery task succeeded/failed spam
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] succeeded in 0.03137450001668185s: None"
trace.logger.setLevel(logging.WARNING)
def set_task_finished_log_level(logLevel: int) -> None:
"""call this to override the setLevel in on_setup_logging. We are interested
in the task timings in the cloud but it can be spammy for self hosted."""
trace.logger.setLevel(logLevel)
class TenantContextFilter(logging.Filter):
"""Logging filter to inject tenant ID into the logger's name."""

View File

@@ -60,7 +60,12 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
# rkuo: been seeing transient connection exceptions here, so upping the connection count
# from just concurrency/concurrency to concurrency/concurrency*2
SqlEngine.init_engine(
pool_size=sender.concurrency, max_overflow=sender.concurrency * 2
)
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)

View File

@@ -1,3 +1,4 @@
import logging
import multiprocessing
from typing import Any
from typing import cast
@@ -194,6 +195,10 @@ def on_setup_logging(
) -> None:
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
# this can be spammy, so just enable it in the cloud for now
if MULTI_TENANT:
app_base.set_task_finished_log_level(logging.INFO)
class HubPeriodicTask(bootsteps.StartStopStep):
"""Regularly reacquires the primary worker lock outside of the task queue.

View File

@@ -3,12 +3,54 @@ import json
from typing import Any
from typing import cast
from celery import Celery
from redis import Redis
from onyx.background.celery.configs.base import CELERY_SEPARATOR
from onyx.configs.constants import OnyxCeleryPriority
def celery_get_unacked_length(r: Redis) -> int:
"""Checking the unacked queue is useful because a non-zero length tells us there
may be prefetched tasks.
There can be other tasks in here besides indexing tasks, so this is mostly useful
just to see if the task count is non zero.
ref: https://blog.hikaru.run/2022/08/29/get-waiting-tasks-count-in-celery.html
"""
length = cast(int, r.hlen("unacked"))
return length
def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]:
"""Gets the set of task id's matching the given queue in the unacked hash.
Unacked entries belonging to the indexing queue are "prefetched", so this gives
us crucial visibility as to what tasks are in that state.
"""
tasks: set[str] = set()
for _, v in r.hscan_iter("unacked"):
v_bytes = cast(bytes, v)
v_str = v_bytes.decode("utf-8")
task = json.loads(v_str)
task_description = task[0]
task_queue = task[2]
if task_queue != queue:
continue
task_id = task_description.get("headers", {}).get("id")
if not task_id:
continue
# if the queue matches and we see the task_id, add it
tasks.add(task_id)
return tasks
def celery_get_queue_length(queue: str, r: Redis) -> int:
"""This is a redis specific way to get the length of a celery queue.
It is priority aware and knows how to count across the multiple redis lists
@@ -47,3 +89,74 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
return True
return False
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
"""Returns a list of current workers containing name_filter, or all workers if
name_filter is None.
We've empirically discovered that the celery inspect API is potentially unstable
and may hang or return empty results when celery is under load. Suggest using this
more to debug and troubleshoot than in production code.
"""
worker_names: list[str] = []
# filter for and create an indexing specific inspect object
inspect = app.control.inspect()
workers: dict[str, Any] = inspect.ping() # type: ignore
if workers:
for worker_name in list(workers.keys()):
# if the name filter not set, return all worker names
if not name_filter:
worker_names.append(worker_name)
continue
# if the name filter is set, return only worker names that contain the name filter
if name_filter not in worker_name:
continue
worker_names.append(worker_name)
return worker_names
def celery_inspect_get_reserved(worker_names: list[str], app: Celery) -> set[str]:
"""Returns a list of reserved tasks on the specified workers.
We've empirically discovered that the celery inspect API is potentially unstable
and may hang or return empty results when celery is under load. Suggest using this
more to debug and troubleshoot than in production code.
"""
reserved_task_ids: set[str] = set()
inspect = app.control.inspect(destination=worker_names)
# get the list of reserved tasks
reserved_tasks: dict[str, list] | None = inspect.reserved() # type: ignore
if reserved_tasks:
for _, task_list in reserved_tasks.items():
for task in task_list:
reserved_task_ids.add(task["id"])
return reserved_task_ids
def celery_inspect_get_active(worker_names: list[str], app: Celery) -> set[str]:
"""Returns a list of active tasks on the specified workers.
We've empirically discovered that the celery inspect API is potentially unstable
and may hang or return empty results when celery is under load. Suggest using this
more to debug and troubleshoot than in production code.
"""
active_task_ids: set[str] = set()
inspect = app.control.inspect(destination=worker_names)
# get the list of reserved tasks
active_tasks: dict[str, list] | None = inspect.active() # type: ignore
if active_tasks:
for _, task_list in active_tasks.items():
for task in task_list:
active_task_ids.add(task["id"])
return active_task_ids

View File

@@ -16,6 +16,14 @@ result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late
# Indexing worker specific ... this lets us track the transition to STARTED in redis
# We don't currently rely on this but it has the potential to be useful and
# indexing tasks are not high volume
# we don't turn this on yet because celery occasionally runs tasks more than once
# which means a duplicate run might change the task state unexpectedly
# task_track_started = True
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
worker_pool = "threads"
worker_prefetch_multiplier = 1

View File

@@ -4,6 +4,12 @@ from typing import Any
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
# choosing 15 minutes because it roughly gives us enough time to process many tasks
# we might be able to reduce this greatly if we can run a unified
# loop across all tenants rather than tasks per tenant
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# we set expires because it isn't necessary to queue up these tasks
# it's only important that they run relatively regularly
tasks_to_schedule = [
@@ -13,7 +19,7 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
@@ -22,7 +28,7 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
@@ -31,7 +37,7 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
@@ -40,7 +46,7 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=15),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
@@ -49,7 +55,7 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=3600),
"options": {
"priority": OnyxCeleryPriority.LOWEST,
"expires": 60,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
@@ -58,7 +64,7 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=5),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
@@ -67,7 +73,7 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=30),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
@@ -76,7 +82,7 @@ tasks_to_schedule = [
"schedule": timedelta(seconds=20),
"options": {
"priority": OnyxCeleryPriority.HIGH,
"expires": 60,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]

View File

@@ -34,7 +34,9 @@ class TaskDependencyError(RuntimeError):
trail=False,
bind=True,
)
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
def check_for_connector_deletion_task(
self: Task, *, tenant_id: str | None
) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
@@ -45,7 +47,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
return None
# collect cc_pair_ids
cc_pair_ids: list[int] = []
@@ -81,6 +83,8 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
if lock_beat.owned():
lock_beat.release()
return True
def try_generate_document_cc_pair_cleanup_tasks(
app: Celery,

View File

@@ -1,6 +1,8 @@
import time
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from time import sleep
from uuid import uuid4
from celery import Celery
@@ -18,6 +20,7 @@ from onyx.access.models import DocExternalAccess
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
@@ -88,10 +91,10 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None:
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -99,7 +102,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
return None
# get all cc pairs that need to be synced
cc_pair_ids_to_sync: list[int] = []
@@ -128,6 +131,8 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None
if lock_beat.owned():
lock_beat.release()
return True
def try_creating_permissions_sync_task(
app: Celery,
@@ -219,6 +224,43 @@ def connector_permission_sync_generator_task(
r = get_redis_client(tenant_id=tenant_id)
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
# before the primary worker can finalize the fence
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise ValueError(
f"connector_permission_sync_generator_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.permissions.fence_key}"
)
if not redis_connector.permissions.fenced: # The fence must exist
raise ValueError(
f"connector_permission_sync_generator_task - fence not found: "
f"fence={redis_connector.permissions.fence_key}"
)
payload = redis_connector.permissions.payload # The payload must exist
if not payload:
raise ValueError(
"connector_permission_sync_generator_task: payload invalid or not found"
)
if payload.celery_task_id is None:
logger.info(
f"connector_permission_sync_generator_task - Waiting for fence: "
f"fence={redis_connector.permissions.fence_key}"
)
sleep(1)
continue
logger.info(
f"connector_permission_sync_generator_task - Fence found, continuing...: "
f"fence={redis_connector.permissions.fence_key}"
)
break
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
@@ -254,8 +296,11 @@ def connector_permission_sync_generator_task(
if not payload:
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
payload.started = datetime.now(timezone.utc)
redis_connector.permissions.set_fence(payload)
new_payload = RedisConnectorPermissionSyncPayload(
started=datetime.now(timezone.utc),
celery_task_id=payload.celery_task_id,
)
redis_connector.permissions.set_fence(new_payload)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)

View File

@@ -94,10 +94,10 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -105,7 +105,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
return None
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
@@ -149,6 +149,8 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
if lock_beat.owned():
lock_beat.release()
return True
def try_creating_external_group_sync_task(
app: Celery,
@@ -162,7 +164,7 @@ def try_creating_external_group_sync_task(
LOCK_TIMEOUT = 30
lock = r.lock(
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
timeout=LOCK_TIMEOUT,
)

View File

@@ -1,9 +1,11 @@
import os
import sys
import time
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from time import sleep
from typing import Any
from typing import cast
import redis
import sentry_sdk
@@ -18,10 +20,12 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import DocumentSource
@@ -29,6 +33,7 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
@@ -69,14 +74,18 @@ logger = setup_logger()
class IndexingCallback(IndexingHeartbeatInterface):
PARENT_CHECK_INTERVAL = 60
def __init__(
self,
parent_pid: int,
stop_key: str,
generator_progress_key: str,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.parent_pid = parent_pid
self.redis_lock: RedisLock = redis_lock
self.stop_key: str = stop_key
self.generator_progress_key: str = generator_progress_key
@@ -87,25 +96,68 @@ class IndexingCallback(IndexingHeartbeatInterface):
self.last_tag: str = "IndexingCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
self.last_parent_check = time.monotonic()
def should_stop(self) -> bool:
if self.redis_client.exists(self.stop_key):
return True
return False
def progress(self, tag: str, amount: int) -> None:
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
# so leave this code in until we're ready to test it.
# if self.parent_pid:
# # check if the parent pid is alive so we aren't running as a zombie
# now = time.monotonic()
# if now - self.last_parent_check > IndexingCallback.PARENT_CHECK_INTERVAL:
# try:
# # this is unintuitive, but it checks if the parent pid is still running
# os.kill(self.parent_pid, 0)
# except Exception:
# logger.exception("IndexingCallback - parent pid check exceptioned")
# raise
# self.last_parent_check = now
try:
self.redis_lock.reacquire()
self.last_tag = tag
self.last_lock_reacquire = datetime.now(timezone.utc)
except LockError:
logger.exception(
f"IndexingCallback - lock.reacquire exceptioned. "
f"IndexingCallback - lock.reacquire exceptioned: "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
# diagnostic logging for lock errors
name = self.redis_lock.name
ttl = self.redis_client.ttl(name)
locked = self.redis_lock.locked()
owned = self.redis_lock.owned()
local_token: str | None = self.redis_lock.local.token # type: ignore
remote_token_raw = self.redis_client.get(self.redis_lock.name)
if remote_token_raw:
remote_token_bytes = cast(bytes, remote_token_raw)
remote_token = remote_token_bytes.decode("utf-8")
else:
remote_token = None
logger.warning(
f"IndexingCallback - lock diagnostics: "
f"name={name} "
f"locked={locked} "
f"owned={owned} "
f"local_token={local_token} "
f"remote_token={remote_token} "
f"ttl={ttl}"
)
raise
self.redis_client.incrby(self.generator_progress_key, amount)
@@ -175,7 +227,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
# redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = redis_client.lock(
OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK,
@@ -318,23 +370,19 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
attempt.id, db_session, failure_reason=failure_reason
)
# rkuo: The following code logically appears to work, but the celery inspect code may be unstable
# turning off for the moment to see if it helps cloud stability
# we want to run this less frequently than the overall task
# if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
# # clear any indexing fences that don't have associated celery tasks in progress
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# # or be currently executing
# try:
# task_logger.info("Validating indexing fences...")
# validate_indexing_fences(
# tenant_id, self.app, redis_client, redis_client_celery, lock_beat
# )
# except Exception:
# task_logger.exception("Exception while validating indexing fences")
if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
# clear any indexing fences that don't have associated celery tasks in progress
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
validate_indexing_fences(
tenant_id, self.app, redis_client, redis_client_celery, lock_beat
)
except Exception:
task_logger.exception("Exception while validating indexing fences")
# redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
except SoftTimeLimitExceeded:
task_logger.info(
@@ -353,7 +401,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
)
time_elapsed = time.monotonic() - time_start
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
task_logger.debug(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
return tasks_created
@@ -364,46 +412,9 @@ def validate_indexing_fences(
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
reserved_indexing_tasks: set[str] = set()
active_indexing_tasks: set[str] = set()
indexing_worker_names: list[str] = []
# filter for and create an indexing specific inspect object
inspect = celery_app.control.inspect()
workers: dict[str, Any] = inspect.ping() # type: ignore
if not workers:
raise ValueError("No workers found!")
for worker_name in list(workers.keys()):
if "indexing" in worker_name:
indexing_worker_names.append(worker_name)
if len(indexing_worker_names) == 0:
raise ValueError("No indexing workers found!")
inspect_indexing = celery_app.control.inspect(destination=indexing_worker_names)
# NOTE: each dict entry is a map of worker name to a list of tasks
# we want sets for reserved task and active task id's to optimize
# subsequent validation lookups
# get the list of reserved tasks
reserved_tasks: dict[str, list] | None = inspect_indexing.reserved() # type: ignore
if reserved_tasks is None:
raise ValueError("inspect_indexing.reserved() returned None!")
for _, task_list in reserved_tasks.items():
for task in task_list:
reserved_indexing_tasks.add(task["id"])
# get the list of active tasks
active_tasks: dict[str, list] | None = inspect_indexing.active() # type: ignore
if active_tasks is None:
raise ValueError("inspect_indexing.active() returned None!")
for _, task_list in active_tasks.items():
for task in task_list:
active_indexing_tasks.add(task["id"])
reserved_indexing_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
# validate all existing indexing jobs
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
@@ -413,7 +424,6 @@ def validate_indexing_fences(
tenant_id,
key_bytes,
reserved_indexing_tasks,
active_indexing_tasks,
r_celery,
db_session,
)
@@ -424,7 +434,6 @@ def validate_indexing_fence(
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
active_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
@@ -434,11 +443,15 @@ def validate_indexing_fence(
gives the help.
How this works:
1. Active signal is renewed with a 5 minute TTL
1.1 When the fence is created
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved or active list for a worker
2. The TTL allows us to get through the transitions on fence startup
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
@@ -466,6 +479,8 @@ def validate_indexing_fence(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# check to see if the fence/payload exists
if not redis_connector_index.fenced:
return
@@ -501,24 +516,24 @@ def validate_indexing_fence(
redis_connector_index.set_active()
return
if payload.celery_task_id in active_tasks:
# the celery task is active (aka currently executing)
redis_connector_index.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# we didn't find any direct indication that associated celery tasks exist, but they still might be there
# due to gaps in our ability to check states during transitions
# Rely on the active signal (which has a duration that allows us to bridge those gaps)
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector_index.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
logger.warning(
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: fence={fence_key}"
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
f"index_attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"fence={fence_key}"
)
if payload.index_attempt_id:
try:
@@ -783,7 +798,6 @@ def connector_indexing_proxy_task(
return
task_logger.info(
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
@@ -795,6 +809,58 @@ def connector_indexing_proxy_task(
while True:
sleep(5)
# renew active signal
redis_connector_index.set_active()
# if the job is done, clean up and break
if job.done():
try:
if job.status == "error":
ignore_exitcode = False
exit_code: int | None = None
if job.process:
exit_code = job.process.exitcode
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
# even though logging clearly indicates successful completion
# to work around this, we ignore the job error state if the completion signal is OK
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if not ignore_exitcode:
raise RuntimeError("Spawned task exceptioned.")
task_logger.warning(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code}"
)
except Exception:
task_logger.error(
"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code} "
f"error={job.exception()}"
)
raise
finally:
job.release()
break
# if a termination signal is detected, clean up and break
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing watchdog - termination signal detected: "
@@ -821,75 +887,33 @@ def connector_indexing_proxy_task(
f"search_settings={search_settings_id}"
)
job.cancel()
job.cancel()
break
if not job.done():
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
continue
if job.status == "error":
ignore_exitcode = False
exit_code: int | None = None
if job.process:
exit_code = job.process.exitcode
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
# even though logging clearly indicates that they completed successfully
# to work around this, we ignore the job error state if the completion signal is OK
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if ignore_exitcode:
task_logger.warning(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code}"
)
else:
task_logger.error(
"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code} "
f"error={job.exception()}"
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
job.release()
break
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
continue
task_logger.info(
f"Indexing watchdog - finished: attempt={index_attempt_id} "
@@ -918,7 +942,7 @@ def connector_indexing_task_wrapper(
tenant_id,
is_ee,
)
except:
except Exception:
logger.exception(
f"connector_indexing_task exceptioned: "
f"tenant={tenant_id} "
@@ -926,13 +950,20 @@ def connector_indexing_task_wrapper(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise
# There is a cloud related bug outside of our code
# where spawned tasks return with an exit code of 1.
# Unfortunately, exceptions also return with an exit code of 1,
# so just raising an exception isn't informative
# Exiting with 255 makes it possible to distinguish between normal exits
# and exceptions.
sys.exit(255)
return result
def connector_indexing_task(
index_attempt_id: int,
index_attempt_id: int | None,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
@@ -998,7 +1029,17 @@ def connector_indexing_task(
f"fence={redis_connector.stop.fence_key}"
)
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
# before the primary worker can finalize the fence
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise ValueError(
f"connector_indexing_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.permissions.fence_key}"
)
if not redis_connector_index.fenced: # The fence must exist
raise ValueError(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
@@ -1039,7 +1080,9 @@ def connector_indexing_task(
if not acquired:
logger.warning(
f"Indexing task already running, exiting...: "
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return None
@@ -1075,6 +1118,7 @@ def connector_indexing_task(
# define a callback class
callback = IndexingCallback(
os.getppid(),
redis_connector.stop.fence_key,
redis_connector_index.generator_progress_key,
lock,
@@ -1108,8 +1152,19 @@ def connector_indexing_task(
f"search_settings={search_settings_id}"
)
if attempt_found:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_failed(index_attempt_id, db_session, failure_reason=str(e))
try:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_failed(
index_attempt_id, db_session, failure_reason=str(e)
)
except Exception:
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise e
finally:

View File

@@ -81,10 +81,10 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
lock_beat = r.lock(
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
)
@@ -92,7 +92,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
return None
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
@@ -127,6 +127,8 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
if lock_beat.owned():
lock_beat.release()
return True
def try_creating_prune_generator_task(
celery_app: Celery,
@@ -283,6 +285,7 @@ def connector_pruning_generator_task(
)
callback = IndexingCallback(
0,
redis_connector.stop.fence_key,
redis_connector.prune.generator_progress_key,
lock,

View File

@@ -20,6 +20,7 @@ from tenacity import RetryError
from onyx.access.access import get_access_for_document
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
@@ -87,7 +88,7 @@ logger = setup_logger()
trail=False,
bind=True,
)
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
time_start = time.monotonic()
@@ -102,7 +103,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
try:
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return
return None
with get_session_with_tenant(tenant_id) as db_session:
try_generate_stale_document_sync_tasks(
@@ -164,8 +165,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
lock_beat.release()
time_elapsed = time.monotonic() - time_start
task_logger.info(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
return
task_logger.debug(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
return True
def try_generate_stale_document_sync_tasks(
@@ -636,15 +637,23 @@ def monitor_ccpair_indexing_taskset(
if not payload:
return
elapsed_started_str = None
if payload.started:
elapsed_started = datetime.now(timezone.utc) - payload.started
elapsed_started_str = f"{elapsed_started.total_seconds():.2f}"
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
progress = redis_connector_index.get_progress()
if progress is not None:
task_logger.info(
f"Connector indexing progress: cc_pair={cc_pair_id} "
f"Connector indexing progress: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
if payload.index_attempt_id is None or payload.celery_task_id is None:
@@ -715,11 +724,14 @@ def monitor_ccpair_indexing_taskset(
status_enum = HTTPStatus(status_int)
task_logger.info(
f"Connector indexing finished: cc_pair={cc_pair_id} "
f"Connector indexing finished: "
f"attempt={payload.index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"progress={progress} "
f"status={status_enum.name} "
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
f"elapsed_started={elapsed_started_str}"
)
redis_connector_index.reset()
@@ -765,32 +777,43 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
n_permissions_sync = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
)
n_external_group_sync = celery_get_queue_length(
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
n_permissions_upsert = celery_get_queue_length(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
prefetched = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
)
task_logger.info(
f"Queue lengths: celery={n_celery} "
f"indexing={n_indexing} "
f"indexing_prefetched={len(prefetched)} "
f"sync={n_sync} "
f"deletion={n_deletion} "
f"pruning={n_pruning} "
f"permissions_sync={n_permissions_sync} "
f"external_group_sync={n_external_group_sync} "
f"permissions_upsert={n_permissions_upsert} "
)
# scan and monitor activity to completion
lock_beat.reacquire()
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
monitor_connector_taskset(r)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
lock_beat.reacquire()
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
lock_beat.reacquire()
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
@@ -801,28 +824,21 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
lock_beat.reacquire()
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
# uncomment for debugging if needed
# r_celery = celery_app.broker_connection().channel().client
# length = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
# task_logger.warning(f"queue={OnyxCeleryQueues.VESPA_METADATA_SYNC} length={length}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -832,7 +848,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
lock_beat.release()
time_elapsed = time.monotonic() - time_start
task_logger.info(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
task_logger.debug(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
return True

View File

@@ -14,6 +14,7 @@ from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import MilestoneRecordType
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
@@ -90,6 +91,35 @@ def _get_connector_runner(
)
def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
cleaned_batch = []
for doc in doc_batch:
cleaned_doc = doc.model_copy()
if "\x00" in cleaned_doc.id:
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
if "\x00" in cleaned_doc.semantic_identifier:
logger.warning(
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
)
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
"\x00", ""
)
for section in cleaned_doc.sections:
if section.link and "\x00" in section.link:
logger.warning(
f"NUL characters found in document link for document: {cleaned_doc.id}"
)
section.link = section.link.replace("\x00", "")
cleaned_batch.append(cleaned_doc)
return cleaned_batch
class ConnectorStopSignal(Exception):
"""A custom exception used to signal a stop in processing."""
@@ -238,7 +268,9 @@ def _run_indexing(
)
batch_description = []
for doc in doc_batch:
doc_batch_cleaned = strip_null_characters(doc_batch)
for doc in doc_batch_cleaned:
batch_description.append(doc.to_short_descriptor())
doc_size = 0
@@ -258,15 +290,15 @@ def _run_indexing(
# real work happens here!
new_docs, total_batch_chunks = indexing_pipeline(
document_batch=doc_batch,
document_batch=doc_batch_cleaned,
index_attempt_metadata=index_attempt_md,
)
batch_num += 1
net_doc_change += new_docs
chunk_count += total_batch_chunks
document_count += len(doc_batch)
all_connector_doc_ids.update(doc.id for doc in doc_batch)
document_count += len(doc_batch_cleaned)
all_connector_doc_ids.update(doc.id for doc in doc_batch_cleaned)
# commit transaction so that the `update` below begins
# with a brand new transaction. Postgres uses the start
@@ -276,7 +308,7 @@ def _run_indexing(
db_session.commit()
if callback:
callback.progress("_run_indexing", len(doc_batch))
callback.progress("_run_indexing", len(doc_batch_cleaned))
# This new value is updated every batch, so UI can refresh per batch update
update_docs_indexed(

View File

@@ -58,6 +58,9 @@ SESSION_EXPIRE_TIME_SECONDS = int(
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
) # 7 days
# Default request timeout, mostly used by connectors
REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS") or 60)
# set `VALID_EMAIL_DOMAINS` to a comma seperated list of domains in order to
# restrict access to Onyx to only users with emails from those domains.
# E.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will restrict Onyx
@@ -92,6 +95,7 @@ SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
EMAIL_CONFIGURED = all([SMTP_SERVER, SMTP_USER, SMTP_PASS])
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
# If set, Onyx will listen to the `expires_at` returned by the identity
@@ -145,7 +149,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
AWS_REGION = os.environ.get("AWS_REGION") or "us-east-2"
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
POSTGRES_API_SERVER_POOL_SIZE = int(
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
@@ -184,6 +188,25 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
# Rate limiting for auth endpoints
RATE_LIMIT_WINDOW_SECONDS: int | None = None
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
if _rate_limit_window_seconds_str is not None:
try:
RATE_LIMIT_WINDOW_SECONDS = int(_rate_limit_window_seconds_str)
except ValueError:
pass
RATE_LIMIT_MAX_REQUESTS: int | None = None
_rate_limit_max_requests_str = os.environ.get("RATE_LIMIT_MAX_REQUESTS")
if _rate_limit_max_requests_str is not None:
try:
RATE_LIMIT_MAX_REQUESTS = int(_rate_limit_max_requests_str)
except ValueError:
pass
# Used for general redis things
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
@@ -347,12 +370,17 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
)
# Typically set to http://localhost:3000 for OAuth connector development
CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE")
# Egnyte specific configs
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
# Linear specific configs
LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
)

View File

@@ -36,6 +36,8 @@ DISABLED_GEN_AI_MSG = (
DEFAULT_PERSONA_ID = 0
DEFAULT_CC_PAIR_ID = 1
# Postgres connection constants for application_name
POSTGRES_WEB_APP_NAME = "web"
POSTGRES_INDEXER_APP_NAME = "indexer"
@@ -81,6 +83,9 @@ CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
# how long a task should wait for associated fence to be ready
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
@@ -137,6 +142,7 @@ class DocumentSource(str, Enum):
FRESHDESK = "freshdesk"
FIREFLIES = "fireflies"
EGNYTE = "egnyte"
AIRTABLE = "airtable"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
@@ -215,6 +221,7 @@ class FileOrigin(str, Enum):
CHAT_IMAGE_GEN = "chat_image_gen"
CONNECTOR = "connector"
GENERATED_REPORT = "generated_report"
MY_DOCUMENTS = "my_documents"
OTHER = "other"
@@ -273,6 +280,7 @@ class OnyxRedisLocks:
SLACK_BOT_LOCK = "da_lock:slack_bot"
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
class OnyxRedisSignals:

View File

@@ -0,0 +1,268 @@
from io import BytesIO
from typing import Any
import requests
from pyairtable import Api as AirtableApi
from retry import retry
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.utils.logger import setup_logger
logger = setup_logger()
# NOTE: all are made lowercase to avoid case sensitivity issues
# these are the field types that are considered metadata rather
# than sections
_METADATA_FIELD_TYPES = {
"singlecollaborator",
"collaborator",
"createdby",
"singleselect",
"multipleselects",
"checkbox",
"date",
"datetime",
"email",
"phone",
"url",
"number",
"currency",
"duration",
"percent",
"rating",
"createdtime",
"lastmodifiedtime",
"autonumber",
"rollup",
"lookup",
"count",
"formula",
"date",
}
class AirtableClientNotSetUpError(PermissionError):
def __init__(self) -> None:
super().__init__("Airtable Client is not set up, was load_credentials called?")
class AirtableConnector(LoadConnector):
def __init__(
self,
base_id: str,
table_name_or_id: str,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.base_id = base_id
self.table_name_or_id = table_name_or_id
self.batch_size = batch_size
self.airtable_client: AirtableApi | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
return None
def _get_field_value(self, field_info: Any, field_type: str) -> list[str]:
"""
Extract value(s) from a field regardless of its type.
Returns either a single string or list of strings for attachments.
"""
if field_info is None:
return []
# skip references to other records for now (would need to do another
# request to get the actual record name/type)
# TODO: support this
if field_type == "multipleRecordLinks":
return []
if field_type == "multipleAttachments":
attachment_texts: list[str] = []
for attachment in field_info:
url = attachment.get("url")
filename = attachment.get("filename", "")
if not url:
continue
@retry(
tries=5,
delay=1,
backoff=2,
max_delay=10,
)
def get_attachment_with_retry(url: str) -> bytes | None:
attachment_response = requests.get(url)
if attachment_response.status_code == 200:
return attachment_response.content
return None
attachment_content = get_attachment_with_retry(url)
if attachment_content:
try:
file_ext = get_file_ext(filename)
attachment_text = extract_file_text(
BytesIO(attachment_content),
filename,
break_on_unprocessable=False,
extension=file_ext,
)
if attachment_text:
attachment_texts.append(f"{filename}:\n{attachment_text}")
except Exception as e:
logger.warning(
f"Failed to process attachment {filename}: {str(e)}"
)
return attachment_texts
if field_type in ["singleCollaborator", "collaborator", "createdBy"]:
combined = []
collab_name = field_info.get("name")
collab_email = field_info.get("email")
if collab_name:
combined.append(collab_name)
if collab_email:
combined.append(f"({collab_email})")
return [" ".join(combined) if combined else str(field_info)]
if isinstance(field_info, list):
return [str(item) for item in field_info]
return [str(field_info)]
def _should_be_metadata(self, field_type: str) -> bool:
"""Determine if a field type should be treated as metadata."""
return field_type.lower() in _METADATA_FIELD_TYPES
def _process_field(
self,
field_name: str,
field_info: Any,
field_type: str,
table_id: str,
record_id: str,
) -> tuple[list[Section], dict[str, Any]]:
"""
Process a single Airtable field and return sections or metadata.
Args:
field_name: Name of the field
field_info: Raw field information from Airtable
field_type: Airtable field type
Returns:
(list of Sections, dict of metadata)
"""
if field_info is None:
return [], {}
# Get the value(s) for the field
field_values = self._get_field_value(field_info, field_type)
if len(field_values) == 0:
return [], {}
# Determine if it should be metadata or a section
if self._should_be_metadata(field_type):
if len(field_values) > 1:
return [], {field_name: field_values}
return [], {field_name: field_values[0]}
# Otherwise, create relevant sections
sections = [
Section(
link=f"https://airtable.com/{self.base_id}/{table_id}/{record_id}",
text=(
f"{field_name}:\n"
"------------------------\n"
f"{text}\n"
"------------------------"
),
)
for text in field_values
]
return sections, {}
def load_from_state(self) -> GenerateDocumentsOutput:
"""
Fetch all records from the table.
NOTE: Airtable does not support filtering by time updated, so
we have to fetch all records every time.
"""
if not self.airtable_client:
raise AirtableClientNotSetUpError()
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
table_id = table.id
# due to https://community.airtable.com/t5/development-apis/pagination-returns-422-error/td-p/54778,
# we can't user the `iterate()` method - we need to get everything up front
# this also means we can't handle tables that won't fit in memory
records = table.all()
table_schema = table.schema()
# have to get the name from the schema, since the table object will
# give back the ID instead of the name if the ID is used to create
# the table object
table_name = table_schema.name
primary_field_name = None
# Find a primary field from the schema
for field in table_schema.fields:
if field.id == table_schema.primary_field_id:
primary_field_name = field.name
break
record_documents: list[Document] = []
for record in records:
record_id = record["id"]
fields = record["fields"]
sections: list[Section] = []
metadata: dict[str, Any] = {}
# Possibly retrieve the primary field's value
primary_field_value = (
fields.get(primary_field_name) if primary_field_name else None
)
for field_schema in table_schema.fields:
field_name = field_schema.name
field_val = fields.get(field_name)
field_type = field_schema.type
field_sections, field_metadata = self._process_field(
field_name=field_name,
field_info=field_val,
field_type=field_type,
table_id=table_id,
record_id=record_id,
)
sections.extend(field_sections)
metadata.update(field_metadata)
semantic_id = (
f"{table_name}: {primary_field_value}"
if primary_field_value
else table_name
)
record_document = Document(
id=f"airtable__{record_id}",
sections=sections,
source=DocumentSource.AIRTABLE,
semantic_identifier=semantic_id,
metadata=metadata,
)
record_documents.append(record_document)
if len(record_documents) >= self.batch_size:
yield record_documents
record_documents = []
if record_documents:
yield record_documents

View File

@@ -56,6 +56,23 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
_SLIM_DOC_BATCH_SIZE = 5000
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
"png",
"jpg",
"jpeg",
"gif",
"mp4",
"mov",
"mp3",
"wav",
]
_FULL_EXTENSION_FILTER_STRING = "".join(
[
f" and title!~'*.{extension}'"
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
]
)
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
@@ -64,7 +81,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
is_cloud: bool,
space: str = "",
page_id: str = "",
index_recursively: bool = True,
index_recursively: bool = False,
cql_query: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
@@ -82,23 +99,25 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
# if nothing is provided, we will fetch all pages
cql_page_query = "type=page"
"""
If nothing is provided, we default to fetching all pages
Only one or none of the following options should be specified so
the order shouldn't matter
However, we use elif to ensure that only of the following is enforced
"""
base_cql_page_query = "type=page"
if cql_query:
# if a cql_query is provided, we will use it to fetch the pages
cql_page_query = cql_query
base_cql_page_query = cql_query
elif page_id:
# if a cql_query is not provided, we will use the page_id to fetch the page
if index_recursively:
cql_page_query += f" and ancestor='{page_id}'"
base_cql_page_query += f" and (ancestor='{page_id}' or id='{page_id}')"
else:
cql_page_query += f" and id='{page_id}'"
base_cql_page_query += f" and id='{page_id}'"
elif space:
# if no cql_query or page_id is provided, we will use the space to fetch the pages
cql_page_query += f" and space='{quote(space)}'"
uri_safe_space = quote(space)
base_cql_page_query += f" and space='{uri_safe_space}'"
self.cql_page_query = cql_page_query
self.cql_time_filter = ""
self.base_cql_page_query = base_cql_page_query
self.cql_label_filter = ""
if labels_to_skip:
@@ -126,6 +145,33 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
)
return None
def _construct_page_query(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> str:
page_query = self.base_cql_page_query + self.cql_label_filter
# Add time filters
if start:
formatted_start_time = datetime.fromtimestamp(
start, tz=self.timezone
).strftime("%Y-%m-%d %H:%M")
page_query += f" and lastmodified >= '{formatted_start_time}'"
if end:
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
page_query += f" and lastmodified <= '{formatted_end_time}'"
return page_query
def _construct_attachment_query(self, confluence_page_id: str) -> str:
attachment_query = f"type=attachment and container='{confluence_page_id}'"
attachment_query += self.cql_label_filter
attachment_query += _FULL_EXTENSION_FILTER_STRING
return attachment_query
def _get_comment_string_for_page_id(self, page_id: str) -> str:
comment_string = ""
@@ -205,11 +251,15 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
metadata=doc_metadata,
)
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
def _fetch_document_batches(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
doc_batch: list[Document] = []
confluence_page_ids: list[str] = []
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
page_query = self._construct_page_query(start, end)
logger.debug(f"page_query: {page_query}")
# Fetch pages as Documents
for page in self.confluence_client.paginated_cql_retrieval(
@@ -228,11 +278,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
# Fetch attachments as Documents
for confluence_page_id in confluence_page_ids:
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
attachment_cql += self.cql_label_filter
attachment_query = self._construct_attachment_query(confluence_page_id)
# TODO: maybe should add time filter as well?
for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_cql,
cql=attachment_query,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
doc = self._convert_object_to_document(attachment)
@@ -248,17 +297,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_document_batches()
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
# Add time filters
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
"%Y-%m-%d %H:%M"
)
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
return self._fetch_document_batches()
def poll_source(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
return self._fetch_document_batches(start, end)
def retrieve_all_slim_documents(
self,
@@ -269,7 +313,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
page_query = self.cql_page_query + self.cql_label_filter
page_query = self.base_cql_page_query + self.cql_label_filter
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
@@ -294,10 +338,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
perm_sync_data=page_perm_sync_data,
)
)
attachment_cql = f"type=attachment and container='{page['id']}'"
attachment_cql += self.cql_label_filter
attachment_query = self._construct_attachment_query(page["id"])
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_cql,
cql=attachment_query,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):

View File

@@ -153,7 +153,7 @@ class OnyxConfluence(Confluence):
try:
response = self.get(url, params=params)
except HTTPError as e:
if e.response.status_code == 403:
if e.response is not None and e.response.status_code == 403:
raise ApiPermissionError(
"The calling user does not have permission", reason=e
)

View File

@@ -6,6 +6,7 @@ from typing import TypeVar
from dateutil.parser import parse
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
from onyx.configs.constants import IGNORE_FOR_QA
from onyx.connectors.models import BasicExpertInfo
from onyx.utils.text_processing import is_valid_email
@@ -71,3 +72,10 @@ def process_in_batches(
def get_metadata_keys_to_ignore() -> list[str]:
return [IGNORE_FOR_QA]
def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
if CONNECTOR_LOCALHOST_OVERRIDE:
# Used for development
base_domain = CONNECTOR_LOCALHOST_OVERRIDE
return f"{base_domain.strip('/')}/connector/oauth/callback/{connector_id}"

View File

@@ -190,7 +190,7 @@ class DiscourseConnector(PollConnector):
start: datetime,
end: datetime,
) -> GenerateDocumentsOutput:
page = 1
page = 0
while topic_ids := self._get_latest_topics(start, end, page):
doc_batch: list[Document] = []
for topic_id in topic_ids:

View File

@@ -3,20 +3,19 @@ import os
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from logging import Logger
from typing import Any
from typing import cast
from typing import IO
from urllib.parse import quote
import requests
from retry import retry
from pydantic import Field
from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN
from onyx.configs.app_configs import EGNYTE_CLIENT_ID
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
from onyx.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_oauth_callback_uri,
)
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import OAuthConnector
@@ -33,53 +32,13 @@ from onyx.file_processing.extract_file_text import is_text_file_extension
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import read_text_file
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
logger = setup_logger()
_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
_TIMEOUT = 60
def _request_with_retries(
method: str,
url: str,
data: dict[str, Any] | None = None,
headers: dict[str, Any] | None = None,
params: dict[str, Any] | None = None,
timeout: int = _TIMEOUT,
stream: bool = False,
tries: int = 8,
delay: float = 1,
backoff: float = 2,
) -> requests.Response:
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
def _make_request() -> requests.Response:
response = requests.request(
method,
url,
data=data,
headers=headers,
params=params,
timeout=timeout,
stream=stream,
)
try:
response.raise_for_status()
except requests.exceptions.HTTPError as e:
if e.response.status_code != 403:
logger.exception(
f"Failed to call Egnyte API.\n"
f"URL: {url}\n"
f"Headers: {headers}\n"
f"Data: {data}\n"
f"Params: {params}"
)
raise e
return response
return _make_request()
def _parse_last_modified(last_modified: str) -> datetime:
@@ -166,6 +125,15 @@ def _process_egnyte_file(
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
class AdditionalOauthKwargs(OAuthConnector.AdditionalOauthKwargs):
egnyte_domain: str = Field(
title="Egnyte Domain",
description=(
"The domain for the Egnyte instance "
"(e.g. 'company' for company.egnyte.com)"
),
)
def __init__(
self,
folder_path: str | None = None,
@@ -181,18 +149,20 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
return DocumentSource.EGNYTE
@classmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
def oauth_authorization_url(
cls,
base_domain: str,
state: str,
additional_kwargs: dict[str, str],
) -> str:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
if EGNYTE_LOCALHOST_OVERRIDE:
base_domain = EGNYTE_LOCALHOST_OVERRIDE
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte"
callback_uri = get_oauth_callback_uri(base_domain, "egnyte")
return (
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
f"?client_id={EGNYTE_CLIENT_ID}"
f"&redirect_uri={callback_uri}"
f"&scope=Egnyte.filesystem"
@@ -201,17 +171,23 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
)
@classmethod
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
def oauth_code_to_token(
cls,
base_domain: str,
code: str,
additional_kwargs: dict[str, str],
) -> dict[str, Any]:
if not EGNYTE_CLIENT_ID:
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
if not EGNYTE_CLIENT_SECRET:
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
if not EGNYTE_BASE_DOMAIN:
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
# Exchange code for token
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
redirect_uri = f"{EGNYTE_LOCALHOST_OVERRIDE or base_domain}/connector/oauth/callback/egnyte"
url = f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
redirect_uri = get_oauth_callback_uri(base_domain, "egnyte")
data = {
"client_id": EGNYTE_CLIENT_ID,
"client_secret": EGNYTE_CLIENT_SECRET,
@@ -222,7 +198,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = _request_with_retries(
response = request_with_retries(
method="POST",
url=url,
data=data,
@@ -236,7 +212,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
token_data = response.json()
return {
"domain": EGNYTE_BASE_DOMAIN,
"domain": oauth_kwargs.egnyte_domain,
"access_token": token_data["access_token"],
}
@@ -260,9 +236,10 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
"list_content": True,
}
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{path or ''}"
response = _request_with_retries(
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
url_encoded_path = quote(path or "")
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
response = request_with_retries(
method="GET", url=url, headers=headers, params=params
)
if not response.ok:
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
@@ -315,12 +292,12 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
headers = {
"Authorization": f"Bearer {self.access_token}",
}
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}"
response = _request_with_retries(
url_encoded_path = quote(file["path"])
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
response = request_with_retries(
method="GET",
url=url,
headers=headers,
timeout=_TIMEOUT,
stream=True,
)

View File

@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import DocumentSourceRequiringTenantContext
from onyx.connectors.airtable.airtable_connector import AirtableConnector
from onyx.connectors.asana.connector import AsanaConnector
from onyx.connectors.axero.connector import AxeroConnector
from onyx.connectors.blob.connector import BlobStorageConnector
@@ -103,6 +104,7 @@ def identify_connector_class(
DocumentSource.FRESHDESK: FreshdeskConnector,
DocumentSource.FIREFLIES: FirefliesConnector,
DocumentSource.EGNYTE: EgnyteConnector,
DocumentSource.AIRTABLE: AirtableConnector,
}
connector_by_source = connector_map.get(source, {})

View File

@@ -4,6 +4,7 @@ from typing import Dict
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
@@ -249,17 +250,36 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
return new_creds_dict
def _get_all_user_emails(self) -> list[str]:
admin_service = get_admin_service(self.creds, self.primary_admin_email)
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
):
if email := user.get("primaryEmail"):
emails.append(email)
return emails
"""
List all user emails if we are on a Google Workspace domain.
If the domain is gmail.com, or if we attempt to call the Admin SDK and
get a 404, fall back to using the single user.
"""
try:
admin_service = get_admin_service(self.creds, self.primary_admin_email)
emails = []
for user in execute_paginated_retrieval(
retrieval_function=admin_service.users().list,
list_key="users",
fields=USER_FIELDS,
domain=self.google_domain,
):
if email := user.get("primaryEmail"):
emails.append(email)
return emails
except HttpError as e:
if e.resp.status == 404:
logger.warning(
"Received 404 from Admin SDK; this may indicate a personal Gmail account "
"with no Workspace domain. Falling back to single user."
)
return [self.primary_admin_email]
raise
except Exception:
raise
def _fetch_threads(
self,

View File

@@ -8,6 +8,7 @@ from typing import cast
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
@@ -20,6 +21,7 @@ from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
from onyx.connectors.google_drive.models import GoogleDriveFileType
from onyx.connectors.google_utils.google_auth import get_google_creds
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
@@ -41,6 +43,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
logger = setup_logger()
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
@@ -286,13 +289,30 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
logger.info(f"Impersonating user {user_email}")
drive_service = get_drive_service(self.creds, user_email)
# validate that the user has access to the drive APIs by performing a simple
# request and checking for a 401
try:
retry_builder()(get_root_folder_id)(drive_service)
except HttpError as e:
if e.status_code == 401:
# fail gracefully, let the other impersonations continue
# one user without access shouldn't block the entire connector
logger.exception(
f"User '{user_email}' does not have access to the drive APIs."
)
return
raise
# if we are including my drives, try to get the current user's my
# drive if any of the following are true:
# - include_my_drives is true
# - the current user's email is in the requested emails
if self.include_my_drives or user_email in self._requested_my_drive_emails:
logger.info(f"Getting all files in my drive as '{user_email}'")
yield from get_all_files_in_my_drive(
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
@@ -303,6 +323,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
remaining_drive_ids = filtered_drive_ids - self._retrieved_ids
for drive_id in remaining_drive_ids:
logger.info(f"Getting files in shared drive '{drive_id}' as '{user_email}'")
yield from get_files_in_shared_drive(
service=drive_service,
drive_id=drive_id,
@@ -314,6 +335,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
remaining_folders = filtered_folder_ids - self._retrieved_ids
for folder_id in remaining_folders:
logger.info(f"Getting files in folder '{folder_id}' as '{user_email}'")
yield from crawl_folders_for_files(
service=drive_service,
parent_id=folder_id,
@@ -344,6 +366,15 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
elif self.include_shared_drives:
drive_ids_to_retrieve = all_drive_ids
# checkpoint - we've found all users and drives, now time to actually start
# fetching stuff
logger.info(f"Found {len(all_org_emails)} users to impersonate")
logger.debug(f"Users: {all_org_emails}")
logger.info(f"Found {len(drive_ids_to_retrieve)} drives to retrieve")
logger.debug(f"Drives: {drive_ids_to_retrieve}")
logger.info(f"Found {len(folder_ids_to_retrieve)} folders to retrieve")
logger.debug(f"Folders: {folder_ids_to_retrieve}")
# Process users in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=10) as executor:
future_to_email = {
@@ -380,6 +411,13 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
drive_service = get_drive_service(self.creds, self.primary_admin_email)
if self.include_files_shared_with_me or self.include_my_drives:
logger.info(
f"Getting shared files/my drive files for OAuth "
f"with include_files_shared_with_me={self.include_files_shared_with_me}, "
f"include_my_drives={self.include_my_drives}, "
f"include_shared_drives={self.include_shared_drives}."
f"Using '{self.primary_admin_email}' as the account."
)
yield from get_all_files_for_oauth(
service=drive_service,
include_files_shared_with_me=self.include_files_shared_with_me,
@@ -412,6 +450,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
drive_ids_to_retrieve = all_drive_ids
for drive_id in drive_ids_to_retrieve:
logger.info(
f"Getting files in shared drive '{drive_id}' as '{self.primary_admin_email}'"
)
yield from get_files_in_shared_drive(
service=drive_service,
drive_id=drive_id,
@@ -425,6 +466,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
# that could be folders.
remaining_folders = folder_ids_to_retrieve - self._retrieved_ids
for folder_id in remaining_folders:
logger.info(
f"Getting files in folder '{folder_id}' as '{self.primary_admin_email}'"
)
yield from crawl_folders_for_files(
service=drive_service,
parent_id=folder_id,

View File

@@ -2,6 +2,8 @@ import abc
from collections.abc import Iterator
from typing import Any
from pydantic import BaseModel
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
@@ -66,6 +68,10 @@ class SlimConnector(BaseConnector):
class OAuthConnector(BaseConnector):
class AdditionalOauthKwargs(BaseModel):
# if overridden, all fields should be str type
pass
@classmethod
@abc.abstractmethod
def oauth_id(cls) -> DocumentSource:
@@ -73,12 +79,22 @@ class OAuthConnector(BaseConnector):
@classmethod
@abc.abstractmethod
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
def oauth_authorization_url(
cls,
base_domain: str,
state: str,
additional_kwargs: dict[str, str],
) -> str:
raise NotImplementedError
@classmethod
@abc.abstractmethod
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
def oauth_code_to_token(
cls,
base_domain: str,
code: str,
additional_kwargs: dict[str, str],
) -> dict[str, Any]:
raise NotImplementedError

View File

@@ -7,16 +7,23 @@ from typing import cast
import requests
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import LINEAR_CLIENT_ID
from onyx.configs.app_configs import LINEAR_CLIENT_SECRET
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_oauth_callback_uri,
)
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import OAuthConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
logger = setup_logger()
@@ -57,7 +64,7 @@ def _make_query(request_body: dict[str, Any], api_key: str) -> requests.Response
)
class LinearConnector(LoadConnector, PollConnector):
class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
@@ -65,8 +72,68 @@ class LinearConnector(LoadConnector, PollConnector):
self.batch_size = batch_size
self.linear_api_key: str | None = None
@classmethod
def oauth_id(cls) -> DocumentSource:
return DocumentSource.LINEAR
@classmethod
def oauth_authorization_url(
cls, base_domain: str, state: str, additional_kwargs: dict[str, str]
) -> str:
if not LINEAR_CLIENT_ID:
raise ValueError("LINEAR_CLIENT_ID environment variable must be set")
callback_uri = get_oauth_callback_uri(base_domain, DocumentSource.LINEAR.value)
return (
f"https://linear.app/oauth/authorize"
f"?client_id={LINEAR_CLIENT_ID}"
f"&redirect_uri={callback_uri}"
f"&response_type=code"
f"&scope=read"
f"&state={state}"
)
@classmethod
def oauth_code_to_token(
cls, base_domain: str, code: str, additional_kwargs: dict[str, str]
) -> dict[str, Any]:
data = {
"code": code,
"redirect_uri": get_oauth_callback_uri(
base_domain, DocumentSource.LINEAR.value
),
"client_id": LINEAR_CLIENT_ID,
"client_secret": LINEAR_CLIENT_SECRET,
"grant_type": "authorization_code",
}
headers = {"Content-Type": "application/x-www-form-urlencoded"}
response = request_with_retries(
method="POST",
url="https://api.linear.app/oauth/token",
data=data,
headers=headers,
backoff=0,
delay=0.1,
)
if not response.ok:
raise RuntimeError(f"Failed to exchange code for token: {response.text}")
token_data = response.json()
return {
"access_token": token_data["access_token"],
}
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.linear_api_key = cast(str, credentials["linear_api_key"])
if "linear_api_key" in credentials:
self.linear_api_key = cast(str, credentials["linear_api_key"])
elif "access_token" in credentials:
self.linear_api_key = "Bearer " + cast(str, credentials["access_token"])
else:
# May need to handle case in the future if the OAuth flow expires
raise ConnectorMissingCredentialError("Linear")
return None
def _process_issues(

View File

@@ -1,11 +1,7 @@
import os
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import Any
from simple_salesforce import Salesforce
from simple_salesforce import SFType
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
@@ -19,24 +15,25 @@ from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.salesforce.utils import extract_dict_text
from onyx.connectors.salesforce.doc_conversion import extract_section
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
from onyx.connectors.salesforce.salesforce_calls import get_all_children_of_sf_type
from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.sqlite_functions import init_db
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.utils.logger import setup_logger
# TODO: this connector does not work well at large scales
# the large query against a large Salesforce instance has been reported to take 1.5 hours.
# Additionally it seems to eat up more memory over time if the connection is long running (again a scale issue).
DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
MAX_QUERY_LENGTH = 10000 # max query length is 20,000 characters
ID_PREFIX = "SALESFORCE_"
logger = setup_logger()
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
_ID_PREFIX = "SALESFORCE_"
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
@@ -44,200 +41,170 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
requested_objects: list[str] = [],
) -> None:
self.batch_size = batch_size
self.sf_client: Salesforce | None = None
self._sf_client: Salesforce | None = None
self.parent_object_list = (
[obj.capitalize() for obj in requested_objects]
if requested_objects
else DEFAULT_PARENT_OBJECT_TYPES
else _DEFAULT_PARENT_OBJECT_TYPES
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.sf_client = Salesforce(
def load_credentials(
self,
credentials: dict[str, Any],
) -> dict[str, Any] | None:
self._sf_client = Salesforce(
username=credentials["sf_username"],
password=credentials["sf_password"],
security_token=credentials["sf_security_token"],
)
return None
def _get_sf_type_object_json(self, type_name: str) -> Any:
if self.sf_client is None:
@property
def sf_client(self) -> Salesforce:
if self._sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
sf_object = SFType(
type_name, self.sf_client.session_id, self.sf_client.sf_instance
)
return sf_object.describe()
return self._sf_client
def _get_name_from_id(self, id: str) -> str:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
try:
user_object_info = self.sf_client.query(
f"SELECT Name FROM User WHERE Id = '{id}'"
)
name = user_object_info.get("Records", [{}])[0].get("Name", "Null User")
return name
except Exception:
logger.warning(f"Couldnt find name for object id: {id}")
return "Null User"
def _extract_primary_owners(
self, sf_object: SalesforceObject
) -> list[BasicExpertInfo] | None:
object_dict = sf_object.data
if not (last_modified_by_id := object_dict.get("LastModifiedById")):
return None
if not (last_modified_by := get_record(last_modified_by_id)):
return None
if not (last_modified_by_name := last_modified_by.data.get("Name")):
return None
primary_owners = [BasicExpertInfo(display_name=last_modified_by_name)]
return primary_owners
def _convert_object_instance_to_document(
self, object_dict: dict[str, Any]
self, sf_object: SalesforceObject
) -> Document:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
object_dict = sf_object.data
salesforce_id = object_dict["Id"]
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
extracted_link = f"https://{self.sf_client.sf_instance}/{salesforce_id}"
onyx_salesforce_id = f"{_ID_PREFIX}{salesforce_id}"
base_url = f"https://{self.sf_client.sf_instance}"
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
extracted_object_text = extract_dict_text(object_dict)
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
extracted_primary_owners = [
BasicExpertInfo(
display_name=self._get_name_from_id(object_dict["LastModifiedById"])
)
]
sections = [extract_section(sf_object, base_url)]
for id in get_child_ids(sf_object.id):
if not (child_object := get_record(id)):
continue
sections.append(extract_section(child_object, base_url))
doc = Document(
id=onyx_salesforce_id,
sections=[Section(link=extracted_link, text=extracted_object_text)],
sections=sections,
source=DocumentSource.SALESFORCE,
semantic_identifier=extracted_semantic_identifier,
doc_updated_at=extracted_doc_updated_at,
primary_owners=extracted_primary_owners,
primary_owners=self._extract_primary_owners(sf_object),
metadata={},
)
return doc
def _is_valid_child_object(self, child_relationship: dict) -> bool:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
if not child_relationship["childSObject"]:
return False
if not child_relationship["relationshipName"]:
return False
sf_type = child_relationship["childSObject"]
object_description = self._get_sf_type_object_json(sf_type)
if not object_description["queryable"]:
return False
try:
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
result = self.sf_client.query(query)
if result["totalSize"] == 0:
return False
except Exception as e:
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
return False
if child_relationship["field"]:
if child_relationship["field"] == "RelatedToId":
return False
else:
return False
return True
def _get_all_children_of_sf_type(self, sf_type: str) -> list[dict]:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
object_description = self._get_sf_type_object_json(sf_type)
children_objects: list[dict] = []
for child_relationship in object_description["childRelationships"]:
if self._is_valid_child_object(child_relationship):
children_objects.append(
{
"relationship_name": child_relationship["relationshipName"],
"object_type": child_relationship["childSObject"],
}
)
return children_objects
def _get_all_fields_for_sf_type(self, sf_type: str) -> list[str]:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
object_description = self._get_sf_type_object_json(sf_type)
fields = [
field.get("name")
for field in object_description["fields"]
if field.get("type", "base64") != "base64"
]
return fields
def _generate_query_per_parent_type(self, parent_sf_type: str) -> Iterator[str]:
"""
This function takes in an object_type and generates query(s) designed to grab
information associated to objects of that type.
It does that by getting all the fields of the parent object type.
Then it gets all the child objects of that object type and all the fields of
those children as well.
"""
parent_fields = self._get_all_fields_for_sf_type(parent_sf_type)
child_sf_types = self._get_all_children_of_sf_type(parent_sf_type)
query = f"SELECT {', '.join(parent_fields)}"
for child_object_dict in child_sf_types:
fields = self._get_all_fields_for_sf_type(child_object_dict["object_type"])
query_addition = f", \n(SELECT {', '.join(fields)} FROM {child_object_dict['relationship_name']})"
if len(query_addition) + len(query) > MAX_QUERY_LENGTH:
query += f"\n FROM {parent_sf_type}"
yield query
query = "SELECT Id" + query_addition
else:
query += query_addition
query += f"\n FROM {parent_sf_type}"
yield query
def _fetch_from_salesforce(
self,
start: datetime | None = None,
end: datetime | None = None,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
init_db()
all_object_types: set[str] = set(self.parent_object_list)
doc_batch: list[Document] = []
logger.info(f"Starting with {len(self.parent_object_list)} parent object types")
logger.debug(f"Parent object types: {self.parent_object_list}")
# This takes like 20 seconds
for parent_object_type in self.parent_object_list:
logger.debug(f"Processing: {parent_object_type}")
query_results: dict = {}
for query in self._generate_query_per_parent_type(parent_object_type):
if start is not None and end is not None:
if start and start.tzinfo is None:
start = start.replace(tzinfo=timezone.utc)
if end and end.tzinfo is None:
end = end.replace(tzinfo=timezone.utc)
query += f" WHERE LastModifiedDate > {start.isoformat()} AND LastModifiedDate < {end.isoformat()}"
query_result = self.sf_client.query_all(query)
for record_dict in query_result["records"]:
query_results.setdefault(record_dict["Id"], {}).update(record_dict)
logger.info(
f"Number of {parent_object_type} Objects processed: {len(query_results)}"
child_types = get_all_children_of_sf_type(
self.sf_client, parent_object_type
)
all_object_types.update(child_types)
logger.debug(
f"Found {len(child_types)} child types for {parent_object_type}"
)
for combined_object_dict in query_results.values():
doc_batch.append(
self._convert_object_instance_to_document(combined_object_dict)
)
logger.info(f"Found total of {len(all_object_types)} object types to fetch")
logger.debug(f"All object types: {all_object_types}")
if len(doc_batch) > self.batch_size:
yield doc_batch
doc_batch = []
yield doc_batch
# checkpoint - we've found all object types, now time to fetch the data
logger.info("Starting to fetch CSVs for all object types")
# This takes like 30 minutes first time and <2 minutes for updates
object_type_to_csv_path = fetch_all_csvs_in_parallel(
sf_client=self.sf_client,
object_types=all_object_types,
start=start,
end=end,
)
updated_ids: set[str] = set()
# This takes like 10 seconds
# This is for testing the rest of the functionality if data has
# already been fetched and put in sqlite
# from import onyx.connectors.salesforce.sf_db.sqlite_functions find_ids_by_type
# for object_type in self.parent_object_list:
# updated_ids.update(list(find_ids_by_type(object_type)))
# This takes 10-70 minutes first time (idk why the range is so big)
total_types = len(object_type_to_csv_path)
logger.info(f"Starting to process {total_types} object types")
for i, (object_type, csv_paths) in enumerate(
object_type_to_csv_path.items(), 1
):
logger.info(f"Processing object type {object_type} ({i}/{total_types})")
# If path is None, it means it failed to fetch the csv
if csv_paths is None:
continue
# Go through each csv path and use it to update the db
for csv_path in csv_paths:
logger.debug(f"Updating {object_type} with {csv_path}")
new_ids = update_sf_db_with_csv(
object_type=object_type,
csv_download_path=csv_path,
)
updated_ids.update(new_ids)
logger.debug(
f"Added {len(new_ids)} new/updated records for {object_type}"
)
# Remove the csv file after it has been used
# to successfully update the db
os.remove(csv_path)
logger.info(f"Found {len(updated_ids)} total updated records")
logger.info(
f"Starting to process parent objects of types: {self.parent_object_list}"
)
docs_to_yield: list[Document] = []
docs_processed = 0
# Takes 15-20 seconds per batch
for parent_type, parent_id_batch in get_affected_parent_ids_by_type(
updated_ids=list(updated_ids),
parent_types=self.parent_object_list,
):
logger.info(
f"Processing batch of {len(parent_id_batch)} {parent_type} objects"
)
for parent_id in parent_id_batch:
if not (parent_object := get_record(parent_id, parent_type)):
logger.warning(
f"Failed to get parent object {parent_id} for {parent_type}"
)
continue
docs_to_yield.append(
self._convert_object_instance_to_document(parent_object)
)
docs_processed += 1
if len(docs_to_yield) >= self.batch_size:
yield docs_to_yield
docs_to_yield = []
yield docs_to_yield
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_salesforce()
@@ -245,26 +212,20 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
start_datetime = datetime.utcfromtimestamp(start)
end_datetime = datetime.utcfromtimestamp(end)
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
return self._fetch_from_salesforce(start=start, end=end)
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateSlimDocumentOutput:
if self.sf_client is None:
raise ConnectorMissingCredentialError("Salesforce")
doc_metadata_list: list[SlimDocument] = []
for parent_object_type in self.parent_object_list:
query = f"SELECT Id FROM {parent_object_type}"
query_result = self.sf_client.query_all(query)
doc_metadata_list.extend(
SlimDocument(
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",
id=f"{_ID_PREFIX}{instance_dict.get('Id', '')}",
perm_sync_data={},
)
for instance_dict in query_result["records"]
@@ -274,9 +235,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
if __name__ == "__main__":
connector = SalesforceConnector(
requested_objects=os.environ["REQUESTED_OBJECTS"].split(",")
)
import time
connector = SalesforceConnector(requested_objects=["Account"])
connector.load_credentials(
{
@@ -285,5 +246,20 @@ if __name__ == "__main__":
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
}
)
document_batches = connector.load_from_state()
print(next(document_batches))
start_time = time.time()
doc_count = 0
section_count = 0
text_count = 0
for doc_batch in connector.load_from_state():
doc_count += len(doc_batch)
print(f"doc_count: {doc_count}")
for doc in doc_batch:
section_count += len(doc.sections)
for section in doc.sections:
text_count += len(section.text)
end_time = time.time()
print(f"Doc count: {doc_count}")
print(f"Section count: {section_count}")
print(f"Text count: {text_count}")
print(f"Time taken: {end_time - start_time}")

View File

@@ -0,0 +1,156 @@
import re
from collections import OrderedDict
from onyx.connectors.models import Section
from onyx.connectors.salesforce.utils import SalesforceObject
# All of these types of keys are handled by specific fields in the doc
# conversion process (E.g. URLs) or are not useful for the user (E.g. UUIDs)
_SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
def _clean_salesforce_dict(data: dict | list) -> dict | list:
"""Clean and transform Salesforce API response data by recursively:
1. Extracting records from the response if present
2. Merging attributes into the main dictionary
3. Filtering out keys matching certain patterns (Id, Date, stamp, url)
4. Removing '__c' suffix from custom field names
5. Removing None values and empty containers
Args:
data: A dictionary or list from Salesforce API response
Returns:
Cleaned dictionary or list with transformed keys and filtered values
"""
if isinstance(data, dict):
if "records" in data.keys():
data = data["records"]
if isinstance(data, dict):
if "attributes" in data.keys():
if isinstance(data["attributes"], dict):
data.update(data.pop("attributes"))
if isinstance(data, dict):
filtered_dict = {}
for key, value in data.items():
if not re.search(_SF_JSON_FILTER, key, re.IGNORECASE):
# remove the custom object indicator for display
if "__c" in key:
key = key[:-3]
if isinstance(value, (dict, list)):
filtered_value = _clean_salesforce_dict(value)
# Only add non-empty dictionaries or lists
if filtered_value:
filtered_dict[key] = filtered_value
elif value is not None:
filtered_dict[key] = value
return filtered_dict
elif isinstance(data, list):
filtered_list = []
for item in data:
if isinstance(item, (dict, list)):
filtered_item = _clean_salesforce_dict(item)
# Only add non-empty dictionaries or lists
if filtered_item:
filtered_list.append(filtered_item)
elif item is not None:
filtered_list.append(filtered_item)
return filtered_list
else:
return data
def _json_to_natural_language(data: dict | list, indent: int = 0) -> str:
"""Convert a nested dictionary or list into a human-readable string format.
Recursively traverses the data structure and formats it with:
- Key-value pairs on separate lines
- Nested structures indented for readability
- Lists and dictionaries handled with appropriate formatting
Args:
data: The dictionary or list to convert
indent: Number of spaces to indent (default: 0)
Returns:
A formatted string representation of the data structure
"""
result = []
indent_str = " " * indent
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, (dict, list)):
result.append(f"{indent_str}{key}:")
result.append(_json_to_natural_language(value, indent + 2))
else:
result.append(f"{indent_str}{key}: {value}")
elif isinstance(data, list):
for item in data:
result.append(_json_to_natural_language(item, indent + 2))
return "\n".join(result)
def _extract_dict_text(raw_dict: dict) -> str:
"""Extract text from a Salesforce API response dictionary by:
1. Cleaning the dictionary
2. Converting the cleaned dictionary to natural language
"""
processed_dict = _clean_salesforce_dict(raw_dict)
natural_language_for_dict = _json_to_natural_language(processed_dict)
return natural_language_for_dict
def extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
return Section(
text=_extract_dict_text(salesforce_object.data),
link=f"{base_url}/{salesforce_object.id}",
)
def _field_value_is_child_object(field_value: dict) -> bool:
"""
Checks if the field value is a child object.
"""
return (
isinstance(field_value, OrderedDict)
and "records" in field_value.keys()
and isinstance(field_value["records"], list)
and len(field_value["records"]) > 0
and "Id" in field_value["records"][0].keys()
)
def _extract_sections(salesforce_object: dict, base_url: str) -> list[Section]:
"""
This goes through the salesforce_object and extracts the top level fields as a Section.
It also goes through the child objects and extracts them as Sections.
"""
top_level_dict = {}
child_object_sections = []
for field_name, field_value in salesforce_object.items():
# If the field value is not a child object, add it to the top level dict
# to turn into text for the top level section
if not _field_value_is_child_object(field_value):
top_level_dict[field_name] = field_value
continue
# If the field value is a child object, extract the child objects and add them as sections
for record in field_value["records"]:
child_object_id = record["Id"]
child_object_sections.append(
Section(
text=f"Child Object(s): {field_name}\n{_extract_dict_text(record)}",
link=f"{base_url}/{child_object_id}",
)
)
top_level_id = salesforce_object["Id"]
top_level_section = Section(
text=_extract_dict_text(top_level_dict),
link=f"{base_url}/{top_level_id}",
)
return [top_level_section, *child_object_sections]

View File

@@ -0,0 +1,210 @@
import os
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from typing import Any
from pytz import UTC
from simple_salesforce import Salesforce
from simple_salesforce import SFType
from simple_salesforce.bulk2 import SFBulk2Handler
from simple_salesforce.bulk2 import SFBulk2Type
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.salesforce.sqlite_functions import has_at_least_one_object_of_type
from onyx.connectors.salesforce.utils import get_object_type_path
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _build_time_filter_for_salesforce(
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> str:
if start is None or end is None:
return ""
start_datetime = datetime.fromtimestamp(start, UTC)
end_datetime = datetime.fromtimestamp(end, UTC)
return (
f" WHERE LastModifiedDate > {start_datetime.isoformat()} "
f"AND LastModifiedDate < {end_datetime.isoformat()}"
)
def _get_sf_type_object_json(sf_client: Salesforce, type_name: str) -> Any:
sf_object = SFType(type_name, sf_client.session_id, sf_client.sf_instance)
return sf_object.describe()
def _is_valid_child_object(
sf_client: Salesforce, child_relationship: dict[str, Any]
) -> bool:
if not child_relationship["childSObject"]:
return False
if not child_relationship["relationshipName"]:
return False
sf_type = child_relationship["childSObject"]
object_description = _get_sf_type_object_json(sf_client, sf_type)
if not object_description["queryable"]:
return False
if child_relationship["field"]:
if child_relationship["field"] == "RelatedToId":
return False
else:
return False
return True
def get_all_children_of_sf_type(sf_client: Salesforce, sf_type: str) -> set[str]:
object_description = _get_sf_type_object_json(sf_client, sf_type)
child_object_types = set()
for child_relationship in object_description["childRelationships"]:
if _is_valid_child_object(sf_client, child_relationship):
logger.debug(
f"Found valid child object {child_relationship['childSObject']}"
)
child_object_types.add(child_relationship["childSObject"])
return child_object_types
def _get_all_queryable_fields_of_sf_type(
sf_client: Salesforce,
sf_type: str,
) -> list[str]:
object_description = _get_sf_type_object_json(sf_client, sf_type)
fields: list[dict[str, Any]] = object_description["fields"]
valid_fields: set[str] = set()
compound_field_names: set[str] = set()
for field in fields:
if compound_field_name := field.get("compoundFieldName"):
compound_field_names.add(compound_field_name)
if field.get("type", "base64") == "base64":
continue
if field_name := field.get("name"):
valid_fields.add(field_name)
return list(valid_fields - compound_field_names)
def _check_if_object_type_is_empty(sf_client: Salesforce, sf_type: str) -> bool:
"""
Send a small query to check if the object type is empty so we don't
perform extra bulk queries
"""
try:
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
result = sf_client.query(query)
if result["totalSize"] == 0:
return False
except Exception as e:
if "OPERATION_TOO_LARGE" not in str(e):
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
return False
return True
def _check_for_existing_csvs(sf_type: str) -> list[str] | None:
# Check if the csv already exists
if os.path.exists(get_object_type_path(sf_type)):
existing_csvs = [
os.path.join(get_object_type_path(sf_type), f)
for f in os.listdir(get_object_type_path(sf_type))
if f.endswith(".csv")
]
# If the csv already exists, return the path
# This is likely due to a previous run that failed
# after downloading the csv but before the data was
# written to the db
if existing_csvs:
return existing_csvs
return None
def _build_bulk_query(sf_client: Salesforce, sf_type: str, time_filter: str) -> str:
queryable_fields = _get_all_queryable_fields_of_sf_type(sf_client, sf_type)
query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}"
return query
def _bulk_retrieve_from_salesforce(
sf_client: Salesforce,
sf_type: str,
time_filter: str,
) -> tuple[str, list[str] | None]:
if not _check_if_object_type_is_empty(sf_client, sf_type):
return sf_type, None
if existing_csvs := _check_for_existing_csvs(sf_type):
return sf_type, existing_csvs
query = _build_bulk_query(sf_client, sf_type, time_filter)
bulk_2_handler = SFBulk2Handler(
session_id=sf_client.session_id,
bulk2_url=sf_client.bulk2_url,
proxies=sf_client.proxies,
session=sf_client.session,
)
bulk_2_type = SFBulk2Type(
object_name=sf_type,
bulk2_url=bulk_2_handler.bulk2_url,
headers=bulk_2_handler.headers,
session=bulk_2_handler.session,
)
logger.info(f"Downloading {sf_type}")
logger.info(f"Query: {query}")
try:
# This downloads the file to a file in the target path with a random name
results = bulk_2_type.download(
query=query,
path=get_object_type_path(sf_type),
max_records=1000000,
)
all_download_paths = [result["file"] for result in results]
logger.info(f"Downloaded {sf_type} to {all_download_paths}")
return sf_type, all_download_paths
except Exception as e:
logger.info(f"Failed to download salesforce csv for object type {sf_type}: {e}")
return sf_type, None
def fetch_all_csvs_in_parallel(
sf_client: Salesforce,
object_types: set[str],
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
) -> dict[str, list[str] | None]:
"""
Fetches all the csvs in parallel for the given object types
Returns a dict of (sf_type, full_download_path)
"""
time_filter = _build_time_filter_for_salesforce(start, end)
time_filter_for_each_object_type = {}
# We do this outside of the thread pool executor because this requires
# a database connection and we don't want to block the thread pool
# executor from running
for sf_type in object_types:
"""Only add time filter if there is at least one object of the type
in the database. We aren't worried about partially completed object update runs
because this occurs after we check for existing csvs which covers this case"""
if has_at_least_one_object_of_type(sf_type):
time_filter_for_each_object_type[sf_type] = time_filter
else:
time_filter_for_each_object_type[sf_type] = ""
# Run the bulk retrieve in parallel
with ThreadPoolExecutor() as executor:
results = executor.map(
lambda object_type: _bulk_retrieve_from_salesforce(
sf_client=sf_client,
sf_type=object_type,
time_filter=time_filter_for_each_object_type[object_type],
),
object_types,
)
return dict(results)

View File

@@ -0,0 +1,209 @@
import csv
import shelve
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import (
get_child_to_parent_shelf_path,
)
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_id_type_shelf_path
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_object_shelf_path
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import (
get_parent_to_child_shelf_path,
)
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.connectors.salesforce.utils import validate_salesforce_id
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _update_relationship_shelves(
child_id: str,
parent_ids: set[str],
) -> None:
"""Update the relationship shelf when a record is updated."""
try:
# Convert child_id to string once
str_child_id = str(child_id)
# First update child to parent mapping
with shelve.open(
get_child_to_parent_shelf_path(),
flag="c",
protocol=None,
writeback=True,
) as child_to_parent_db:
old_parent_ids = set(child_to_parent_db.get(str_child_id, []))
child_to_parent_db[str_child_id] = list(parent_ids)
# Calculate differences outside the next context manager
parent_ids_to_remove = old_parent_ids - parent_ids
parent_ids_to_add = parent_ids - old_parent_ids
# Only sync once at the end
child_to_parent_db.sync()
# Then update parent to child mapping in a single transaction
if not parent_ids_to_remove and not parent_ids_to_add:
return
with shelve.open(
get_parent_to_child_shelf_path(),
flag="c",
protocol=None,
writeback=True,
) as parent_to_child_db:
# Process all removals first
for parent_id in parent_ids_to_remove:
str_parent_id = str(parent_id)
existing_children = set(parent_to_child_db.get(str_parent_id, []))
if str_child_id in existing_children:
existing_children.remove(str_child_id)
parent_to_child_db[str_parent_id] = list(existing_children)
# Then process all additions
for parent_id in parent_ids_to_add:
str_parent_id = str(parent_id)
existing_children = set(parent_to_child_db.get(str_parent_id, []))
existing_children.add(str_child_id)
parent_to_child_db[str_parent_id] = list(existing_children)
# Single sync at the end
parent_to_child_db.sync()
except Exception as e:
logger.error(f"Error updating relationship shelves: {e}")
logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}")
raise
def get_child_ids(parent_id: str) -> set[str]:
"""Get all child IDs for a given parent ID.
Args:
parent_id: The ID of the parent object
Returns:
A set of child object IDs
"""
with shelve.open(get_parent_to_child_shelf_path()) as parent_to_child_db:
return set(parent_to_child_db.get(parent_id, []))
def update_sf_db_with_csv(
object_type: str,
csv_download_path: str,
) -> list[str]:
"""Update the SF DB with a CSV file using shelve storage."""
updated_ids = []
shelf_path = get_object_shelf_path(object_type)
# First read the CSV to get all the data
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
id = row["Id"]
parent_ids = set()
field_to_remove: set[str] = set()
# Update relationship shelves for any parent references
for field, value in row.items():
if validate_salesforce_id(value) and field != "Id":
parent_ids.add(value)
field_to_remove.add(field)
if not value:
field_to_remove.add(field)
_update_relationship_shelves(id, parent_ids)
for field in field_to_remove:
# We use this to extract the Primary Owner later
if field != "LastModifiedById":
del row[field]
# Update the main object shelf
with shelve.open(shelf_path) as object_type_db:
object_type_db[id] = row
# Update the ID-to-type mapping shelf
with shelve.open(get_id_type_shelf_path()) as id_type_db:
id_type_db[id] = object_type
updated_ids.append(id)
# os.remove(csv_download_path)
return updated_ids
def get_type_from_id(object_id: str) -> str | None:
"""Get the type of an object from its ID."""
# Look up the object type from the ID-to-type mapping
with shelve.open(get_id_type_shelf_path()) as id_type_db:
if object_id not in id_type_db:
logger.warning(f"Object ID {object_id} not found in ID-to-type mapping")
return None
return id_type_db[object_id]
def get_record(
object_id: str, object_type: str | None = None
) -> SalesforceObject | None:
"""
Retrieve the record and return it as a SalesforceObject.
The object type will be looked up from the ID-to-type mapping shelf.
"""
if object_type is None:
if not (object_type := get_type_from_id(object_id)):
return None
shelf_path = get_object_shelf_path(object_type)
with shelve.open(shelf_path) as db:
if object_id not in db:
logger.warning(f"Object ID {object_id} not found in {shelf_path}")
return None
data = db[object_id]
return SalesforceObject(
id=object_id,
type=object_type,
data=data,
)
def find_ids_by_type(object_type: str) -> list[str]:
"""
Find all object IDs for rows of the specified type.
"""
shelf_path = get_object_shelf_path(object_type)
try:
with shelve.open(shelf_path) as db:
return list(db.keys())
except FileNotFoundError:
return []
def get_affected_parent_ids_by_type(
updated_ids: set[str], parent_types: list[str]
) -> dict[str, set[str]]:
"""Get IDs of objects that are of the specified parent types and are either in the updated_ids
or have children in the updated_ids.
Args:
updated_ids: List of IDs that were updated
parent_types: List of object types to filter by
Returns:
A dictionary of IDs that match the criteria
"""
affected_ids_by_type: dict[str, set[str]] = {}
# Check each updated ID
for updated_id in updated_ids:
# Add the ID itself if it's of a parent type
updated_type = get_type_from_id(updated_id)
if updated_type in parent_types:
affected_ids_by_type.setdefault(updated_type, set()).add(updated_id)
continue
# Get parents of this ID and add them if they're of a parent type
with shelve.open(get_child_to_parent_shelf_path()) as child_to_parent_db:
parent_ids = child_to_parent_db.get(updated_id, [])
for parent_id in parent_ids:
parent_type = get_type_from_id(parent_id)
if parent_type in parent_types:
affected_ids_by_type.setdefault(parent_type, set()).add(parent_id)
return affected_ids_by_type

View File

@@ -0,0 +1,29 @@
import os
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_object_type_path
def get_object_shelf_path(object_type: str) -> str:
"""Get the path to the shelf file for a specific object type."""
base_path = get_object_type_path(object_type)
os.makedirs(base_path, exist_ok=True)
return os.path.join(base_path, "data.shelf")
def get_id_type_shelf_path() -> str:
"""Get the path to the ID-to-type mapping shelf."""
os.makedirs(BASE_DATA_PATH, exist_ok=True)
return os.path.join(BASE_DATA_PATH, "id_type_mapping.shelf.4g")
def get_parent_to_child_shelf_path() -> str:
"""Get the path to the parent-to-child mapping shelf."""
os.makedirs(BASE_DATA_PATH, exist_ok=True)
return os.path.join(BASE_DATA_PATH, "parent_to_child_mapping.shelf.4g")
def get_child_to_parent_shelf_path() -> str:
"""Get the path to the child-to-parent mapping shelf."""
os.makedirs(BASE_DATA_PATH, exist_ok=True)
return os.path.join(BASE_DATA_PATH, "child_to_parent_mapping.shelf.4g")

View File

@@ -0,0 +1,737 @@
import csv
import os
import shutil
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import find_ids_by_type
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
get_affected_parent_ids_by_type,
)
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_child_ids
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_record
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
update_sf_db_with_csv,
)
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_object_type_path
_VALID_SALESFORCE_IDS = [
"001bm00000fd9Z3AAI",
"001bm00000fdYTdAAM",
"001bm00000fdYTeAAM",
"001bm00000fdYTfAAM",
"001bm00000fdYTgAAM",
"001bm00000fdYThAAM",
"001bm00000fdYTiAAM",
"001bm00000fdYTjAAM",
"001bm00000fdYTkAAM",
"001bm00000fdYTlAAM",
"001bm00000fdYTmAAM",
"001bm00000fdYTnAAM",
"001bm00000fdYToAAM",
"500bm00000XoOxtAAF",
"500bm00000XoOxuAAF",
"500bm00000XoOxvAAF",
"500bm00000XoOxwAAF",
"500bm00000XoOxxAAF",
"500bm00000XoOxyAAF",
"500bm00000XoOxzAAF",
"500bm00000XoOy0AAF",
"500bm00000XoOy1AAF",
"500bm00000XoOy2AAF",
"500bm00000XoOy3AAF",
"500bm00000XoOy4AAF",
"500bm00000XoOy5AAF",
"500bm00000XoOy6AAF",
"500bm00000XoOy7AAF",
"500bm00000XoOy8AAF",
"500bm00000XoOy9AAF",
"500bm00000XoOyAAAV",
"500bm00000XoOyBAAV",
"500bm00000XoOyCAAV",
"500bm00000XoOyDAAV",
"500bm00000XoOyEAAV",
"500bm00000XoOyFAAV",
"500bm00000XoOyGAAV",
"500bm00000XoOyHAAV",
"500bm00000XoOyIAAV",
"003bm00000EjHCjAAN",
"003bm00000EjHCkAAN",
"003bm00000EjHClAAN",
"003bm00000EjHCmAAN",
"003bm00000EjHCnAAN",
"003bm00000EjHCoAAN",
"003bm00000EjHCpAAN",
"003bm00000EjHCqAAN",
"003bm00000EjHCrAAN",
"003bm00000EjHCsAAN",
"003bm00000EjHCtAAN",
"003bm00000EjHCuAAN",
"003bm00000EjHCvAAN",
"003bm00000EjHCwAAN",
"003bm00000EjHCxAAN",
"003bm00000EjHCyAAN",
"003bm00000EjHCzAAN",
"003bm00000EjHD0AAN",
"003bm00000EjHD1AAN",
"003bm00000EjHD2AAN",
"550bm00000EXc2tAAD",
"006bm000006kyDpAAI",
"006bm000006kyDqAAI",
"006bm000006kyDrAAI",
"006bm000006kyDsAAI",
"006bm000006kyDtAAI",
"006bm000006kyDuAAI",
"006bm000006kyDvAAI",
"006bm000006kyDwAAI",
"006bm000006kyDxAAI",
"006bm000006kyDyAAI",
"006bm000006kyDzAAI",
"006bm000006kyE0AAI",
"006bm000006kyE1AAI",
"006bm000006kyE2AAI",
"006bm000006kyE3AAI",
"006bm000006kyE4AAI",
"006bm000006kyE5AAI",
"006bm000006kyE6AAI",
"006bm000006kyE7AAI",
"006bm000006kyE8AAI",
"006bm000006kyE9AAI",
"006bm000006kyEAAAY",
"006bm000006kyEBAAY",
"006bm000006kyECAAY",
"006bm000006kyEDAAY",
"006bm000006kyEEAAY",
"006bm000006kyEFAAY",
"006bm000006kyEGAAY",
"006bm000006kyEHAAY",
"006bm000006kyEIAAY",
"006bm000006kyEJAAY",
"005bm000009zy0TAAQ",
"005bm000009zy25AAA",
"005bm000009zy26AAA",
"005bm000009zy28AAA",
"005bm000009zy29AAA",
"005bm000009zy2AAAQ",
"005bm000009zy2BAAQ",
]
def clear_sf_db() -> None:
"""
Clears the SF DB by deleting all files in the data directory.
"""
shutil.rmtree(BASE_DATA_PATH)
def create_csv_file(
object_type: str, records: list[dict], filename: str = "test_data.csv"
) -> None:
"""
Creates a CSV file for the given object type and records.
Args:
object_type: The Salesforce object type (e.g. "Account", "Contact")
records: List of dictionaries containing the record data
filename: Name of the CSV file to create (default: test_data.csv)
"""
if not records:
return
# Get all unique fields from records
fields: set[str] = set()
for record in records:
fields.update(record.keys())
fields = set(sorted(list(fields))) # Sort for consistent order
# Create CSV file
csv_path = os.path.join(get_object_type_path(object_type), filename)
with open(csv_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fields)
writer.writeheader()
for record in records:
writer.writerow(record)
# Update the database with the CSV
update_sf_db_with_csv(object_type, csv_path)
def create_csv_with_example_data() -> None:
"""
Creates CSV files with example data, organized by object type.
"""
example_data: dict[str, list[dict]] = {
"Account": [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Acme Inc.",
"BillingCity": "New York",
"Industry": "Technology",
},
{
"Id": _VALID_SALESFORCE_IDS[1],
"Name": "Globex Corp",
"BillingCity": "Los Angeles",
"Industry": "Manufacturing",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "Initech",
"BillingCity": "Austin",
"Industry": "Software",
},
{
"Id": _VALID_SALESFORCE_IDS[3],
"Name": "TechCorp Solutions",
"BillingCity": "San Francisco",
"Industry": "Software",
"AnnualRevenue": 5000000,
},
{
"Id": _VALID_SALESFORCE_IDS[4],
"Name": "BioMed Research",
"BillingCity": "Boston",
"Industry": "Healthcare",
"AnnualRevenue": 12000000,
},
{
"Id": _VALID_SALESFORCE_IDS[5],
"Name": "Green Energy Co",
"BillingCity": "Portland",
"Industry": "Energy",
"AnnualRevenue": 8000000,
},
{
"Id": _VALID_SALESFORCE_IDS[6],
"Name": "DataFlow Analytics",
"BillingCity": "Seattle",
"Industry": "Technology",
"AnnualRevenue": 3000000,
},
{
"Id": _VALID_SALESFORCE_IDS[7],
"Name": "Cloud Nine Services",
"BillingCity": "Denver",
"Industry": "Cloud Computing",
"AnnualRevenue": 7000000,
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[40],
"FirstName": "John",
"LastName": "Doe",
"Email": "john.doe@acme.com",
"Title": "CEO",
},
{
"Id": _VALID_SALESFORCE_IDS[41],
"FirstName": "Jane",
"LastName": "Smith",
"Email": "jane.smith@acme.com",
"Title": "CTO",
},
{
"Id": _VALID_SALESFORCE_IDS[42],
"FirstName": "Bob",
"LastName": "Johnson",
"Email": "bob.j@globex.com",
"Title": "Sales Director",
},
{
"Id": _VALID_SALESFORCE_IDS[43],
"FirstName": "Sarah",
"LastName": "Chen",
"Email": "sarah.chen@techcorp.com",
"Title": "Product Manager",
"Phone": "415-555-0101",
},
{
"Id": _VALID_SALESFORCE_IDS[44],
"FirstName": "Michael",
"LastName": "Rodriguez",
"Email": "m.rodriguez@biomed.com",
"Title": "Research Director",
"Phone": "617-555-0202",
},
{
"Id": _VALID_SALESFORCE_IDS[45],
"FirstName": "Emily",
"LastName": "Green",
"Email": "emily.g@greenenergy.com",
"Title": "Sustainability Lead",
"Phone": "503-555-0303",
},
{
"Id": _VALID_SALESFORCE_IDS[46],
"FirstName": "David",
"LastName": "Kim",
"Email": "david.kim@dataflow.com",
"Title": "Data Scientist",
"Phone": "206-555-0404",
},
{
"Id": _VALID_SALESFORCE_IDS[47],
"FirstName": "Rachel",
"LastName": "Taylor",
"Email": "r.taylor@cloudnine.com",
"Title": "Cloud Architect",
"Phone": "303-555-0505",
},
],
"Opportunity": [
{
"Id": _VALID_SALESFORCE_IDS[62],
"Name": "Acme Server Upgrade",
"Amount": 50000,
"Stage": "Prospecting",
"CloseDate": "2024-06-30",
},
{
"Id": _VALID_SALESFORCE_IDS[63],
"Name": "Globex Manufacturing Line",
"Amount": 150000,
"Stage": "Negotiation",
"CloseDate": "2024-03-15",
},
{
"Id": _VALID_SALESFORCE_IDS[64],
"Name": "Initech Software License",
"Amount": 75000,
"Stage": "Closed Won",
"CloseDate": "2024-01-30",
},
{
"Id": _VALID_SALESFORCE_IDS[65],
"Name": "TechCorp AI Implementation",
"Amount": 250000,
"Stage": "Needs Analysis",
"CloseDate": "2024-08-15",
"Probability": 60,
},
{
"Id": _VALID_SALESFORCE_IDS[66],
"Name": "BioMed Lab Equipment",
"Amount": 500000,
"Stage": "Value Proposition",
"CloseDate": "2024-09-30",
"Probability": 75,
},
{
"Id": _VALID_SALESFORCE_IDS[67],
"Name": "Green Energy Solar Project",
"Amount": 750000,
"Stage": "Proposal",
"CloseDate": "2024-07-15",
"Probability": 80,
},
{
"Id": _VALID_SALESFORCE_IDS[68],
"Name": "DataFlow Analytics Platform",
"Amount": 180000,
"Stage": "Negotiation",
"CloseDate": "2024-05-30",
"Probability": 90,
},
{
"Id": _VALID_SALESFORCE_IDS[69],
"Name": "Cloud Nine Infrastructure",
"Amount": 300000,
"Stage": "Qualification",
"CloseDate": "2024-10-15",
"Probability": 40,
},
],
}
# Create CSV files for each object type
for object_type, records in example_data.items():
create_csv_file(object_type, records)
def test_query() -> None:
"""
Tests querying functionality by verifying:
1. All expected Account IDs are found
2. Each Account's data matches what was inserted
"""
# Expected test data for verification
expected_accounts: dict[str, dict[str, str | int]] = {
_VALID_SALESFORCE_IDS[0]: {
"Name": "Acme Inc.",
"BillingCity": "New York",
"Industry": "Technology",
},
_VALID_SALESFORCE_IDS[1]: {
"Name": "Globex Corp",
"BillingCity": "Los Angeles",
"Industry": "Manufacturing",
},
_VALID_SALESFORCE_IDS[2]: {
"Name": "Initech",
"BillingCity": "Austin",
"Industry": "Software",
},
_VALID_SALESFORCE_IDS[3]: {
"Name": "TechCorp Solutions",
"BillingCity": "San Francisco",
"Industry": "Software",
"AnnualRevenue": 5000000,
},
_VALID_SALESFORCE_IDS[4]: {
"Name": "BioMed Research",
"BillingCity": "Boston",
"Industry": "Healthcare",
"AnnualRevenue": 12000000,
},
_VALID_SALESFORCE_IDS[5]: {
"Name": "Green Energy Co",
"BillingCity": "Portland",
"Industry": "Energy",
"AnnualRevenue": 8000000,
},
_VALID_SALESFORCE_IDS[6]: {
"Name": "DataFlow Analytics",
"BillingCity": "Seattle",
"Industry": "Technology",
"AnnualRevenue": 3000000,
},
_VALID_SALESFORCE_IDS[7]: {
"Name": "Cloud Nine Services",
"BillingCity": "Denver",
"Industry": "Cloud Computing",
"AnnualRevenue": 7000000,
},
}
# Get all Account IDs
account_ids = find_ids_by_type("Account")
# Verify we found all expected accounts
assert len(account_ids) == len(
expected_accounts
), f"Expected {len(expected_accounts)} accounts, found {len(account_ids)}"
assert set(account_ids) == set(
expected_accounts.keys()
), "Found account IDs don't match expected IDs"
# Verify each account's data
for acc_id in account_ids:
combined = get_record(acc_id)
assert combined is not None, f"Could not find account {acc_id}"
expected = expected_accounts[acc_id]
# Verify account data matches
for key, value in expected.items():
value = str(value)
assert (
combined.data[key] == value
), f"Account {acc_id} field {key} expected {value}, got {combined.data[key]}"
print("All query tests passed successfully!")
def test_upsert() -> None:
"""
Tests upsert functionality by:
1. Updating an existing account
2. Creating a new account
3. Verifying both operations were successful
"""
# Create CSV for updating an existing account and adding a new one
update_data: list[dict[str, str | int]] = [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Acme Inc. Updated",
"BillingCity": "New York",
"Industry": "Technology",
"Description": "Updated company info",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "New Company Inc.",
"BillingCity": "Miami",
"Industry": "Finance",
"AnnualRevenue": 1000000,
},
]
create_csv_file("Account", update_data, "update_data.csv")
# Verify the update worked
updated_record = get_record(_VALID_SALESFORCE_IDS[0])
assert updated_record is not None, "Updated record not found"
assert updated_record.data["Name"] == "Acme Inc. Updated", "Name not updated"
assert (
updated_record.data["Description"] == "Updated company info"
), "Description not added"
# Verify the new record was created
new_record = get_record(_VALID_SALESFORCE_IDS[2])
assert new_record is not None, "New record not found"
assert new_record.data["Name"] == "New Company Inc.", "New record name incorrect"
assert new_record.data["AnnualRevenue"] == "1000000", "New record revenue incorrect"
print("All upsert tests passed successfully!")
def test_relationships() -> None:
"""
Tests relationship shelf updates and queries by:
1. Creating test data with relationships
2. Verifying the relationships are correctly stored
3. Testing relationship queries
"""
# Create test data for each object type
test_data: dict[str, list[dict[str, str | int]]] = {
"Case": [
{
"Id": _VALID_SALESFORCE_IDS[13],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Subject": "Test Case 1",
},
{
"Id": _VALID_SALESFORCE_IDS[14],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Subject": "Test Case 2",
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[48],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Test",
"LastName": "Contact",
}
],
"Opportunity": [
{
"Id": _VALID_SALESFORCE_IDS[62],
"AccountId": _VALID_SALESFORCE_IDS[0],
"Name": "Test Opportunity",
"Amount": 100000,
}
],
}
# Create and update CSV files for each object type
for object_type, records in test_data.items():
create_csv_file(object_type, records, "relationship_test.csv")
# Test relationship queries
# All these objects should be children of Acme Inc.
child_ids = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert len(child_ids) == 4, f"Expected 4 child objects, found {len(child_ids)}"
assert _VALID_SALESFORCE_IDS[13] in child_ids, "Case 1 not found in relationship"
assert _VALID_SALESFORCE_IDS[14] in child_ids, "Case 2 not found in relationship"
assert _VALID_SALESFORCE_IDS[48] in child_ids, "Contact not found in relationship"
assert (
_VALID_SALESFORCE_IDS[62] in child_ids
), "Opportunity not found in relationship"
# Test querying relationships for a different account (should be empty)
other_account_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
assert (
len(other_account_children) == 0
), "Expected no children for different account"
print("All relationship tests passed successfully!")
def test_account_with_children() -> None:
"""
Tests querying all accounts and retrieving their child objects.
This test verifies that:
1. All accounts can be retrieved
2. Child objects are correctly linked
3. Child object data is complete and accurate
"""
# First get all account IDs
account_ids = find_ids_by_type("Account")
assert len(account_ids) > 0, "No accounts found"
# For each account, get its children and verify the data
for account_id in account_ids:
account = get_record(account_id)
assert account is not None, f"Could not find account {account_id}"
# Get all child objects
child_ids = get_child_ids(account_id)
# For Acme Inc., verify specific relationships
if account_id == _VALID_SALESFORCE_IDS[0]: # Acme Inc.
assert (
len(child_ids) == 4
), f"Expected 4 children for Acme Inc., found {len(child_ids)}"
# Get all child records
child_records = []
for child_id in child_ids:
child_record = get_record(child_id)
if child_record is not None:
child_records.append(child_record)
# Verify Cases
cases = [r for r in child_records if r.type == "Case"]
assert (
len(cases) == 2
), f"Expected 2 cases for Acme Inc., found {len(cases)}"
case_subjects = {case.data["Subject"] for case in cases}
assert "Test Case 1" in case_subjects, "Test Case 1 not found"
assert "Test Case 2" in case_subjects, "Test Case 2 not found"
# Verify Contacts
contacts = [r for r in child_records if r.type == "Contact"]
assert (
len(contacts) == 1
), f"Expected 1 contact for Acme Inc., found {len(contacts)}"
contact = contacts[0]
assert contact.data["FirstName"] == "Test", "Contact first name mismatch"
assert contact.data["LastName"] == "Contact", "Contact last name mismatch"
# Verify Opportunities
opportunities = [r for r in child_records if r.type == "Opportunity"]
assert (
len(opportunities) == 1
), f"Expected 1 opportunity for Acme Inc., found {len(opportunities)}"
opportunity = opportunities[0]
assert (
opportunity.data["Name"] == "Test Opportunity"
), "Opportunity name mismatch"
assert opportunity.data["Amount"] == "100000", "Opportunity amount mismatch"
print("All account with children tests passed successfully!")
def test_relationship_updates() -> None:
"""
Tests that relationships are properly updated when a child object's parent reference changes.
This test verifies:
1. Initial relationship is created correctly
2. When parent reference is updated, old relationship is removed
3. New relationship is created correctly
"""
# Create initial test data - Contact linked to Acme Inc.
initial_contact = [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Test",
"LastName": "Contact",
}
]
create_csv_file("Contact", initial_contact, "initial_contact.csv")
# Verify initial relationship
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert (
_VALID_SALESFORCE_IDS[40] in acme_children
), "Initial relationship not created"
# Update contact to be linked to Globex Corp instead
updated_contact = [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[1],
"FirstName": "Test",
"LastName": "Contact",
}
]
create_csv_file("Contact", updated_contact, "updated_contact.csv")
# Verify old relationship is removed
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
assert (
_VALID_SALESFORCE_IDS[40] not in acme_children
), "Old relationship not removed"
# Verify new relationship is created
globex_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
assert _VALID_SALESFORCE_IDS[40] in globex_children, "New relationship not created"
print("All relationship update tests passed successfully!")
def test_get_affected_parent_ids() -> None:
"""
Tests get_affected_parent_ids functionality by verifying:
1. IDs that are directly in the parent_types list are included
2. IDs that have children in the updated_ids list are included
3. IDs that are neither of the above are not included
"""
# Create test data with relationships
test_data = {
"Account": [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Parent Account 1",
},
{
"Id": _VALID_SALESFORCE_IDS[1],
"Name": "Parent Account 2",
},
{
"Id": _VALID_SALESFORCE_IDS[2],
"Name": "Not Affected Account",
},
],
"Contact": [
{
"Id": _VALID_SALESFORCE_IDS[40],
"AccountId": _VALID_SALESFORCE_IDS[0],
"FirstName": "Child",
"LastName": "Contact",
}
],
}
# Create and update CSV files for test data
for object_type, records in test_data.items():
create_csv_file(object_type, records)
# Test Case 1: Account directly in updated_ids and parent_types
updated_ids = {_VALID_SALESFORCE_IDS[1]} # Parent Account 2
parent_types = ["Account"]
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
# Test Case 2: Account with child in updated_ids
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
parent_types = ["Account"]
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert (
_VALID_SALESFORCE_IDS[0] in affected_ids
), "Parent of updated child not included"
# Test Case 3: Both direct and indirect affects
updated_ids = {_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]} # Both cases
parent_types = ["Account"]
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert len(affected_ids) == 2, "Expected exactly two affected parent IDs"
assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included"
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
assert (
_VALID_SALESFORCE_IDS[2] not in affected_ids
), "Unaffected ID incorrectly included"
# Test Case 4: No matches
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
parent_types = ["Opportunity"] # Wrong type
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
assert len(affected_ids) == 0, "Should return empty list when no matches"
print("All get_affected_parent_ids tests passed successfully!")
def main_build() -> None:
clear_sf_db()
create_csv_with_example_data()
test_query()
test_upsert()
test_relationships()
test_account_with_children()
test_relationship_updates()
test_get_affected_parent_ids()
if __name__ == "__main__":
main_build()

View File

@@ -0,0 +1,386 @@
import csv
import json
import os
import sqlite3
from collections.abc import Iterator
from contextlib import contextmanager
from onyx.connectors.salesforce.utils import get_sqlite_db_path
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.connectors.salesforce.utils import validate_salesforce_id
from onyx.utils.logger import setup_logger
from shared_configs.utils import batch_list
logger = setup_logger()
@contextmanager
def get_db_connection(
isolation_level: str | None = None,
) -> Iterator[sqlite3.Connection]:
"""Get a database connection with proper isolation level and error handling.
Args:
isolation_level: SQLite isolation level. None = default "DEFERRED",
can be "IMMEDIATE" or "EXCLUSIVE" for more strict isolation.
"""
# 60 second timeout for locks
conn = sqlite3.connect(get_sqlite_db_path(), timeout=60.0)
if isolation_level is not None:
conn.isolation_level = isolation_level
try:
yield conn
except Exception:
conn.rollback()
raise
finally:
conn.close()
def init_db() -> None:
"""Initialize the SQLite database with required tables if they don't exist."""
if os.path.exists(get_sqlite_db_path()):
return
# Create database directory if it doesn't exist
os.makedirs(os.path.dirname(get_sqlite_db_path()), exist_ok=True)
with get_db_connection("EXCLUSIVE") as conn:
cursor = conn.cursor()
# Enable WAL mode for better concurrent access and write performance
cursor.execute("PRAGMA journal_mode=WAL")
cursor.execute("PRAGMA synchronous=NORMAL")
cursor.execute("PRAGMA temp_store=MEMORY")
cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache
# Main table for storing Salesforce objects
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS salesforce_objects (
id TEXT PRIMARY KEY,
object_type TEXT NOT NULL,
data TEXT NOT NULL, -- JSON serialized data
last_modified INTEGER DEFAULT (strftime('%s', 'now')) -- Add timestamp for better cache management
) WITHOUT ROWID -- Optimize for primary key lookups
"""
)
# Table for parent-child relationships with covering index
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS relationships (
child_id TEXT NOT NULL,
parent_id TEXT NOT NULL,
PRIMARY KEY (child_id, parent_id)
) WITHOUT ROWID -- Optimize for primary key lookups
"""
)
# New table for caching parent-child relationships with object types
cursor.execute(
"""
CREATE TABLE IF NOT EXISTS relationship_types (
child_id TEXT NOT NULL,
parent_id TEXT NOT NULL,
parent_type TEXT NOT NULL,
PRIMARY KEY (child_id, parent_id, parent_type)
) WITHOUT ROWID
"""
)
# Always recreate indexes to ensure they exist
cursor.execute("DROP INDEX IF EXISTS idx_object_type")
cursor.execute("DROP INDEX IF EXISTS idx_parent_id")
cursor.execute("DROP INDEX IF EXISTS idx_child_parent")
cursor.execute("DROP INDEX IF EXISTS idx_object_type_id")
cursor.execute("DROP INDEX IF EXISTS idx_relationship_types_lookup")
# Create covering indexes for common queries
cursor.execute(
"""
CREATE INDEX idx_object_type
ON salesforce_objects(object_type, id)
WHERE object_type IS NOT NULL
"""
)
cursor.execute(
"""
CREATE INDEX idx_parent_id
ON relationships(parent_id, child_id)
"""
)
cursor.execute(
"""
CREATE INDEX idx_child_parent
ON relationships(child_id)
WHERE child_id IS NOT NULL
"""
)
# New composite index for fast parent type lookups
cursor.execute(
"""
CREATE INDEX idx_relationship_types_lookup
ON relationship_types(parent_type, child_id, parent_id)
"""
)
# Analyze tables to help query planner
cursor.execute("ANALYZE relationships")
cursor.execute("ANALYZE salesforce_objects")
cursor.execute("ANALYZE relationship_types")
conn.commit()
def _update_relationship_tables(
conn: sqlite3.Connection, child_id: str, parent_ids: set[str]
) -> None:
"""Update the relationship tables when a record is updated.
Args:
conn: The database connection to use (must be in a transaction)
child_id: The ID of the child record
parent_ids: Set of parent IDs to link to
"""
try:
cursor = conn.cursor()
# Get existing parent IDs
cursor.execute(
"SELECT parent_id FROM relationships WHERE child_id = ?", (child_id,)
)
old_parent_ids = {row[0] for row in cursor.fetchall()}
# Calculate differences
parent_ids_to_remove = old_parent_ids - parent_ids
parent_ids_to_add = parent_ids - old_parent_ids
# Remove old relationships
if parent_ids_to_remove:
cursor.executemany(
"DELETE FROM relationships WHERE child_id = ? AND parent_id = ?",
[(child_id, pid) for pid in parent_ids_to_remove],
)
# Also remove from relationship_types
cursor.executemany(
"DELETE FROM relationship_types WHERE child_id = ? AND parent_id = ?",
[(child_id, pid) for pid in parent_ids_to_remove],
)
# Add new relationships
if parent_ids_to_add:
# First add to relationships table
cursor.executemany(
"INSERT INTO relationships (child_id, parent_id) VALUES (?, ?)",
[(child_id, pid) for pid in parent_ids_to_add],
)
# Then get the types of the parent objects and add to relationship_types
for parent_id in parent_ids_to_add:
cursor.execute(
"SELECT object_type FROM salesforce_objects WHERE id = ?",
(parent_id,),
)
result = cursor.fetchone()
if result:
parent_type = result[0]
cursor.execute(
"""
INSERT INTO relationship_types (child_id, parent_id, parent_type)
VALUES (?, ?, ?)
""",
(child_id, parent_id, parent_type),
)
except Exception as e:
logger.error(f"Error updating relationship tables: {e}")
logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}")
raise
def update_sf_db_with_csv(object_type: str, csv_download_path: str) -> list[str]:
"""Update the SF DB with a CSV file using SQLite storage."""
updated_ids = []
# Use IMMEDIATE to get a write lock at the start of the transaction
with get_db_connection("IMMEDIATE") as conn:
cursor = conn.cursor()
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row in reader:
if "Id" not in row:
logger.warning(
f"Row {row} does not have an Id field in {csv_download_path}"
)
continue
id = row["Id"]
parent_ids = set()
field_to_remove: set[str] = set()
# Process relationships and clean data
for field, value in row.items():
if validate_salesforce_id(value) and field != "Id":
parent_ids.add(value)
field_to_remove.add(field)
if not value:
field_to_remove.add(field)
# Remove unwanted fields
for field in field_to_remove:
if field != "LastModifiedById":
del row[field]
# Update main object data
cursor.execute(
"""
INSERT OR REPLACE INTO salesforce_objects (id, object_type, data)
VALUES (?, ?, ?)
""",
(id, object_type, json.dumps(row)),
)
# Update relationships using the same connection
_update_relationship_tables(conn, id, parent_ids)
updated_ids.append(id)
conn.commit()
return updated_ids
def get_child_ids(parent_id: str) -> set[str]:
"""Get all child IDs for a given parent ID."""
with get_db_connection() as conn:
cursor = conn.cursor()
# Force index usage with INDEXED BY
cursor.execute(
"SELECT child_id FROM relationships INDEXED BY idx_parent_id WHERE parent_id = ?",
(parent_id,),
)
child_ids = {row[0] for row in cursor.fetchall()}
return child_ids
def get_type_from_id(object_id: str) -> str | None:
"""Get the type of an object from its ID."""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT object_type FROM salesforce_objects WHERE id = ?", (object_id,)
)
result = cursor.fetchone()
if not result:
logger.warning(f"Object ID {object_id} not found")
return None
return result[0]
def get_record(
object_id: str, object_type: str | None = None
) -> SalesforceObject | None:
"""Retrieve the record and return it as a SalesforceObject."""
if object_type is None:
object_type = get_type_from_id(object_id)
if not object_type:
return None
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute("SELECT data FROM salesforce_objects WHERE id = ?", (object_id,))
result = cursor.fetchone()
if not result:
logger.warning(f"Object ID {object_id} not found")
return None
data = json.loads(result[0])
return SalesforceObject(id=object_id, type=object_type, data=data)
def find_ids_by_type(object_type: str) -> list[str]:
"""Find all object IDs for rows of the specified type."""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT id FROM salesforce_objects WHERE object_type = ?", (object_type,)
)
return [row[0] for row in cursor.fetchall()]
def get_affected_parent_ids_by_type(
updated_ids: list[str],
parent_types: list[str],
batch_size: int = 500,
) -> Iterator[tuple[str, set[str]]]:
"""Get IDs of objects that are of the specified parent types and are either in the
updated_ids or have children in the updated_ids. Yields tuples of (parent_type, affected_ids).
"""
# SQLite typically has a limit of 999 variables
updated_ids_batches = batch_list(updated_ids, batch_size)
updated_parent_ids: set[str] = set()
with get_db_connection() as conn:
cursor = conn.cursor()
for batch_ids in updated_ids_batches:
id_placeholders = ",".join(["?" for _ in batch_ids])
for parent_type in parent_types:
affected_ids: set[str] = set()
# Get directly updated objects of parent types - using index on object_type
cursor.execute(
f"""
SELECT id FROM salesforce_objects
WHERE id IN ({id_placeholders})
AND object_type = ?
""",
batch_ids + [parent_type],
)
affected_ids.update(row[0] for row in cursor.fetchall())
# Get parent objects of updated objects - using optimized relationship_types table
cursor.execute(
f"""
SELECT DISTINCT parent_id
FROM relationship_types
INDEXED BY idx_relationship_types_lookup
WHERE parent_type = ?
AND child_id IN ({id_placeholders})
""",
[parent_type] + batch_ids,
)
affected_ids.update(row[0] for row in cursor.fetchall())
# Remove any parent IDs that have already been processed
new_affected_ids = affected_ids - updated_parent_ids
# Add the new affected IDs to the set of updated parent IDs
updated_parent_ids.update(new_affected_ids)
if new_affected_ids:
yield parent_type, new_affected_ids
def has_at_least_one_object_of_type(object_type: str) -> bool:
"""Check if there is at least one object of the specified type in the database.
Args:
object_type: The Salesforce object type to check
Returns:
bool: True if at least one object exists, False otherwise
"""
with get_db_connection() as conn:
cursor = conn.cursor()
cursor.execute(
"SELECT COUNT(*) FROM salesforce_objects WHERE object_type = ?",
(object_type,),
)
count = cursor.fetchone()[0]
return count > 0

View File

@@ -1,66 +1,72 @@
import re
from typing import Union
SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
import os
from dataclasses import dataclass
from typing import Any
def _clean_salesforce_dict(data: Union[dict, list]) -> Union[dict, list]:
if isinstance(data, dict):
if "records" in data.keys():
data = data["records"]
if isinstance(data, dict):
if "attributes" in data.keys():
if isinstance(data["attributes"], dict):
data.update(data.pop("attributes"))
@dataclass
class SalesforceObject:
id: str
type: str
data: dict[str, Any]
if isinstance(data, dict):
filtered_dict = {}
for key, value in data.items():
if not re.search(SF_JSON_FILTER, key, re.IGNORECASE):
if "__c" in key: # remove the custom object indicator for display
key = key[:-3]
if isinstance(value, (dict, list)):
filtered_value = _clean_salesforce_dict(value)
if filtered_value: # Only add non-empty dictionaries or lists
filtered_dict[key] = filtered_value
elif value is not None:
filtered_dict[key] = value
return filtered_dict
elif isinstance(data, list):
filtered_list = []
for item in data:
if isinstance(item, (dict, list)):
filtered_item = _clean_salesforce_dict(item)
if filtered_item: # Only add non-empty dictionaries or lists
filtered_list.append(filtered_item)
elif item is not None:
filtered_list.append(filtered_item)
return filtered_list
else:
return data
def to_dict(self) -> dict[str, Any]:
return {
"ID": self.id,
"Type": self.type,
"Data": self.data,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> "SalesforceObject":
return cls(
id=data["Id"],
type=data["Type"],
data=data,
)
def _json_to_natural_language(data: Union[dict, list], indent: int = 0) -> str:
result = []
indent_str = " " * indent
if isinstance(data, dict):
for key, value in data.items():
if isinstance(value, (dict, list)):
result.append(f"{indent_str}{key}:")
result.append(_json_to_natural_language(value, indent + 2))
else:
result.append(f"{indent_str}{key}: {value}")
elif isinstance(data, list):
for item in data:
result.append(_json_to_natural_language(item, indent))
else:
result.append(f"{indent_str}{data}")
return "\n".join(result)
# This defines the base path for all data files relative to this file
# AKA BE CAREFUL WHEN MOVING THIS FILE
BASE_DATA_PATH = os.path.join(os.path.dirname(__file__), "data")
def extract_dict_text(raw_dict: dict) -> str:
processed_dict = _clean_salesforce_dict(raw_dict)
natural_language_dict = _json_to_natural_language(processed_dict)
return natural_language_dict
def get_sqlite_db_path() -> str:
"""Get the path to the sqlite db file."""
return os.path.join(BASE_DATA_PATH, "salesforce_db.sqlite")
def get_object_type_path(object_type: str) -> str:
"""Get the directory path for a specific object type."""
type_dir = os.path.join(BASE_DATA_PATH, object_type)
os.makedirs(type_dir, exist_ok=True)
return type_dir
_CHECKSUM_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"
_LOOKUP = {format(i, "05b"): _CHECKSUM_CHARS[i] for i in range(32)}
def validate_salesforce_id(salesforce_id: str) -> bool:
"""Validate the checksum portion of an 18-character Salesforce ID.
Args:
salesforce_id: An 18-character Salesforce ID
Returns:
bool: True if the checksum is valid, False otherwise
"""
if len(salesforce_id) != 18:
return False
chunks = [salesforce_id[0:5], salesforce_id[5:10], salesforce_id[10:15]]
checksum = salesforce_id[15:18]
calculated_checksum = ""
for chunk in chunks:
result_string = "".join(
"1" if char.isupper() else "0" for char in reversed(chunk)
)
calculated_checksum += _LOOKUP[result_string]
return checksum == calculated_checksum

View File

@@ -264,24 +264,6 @@ class SlackTextCleaner:
message = message.replace("<!everyone>", "@everyone")
return message
@staticmethod
def replace_links(message: str) -> str:
"""Replaces slack links e.g. `<URL>` -> `URL` and `<URL|DISPLAY>` -> `DISPLAY`"""
# Find user IDs in the message
possible_link_matches = re.findall(r"<(.*?)>", message)
for possible_link in possible_link_matches:
if not possible_link:
continue
# Special slack patterns that aren't for links
if possible_link[0] not in ["#", "@", "!"]:
link_display = (
possible_link
if "|" not in possible_link
else possible_link.split("|")[1]
)
message = message.replace(f"<{possible_link}>", link_display)
return message
@staticmethod
def replace_special_catchall(message: str) -> str:
"""Replaces pattern of <!something|another-thing> with another-thing

View File

@@ -33,6 +33,7 @@ from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.html_utils import web_html_cleanup
from onyx.utils.logger import setup_logger
from onyx.utils.sitemap import list_pages_for_site
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -241,6 +242,12 @@ class WebConnector(LoadConnector):
self.to_visit_list = extract_urls_from_sitemap(_ensure_valid_url(base_url))
elif web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.UPLOAD:
# Explicitly check if running in multi-tenant mode to prevent potential security risks
if MULTI_TENANT:
raise ValueError(
"Upload input for web connector is not supported in cloud environments"
)
logger.warning(
"This is not a UI supported Web Connector flow, "
"are you sure you want to do this?"

View File

@@ -40,6 +40,13 @@ class ZendeskClient:
response = requests.get(
f"{self.base_url}/{endpoint}", auth=self.auth, params=params
)
if response.status_code == 429:
retry_after = response.headers.get("Retry-After")
if retry_after is not None:
# Sleep for the duration indicated by the Retry-After header
time.sleep(int(retry_after))
response.raise_for_status()
return response.json()

View File

@@ -96,6 +96,8 @@ class Tag(BaseModel):
class BaseFilters(BaseModel):
source_type: list[DocumentSource] | None = None
document_set: list[str] | None = None
user_folders: list[str] | None = None
document_ids: list[str] | None = None
time_cutoff: datetime | None = None
tags: list[Tag] | None = None

View File

@@ -54,9 +54,11 @@ def get_total_users_count(db_session: Session) -> int:
return user_count + invited_users
async def get_user_count() -> int:
async def get_user_count(only_admin_users: bool = False) -> int:
async with get_async_session_with_tenant() as session:
stmt = select(func.count(User.id))
if only_admin_users:
stmt = stmt.where(User.role == UserRole.ADMIN)
result = await session.execute(stmt)
user_count = result.scalar()
if user_count is None:

View File

@@ -7,6 +7,7 @@ from sqlalchemy import exists
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
@@ -90,15 +91,22 @@ def get_connector_credential_pairs(
user: User | None = None,
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
) -> list[ConnectorCredentialPair]:
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
stmt = _add_user_filters(stmt, user, get_editable)
if not include_disabled:
stmt = stmt.where(
ConnectorCredentialPair.status == ConnectorCredentialPairStatus.ACTIVE
) # noqa
)
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
return list(db_session.scalars(stmt).all())
@@ -310,6 +318,9 @@ def associate_default_cc_pair(db_session: Session) -> None:
if existing_association is not None:
return
# DefaultCCPair has id 1 since it is the first CC pair created
# It is DEFAULT_CC_PAIR_ID, but can't set it explicitly because it messed with the
# auto-incrementing id
association = ConnectorCredentialPair(
connector_id=0,
credential_id=0,
@@ -350,7 +361,12 @@ def add_credential_to_connector(
last_successful_index_time: datetime | None = None,
) -> StatusResponse:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(credential_id, user, db_session)
credential = fetch_credential_by_id(
credential_id,
user,
db_session,
get_editable=False,
)
if connector is None:
raise HTTPException(status_code=404, detail="Connector does not exist")
@@ -427,7 +443,12 @@ def remove_credential_from_connector(
db_session: Session,
) -> StatusResponse[int]:
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(credential_id, user, db_session)
credential = fetch_credential_by_id(
credential_id,
user,
db_session,
get_editable=False,
)
if connector is None:
raise HTTPException(status_code=404, detail="Connector does not exist")

View File

@@ -86,7 +86,7 @@ def _add_user_filters(
"""
Filter Credentials by:
- if the user is in the user_group that owns the Credential
- if the user is not a global_curator, they must also have a curator relationship
- if the user is a curator, they must also have a curator relationship
to the user_group
- if editing is being done, we also filter out Credentials that are owned by groups
that the user isn't a curator for
@@ -97,6 +97,7 @@ def _add_user_filters(
where_clause = User__UserGroup.user_id == user.id
if user.role == UserRole.CURATOR:
where_clause &= User__UserGroup.is_curator == True # noqa: E712
if get_editable:
user_groups = select(User__UserGroup.user_group_id).where(
User__UserGroup.user_id == user.id
@@ -152,10 +153,16 @@ def fetch_credential_by_id(
user: User | None,
db_session: Session,
assume_admin: bool = False,
get_editable: bool = True,
) -> Credential | None:
stmt = select(Credential).distinct()
stmt = stmt.where(Credential.id == credential_id)
stmt = _add_user_filters(stmt, user, assume_admin=assume_admin)
stmt = _add_user_filters(
stmt=stmt,
user=user,
assume_admin=assume_admin,
get_editable=get_editable,
)
result = db_session.execute(stmt)
credential = result.scalar_one_or_none()
return credential

View File

@@ -27,7 +27,7 @@ from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import Session
from sqlalchemy.orm import sessionmaker
from onyx.configs.app_configs import AWS_REGION
from onyx.configs.app_configs import AWS_REGION_NAME
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
@@ -273,7 +273,7 @@ async def get_async_connection() -> Any:
port = POSTGRES_PORT
user = POSTGRES_USER
db = POSTGRES_DB
token = get_iam_auth_token(host, port, user, AWS_REGION)
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
# asyncpg requires 'ssl="require"' if SSL needed
return await asyncpg.connect(
@@ -315,7 +315,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
token = get_iam_auth_token(host, port, user, AWS_REGION)
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
cparams["password"] = token
cparams["ssl"] = ssl_context
@@ -525,6 +525,6 @@ def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) ->
host = POSTGRES_HOST
port = POSTGRES_PORT
user = POSTGRES_USER
region = os.getenv("AWS_REGION", "us-east-2")
region = os.getenv("AWS_REGION_NAME", "us-east-2")
# Configure for psycopg2 with IAM token
configure_psycopg2_iam_auth(cparams, host, port, user, region)

View File

@@ -1,132 +0,0 @@
from uuid import UUID
from sqlalchemy.orm import Session
from onyx.db.chat import delete_chat_session
from onyx.db.models import ChatFolder
from onyx.db.models import ChatSession
from onyx.utils.logger import setup_logger
logger = setup_logger()
def get_user_folders(
user_id: UUID | None,
db_session: Session,
) -> list[ChatFolder]:
return db_session.query(ChatFolder).filter(ChatFolder.user_id == user_id).all()
def update_folder_display_priority(
user_id: UUID | None,
display_priority_map: dict[int, int],
db_session: Session,
) -> None:
folders = get_user_folders(user_id=user_id, db_session=db_session)
folder_ids = {folder.id for folder in folders}
if folder_ids != set(display_priority_map.keys()):
raise ValueError("Invalid Folder IDs provided")
for folder in folders:
folder.display_priority = display_priority_map[folder.id]
db_session.commit()
def get_folder_by_id(
user_id: UUID | None,
folder_id: int,
db_session: Session,
) -> ChatFolder:
folder = (
db_session.query(ChatFolder).filter(ChatFolder.id == folder_id).one_or_none()
)
if not folder:
raise ValueError("Folder by specified id does not exist")
if folder.user_id != user_id:
raise PermissionError(f"Folder does not belong to user: {user_id}")
return folder
def create_folder(
user_id: UUID | None, folder_name: str | None, db_session: Session
) -> int:
new_folder = ChatFolder(
user_id=user_id,
name=folder_name,
)
db_session.add(new_folder)
db_session.commit()
return new_folder.id
def rename_folder(
user_id: UUID | None, folder_id: int, folder_name: str | None, db_session: Session
) -> None:
folder = get_folder_by_id(
user_id=user_id, folder_id=folder_id, db_session=db_session
)
folder.name = folder_name
db_session.commit()
def add_chat_to_folder(
user_id: UUID | None, folder_id: int, chat_session: ChatSession, db_session: Session
) -> None:
folder = get_folder_by_id(
user_id=user_id, folder_id=folder_id, db_session=db_session
)
chat_session.folder_id = folder.id
db_session.commit()
def remove_chat_from_folder(
user_id: UUID | None, folder_id: int, chat_session: ChatSession, db_session: Session
) -> None:
folder = get_folder_by_id(
user_id=user_id, folder_id=folder_id, db_session=db_session
)
if chat_session.folder_id != folder.id:
raise ValueError("The chat session is not in the specified folder.")
if folder.user_id != user_id:
raise ValueError(
f"Tried to remove a chat session from a folder that does not below to "
f"this user, user id: {user_id}"
)
chat_session.folder_id = None
if chat_session in folder.chat_sessions:
folder.chat_sessions.remove(chat_session)
db_session.commit()
def delete_folder(
user_id: UUID | None,
folder_id: int,
including_chats: bool,
db_session: Session,
) -> None:
folder = get_folder_by_id(
user_id=user_id, folder_id=folder_id, db_session=db_session
)
# Assuming there will not be a massive number of chats in any given folder
if including_chats:
for chat_session in folder.chat_sessions:
delete_chat_session(
user_id=user_id,
chat_session_id=chat_session.id,
db_session=db_session,
)
db_session.delete(folder)
db_session.commit()

View File

@@ -54,6 +54,7 @@ from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.enums import TaskStatus
from onyx.db.pydantic_type import PydanticType
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import LLMOverride
@@ -65,6 +66,8 @@ from onyx.utils.headers import HeaderItemDict
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import RerankerProvider
logger = setup_logger()
class Base(DeclarativeBase):
__abstract__ = True
@@ -72,6 +75,8 @@ class Base(DeclarativeBase):
class EncryptedString(TypeDecorator):
impl = LargeBinary
# This type's behavior is fully deterministic and doesn't depend on any external factors.
cache_ok = True
def process_bind_param(self, value: str | None, dialect: Dialect) -> bytes | None:
if value is not None:
@@ -86,6 +91,8 @@ class EncryptedString(TypeDecorator):
class EncryptedJson(TypeDecorator):
impl = LargeBinary
# This type's behavior is fully deterministic and doesn't depend on any external factors.
cache_ok = True
def process_bind_param(self, value: dict | None, dialect: Dialect) -> bytes | None:
if value is not None:
@@ -102,11 +109,76 @@ class EncryptedJson(TypeDecorator):
return value
class NullFilteredString(TypeDecorator):
impl = String
# This type's behavior is fully deterministic and doesn't depend on any external factors.
cache_ok = True
def process_bind_param(self, value: str | None, dialect: Dialect) -> str | None:
if value is not None and "\x00" in value:
logger.warning(f"NUL characters found in value: {value}")
return value.replace("\x00", "")
return value
def process_result_value(self, value: str | None, dialect: Dialect) -> str | None:
return value
"""
Auth/Authz (users, permissions, access) Tables
"""
class UserFolder(Base):
__tablename__ = "user_folder"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), nullable=False)
parent_id: Mapped[int | None] = mapped_column(
ForeignKey("user_folder.id"), nullable=True
)
name: Mapped[str] = mapped_column(nullable=False)
created_at: Mapped[datetime.datetime] = mapped_column(
default=datetime.datetime.utcnow
)
user: Mapped["User"] = relationship(back_populates="folders")
parent: Mapped["UserFolder"] = relationship(
remote_side=[id], back_populates="children"
)
children: Mapped[list["UserFolder"]] = relationship(back_populates="parent")
files: Mapped[list["UserFile"]] = relationship(back_populates="folder")
chat_sessions: Mapped[list["ChatSession"]] = relationship(back_populates="folder")
class UserDocument(str, Enum):
CHAT = "chat"
RECENT = "recent"
FILE = "file"
class UserFile(Base):
__tablename__ = "user_file"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id"), nullable=False)
parent_folder_id: Mapped[int | None] = mapped_column(
ForeignKey("user_folder.id"), nullable=True
)
file_id: Mapped[str] = mapped_column(nullable=False)
document_id: Mapped[str] = mapped_column(nullable=False)
name: Mapped[str] = mapped_column(nullable=False)
created_at: Mapped[datetime.datetime] = mapped_column(
default=datetime.datetime.utcnow
)
ccpair_id: Mapped[int | None] = mapped_column(
ForeignKey("connector_credential_pair.id"), nullable=False
)
user: Mapped["User"] = relationship(back_populates="files")
folder: Mapped["UserFolder"] = relationship(back_populates="files")
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
# even an almost empty token from keycloak will not fit the default 1024 bytes
access_token: Mapped[str] = mapped_column(Text, nullable=False) # type: ignore
@@ -156,9 +228,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
chat_sessions: Mapped[list["ChatSession"]] = relationship(
"ChatSession", back_populates="user"
)
chat_folders: Mapped[list["ChatFolder"]] = relationship(
"ChatFolder", back_populates="user"
)
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
@@ -176,6 +245,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
)
folders: Mapped[list["UserFolder"]] = relationship(
"UserFolder", back_populates="user"
)
files: Mapped[list["UserFile"]] = relationship("UserFile", back_populates="user")
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
pass
@@ -451,16 +525,16 @@ class Document(Base):
# this should correspond to the ID of the document
# (as is passed around in Onyx)
id: Mapped[str] = mapped_column(String, primary_key=True)
id: Mapped[str] = mapped_column(NullFilteredString, primary_key=True)
from_ingestion_api: Mapped[bool] = mapped_column(
Boolean, default=False, nullable=True
)
# 0 for neutral, positive for mostly endorse, negative for mostly reject
boost: Mapped[int] = mapped_column(Integer, default=DEFAULT_BOOST)
hidden: Mapped[bool] = mapped_column(Boolean, default=False)
semantic_id: Mapped[str] = mapped_column(String)
semantic_id: Mapped[str] = mapped_column(NullFilteredString)
# First Section's link
link: Mapped[str | None] = mapped_column(String, nullable=True)
link: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True)
# The updated time is also used as a measure of the last successful state of the doc
# pulled from the source (to help skip reindexing already updated docs in case of
@@ -976,7 +1050,7 @@ class ChatSession(Base):
default=ChatSessionSharedStatus.PRIVATE,
)
folder_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_folder.id"), nullable=True
ForeignKey("user_folder.id", ondelete="SET NULL"), nullable=True
)
current_alternate_model: Mapped[str | None] = mapped_column(String, default=None)
@@ -1006,8 +1080,8 @@ class ChatSession(Base):
DateTime(timezone=True), server_default=func.now()
)
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
folder: Mapped["ChatFolder"] = relationship(
"ChatFolder", back_populates="chat_sessions"
folder: Mapped["UserFolder"] = relationship(
"UserFolder", back_populates="chat_sessions"
)
messages: Mapped[list["ChatMessage"]] = relationship(
"ChatMessage", back_populates="chat_session", cascade="all, delete-orphan"
@@ -1095,33 +1169,6 @@ class ChatMessage(Base):
)
class ChatFolder(Base):
"""For organizing chat sessions"""
__tablename__ = "chat_folder"
id: Mapped[int] = mapped_column(primary_key=True)
# Only null if auth is off
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
name: Mapped[str | None] = mapped_column(String, nullable=True)
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=0)
user: Mapped[User] = relationship("User", back_populates="chat_folders")
chat_sessions: Mapped[list["ChatSession"]] = relationship(
"ChatSession", back_populates="folder"
)
def __lt__(self, other: Any) -> bool:
if not isinstance(other, ChatFolder):
return NotImplemented
if self.display_priority == other.display_priority:
# Bigger ID (created later) show earlier
return self.id > other.id
return self.display_priority < other.display_priority
"""
Feedback, Logging, Metrics Tables
"""

View File

@@ -0,0 +1,36 @@
from typing import List
from fastapi import UploadFile
from sqlalchemy.orm import Session
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.server.documents.connector import upload_files
from onyx.server.documents.models import FileUploadResponse
CHAT_FOLDER_ID = -1
RECENT_DOCUMENTS_FOLDER_ID = -2
def create_user_files(
files: List[UploadFile],
folder_id: int | None,
user: User | None,
db_session: Session,
) -> FileUploadResponse:
upload_response = upload_files(files, db_session)
for file_path, file in zip(upload_response.file_paths, files):
new_file = UserFile(
user_id=user.id if user else None,
parent_folder_id=folder_id,
file_id=file_path,
document_id=file_path, # We'll use the same ID for now
name=file.filename,
)
db_session.add(new_file)
db_session.commit()
return upload_response
# def trigger_document_indexing(db_session: Session, user_id: int) -> None:

View File

@@ -0,0 +1,29 @@
from typing import List
from fastapi import UploadFile
from sqlalchemy.orm import Session
from onyx.db.models import User
from onyx.db.models import UserFile
from onyx.server.documents.connector import upload_files
from onyx.server.documents.models import FileUploadResponse
def create_user_files(
files: List[UploadFile],
folder_id: int | None,
user: User,
db_session: Session,
) -> FileUploadResponse:
upload_response = upload_files(files, db_session)
for file_path, file in zip(upload_response.file_paths, files):
new_file = UserFile(
user_id=user.id if user else None,
parent_folder_id=folder_id if folder_id != -1 else None,
file_id=file_path,
document_id=file_path,
name=file.filename,
)
db_session.add(new_file)
db_session.commit()
return upload_response

View File

View File

@@ -7,8 +7,15 @@ from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.schemas import UserRole
from onyx.db.models import DocumentSet__User
from onyx.db.models import Persona__User
from onyx.db.models import SamlAccount
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
def validate_user_role_update(requested_role: UserRole, current_role: UserRole) -> None:
@@ -185,3 +192,43 @@ def batch_add_ext_perm_user_if_not_exists(
db_session.commit()
return found_users + new_users
def delete_user_from_db(
user_to_delete: User,
db_session: Session,
) -> None:
for oauth_account in user_to_delete.oauth_accounts:
db_session.delete(oauth_account)
fetch_ee_implementation_or_noop(
"onyx.db.external_perm",
"delete_user__ext_group_for_user__no_commit",
)(
db_session=db_session,
user_id=user_to_delete.id,
)
db_session.query(SamlAccount).filter(
SamlAccount.user_id == user_to_delete.id
).delete()
db_session.query(DocumentSet__User).filter(
DocumentSet__User.user_id == user_to_delete.id
).delete()
db_session.query(Persona__User).filter(
Persona__User.user_id == user_to_delete.id
).delete()
db_session.query(User__UserGroup).filter(
User__UserGroup.user_id == user_to_delete.id
).delete()
db_session.delete(user_to_delete)
db_session.commit()
# NOTE: edge case may exist with race conditions
# with this `invited user` scheme generally.
user_emails = get_invited_users()
remaining_users = [
remaining_user_email
for remaining_user_email in user_emails
if remaining_user_email != user_to_delete.email
]
write_invited_users(remaining_users)

View File

@@ -112,6 +112,11 @@ schema DANSWER_CHUNK_NAME {
rank: filter
attribute: fast-search
}
field user_folders type weightedset<string> {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
}
# If using different tokenization settings, the fieldset has to be removed, and the field must

View File

@@ -16,6 +16,12 @@ logger = setup_logger()
CONTENT_SUMMARY = "content_summary"
@retry(tries=10, delay=1, backoff=2)
def _retryable_http_delete(http_client: httpx.Client, url: str) -> None:
res = http_client.delete(url)
res.raise_for_status()
@retry(tries=3, delay=1, backoff=2)
def _delete_vespa_doc_chunks(
document_id: str, index_name: str, http_client: httpx.Client
@@ -28,10 +34,10 @@ def _delete_vespa_doc_chunks(
for chunk_id in doc_chunk_ids:
try:
res = http_client.delete(
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{chunk_id}"
_retryable_http_delete(
http_client,
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{chunk_id}",
)
res.raise_for_status()
except httpx.HTTPStatusError as e:
logger.error(f"Failed to delete chunk, details: {e.response.text}")
raise

View File

@@ -313,6 +313,7 @@ class VespaIndex(DocumentIndex):
with updating the associated permissions. Assumes that a document will not be split into
multiple chunk batches calling this function multiple times, otherwise only the last set of
chunks will be kept"""
# IMPORTANT: This must be done one index at a time, do not use secondary index here
cleaned_chunks = [clean_chunk_id_copy(chunk) for chunk in chunks]
@@ -706,6 +707,8 @@ class VespaIndex(DocumentIndex):
offset: int = 0,
title_content_ratio: float | None = TITLE_CONTENT_RATIO,
) -> list[InferenceChunkUncleaned]:
print("filters", filters)
print("filters.user_folders", filters.__dict__)
vespa_where_clauses = build_vespa_filters(filters)
# Needs to be at least as much as the value set in Vespa schema config
target_hits = max(10 * num_to_retrieve, 1000)

View File

@@ -64,10 +64,10 @@ def _does_document_exist(
if doc_fetch_response.status_code != 200:
logger.debug(f"Failed to check for document with URL {doc_url}")
raise RuntimeError(
f"Unexpected fetch document by ID value from Vespa "
f"with error {doc_fetch_response.status_code}"
f"Index name: {index_name}"
f"Doc chunk id: {doc_chunk_id}"
f"Unexpected fetch document by ID value from Vespa: "
f"error={doc_fetch_response.status_code} "
f"index={index_name} "
f"doc_chunk_id={doc_chunk_id}"
)
return True

View File

@@ -55,9 +55,7 @@ def remove_invalid_unicode_chars(text: str) -> str:
return _illegal_xml_chars_RE.sub("", text)
def get_vespa_http_client(
no_timeout: bool = False, http2: bool = False
) -> httpx.Client:
def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx.Client:
"""
Configure and return an HTTP client for communicating with Vespa,
including authentication if needed.

View File

@@ -9,11 +9,13 @@ from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import CHUNK_ID
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_IDS
from onyx.document_index.vespa_constants import DOCUMENT_SETS
from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import SOURCE_TYPE
from onyx.document_index.vespa_constants import TENANT_ID
from onyx.document_index.vespa_constants import USER_FOLDERS
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -77,10 +79,15 @@ def build_vespa_filters(
tags = filters.tags
if tags:
tag_attributes = [tag.tag_key + INDEX_SEPARATOR + tag.tag_value for tag in tags]
filter_str += _build_or_filters(METADATA_LIST, tag_attributes)
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
filter_str += _build_or_filters(USER_FOLDERS, filters.user_folders)
filter_str += _build_or_filters(DOCUMENT_IDS, filters.document_ids)
filter_str += _build_time_filter(filters.time_cutoff)
if remove_trailing_and and filter_str.endswith(" and "):

View File

@@ -64,6 +64,8 @@ EMBEDDINGS = "embeddings"
TITLE_EMBEDDING = "title_embedding"
ACCESS_CONTROL_LIST = "access_control_list"
DOCUMENT_SETS = "document_sets"
USER_FOLDERS = "user_folders"
DOCUMENT_IDS = "document_ids"
LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
METADATA = "metadata"
METADATA_LIST = "metadata_list"

View File

@@ -260,6 +260,21 @@ def index_doc_batch_prepare(
def filter_documents(document_batch: list[Document]) -> list[Document]:
documents: list[Document] = []
for document in document_batch:
# Remove any NUL characters from title/semantic_id
# This is a known issue with the Zendesk connector
# Postgres cannot handle NUL characters in text fields
if document.title:
document.title = document.title.replace("\x00", "")
if document.semantic_identifier:
document.semantic_identifier = document.semantic_identifier.replace(
"\x00", ""
)
# Remove NUL characters from all sections
for section in document.sections:
if section.text is not None:
section.text = section.text.replace("\x00", "")
empty_contents = not any(section.text.strip() for section in document.sections)
if (
(not document.title or not document.title.strip())

View File

@@ -266,18 +266,27 @@ class DefaultMultiLLM(LLM):
# )
self._custom_config = custom_config
# Create a dictionary for model-specific arguments if it's None
model_kwargs = model_kwargs or {}
# NOTE: have to set these as environment variables for Litellm since
# not all are able to passed in but they always support them set as env
# variables. We'll also try passing them in, since litellm just ignores
# addtional kwargs (and some kwargs MUST be passed in rather than set as
# env variables)
if custom_config:
for k, v in custom_config.items():
os.environ[k] = v
# Specifically pass in "vertex_credentials" as a model_kwarg to the
# completion call for vertex AI. More details here:
# https://docs.litellm.ai/docs/providers/vertex
vertex_credentials_key = "vertex_credentials"
vertex_credentials = custom_config.get(vertex_credentials_key)
if vertex_credentials and model_provider == "vertex_ai":
model_kwargs[vertex_credentials_key] = vertex_credentials
else:
# standard case
for k, v in custom_config.items():
os.environ[k] = v
model_kwargs = model_kwargs or {}
if custom_config:
model_kwargs.update(custom_config)
if extra_headers:
model_kwargs.update({"extra_headers": extra_headers})
if extra_body:

View File

@@ -53,7 +53,6 @@ from onyx.server.documents.document import router as document_router
from onyx.server.documents.indexing import router as indexing_router
from onyx.server.documents.standard_oauth import router as oauth_router
from onyx.server.features.document_set.api import router as document_set_router
from onyx.server.features.folder.api import router as folder_router
from onyx.server.features.notifications.api import router as notification_router
from onyx.server.features.persona.api import admin_router as admin_persona_router
from onyx.server.features.persona.api import basic_router as persona_router
@@ -74,6 +73,9 @@ from onyx.server.manage.search_settings import router as search_settings_router
from onyx.server.manage.slack_bot import router as slack_bot_management_router
from onyx.server.manage.users import router as user_router
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
from onyx.server.middleware.rate_limiting import close_limiter
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
from onyx.server.middleware.rate_limiting import setup_limiter
from onyx.server.onyx_api.ingestion import router as onyx_api_router
from onyx.server.openai_assistants_api.full_openai_assistants_api import (
get_full_openai_assistants_api_router,
@@ -88,6 +90,7 @@ from onyx.server.settings.api import basic_router as settings_router
from onyx.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
from onyx.server.user_documents.api import router as user_documents_router
from onyx.server.utils import BasicAuthenticationError
from onyx.setup import setup_multitenant_onyx
from onyx.setup import setup_onyx
@@ -153,6 +156,20 @@ def include_router_with_global_prefix_prepended(
application.include_router(router, **final_kwargs)
def include_auth_router_with_prefix(
application: FastAPI, router: APIRouter, prefix: str, tags: list[str] | None = None
) -> None:
"""Wrapper function to include an 'auth' router with prefix + rate-limiting dependencies."""
final_tags = tags or ["auth"]
include_router_with_global_prefix_prepended(
application,
router,
prefix=prefix,
tags=final_tags,
dependencies=get_auth_rate_limiters(),
)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
# Set recursion limit
@@ -194,8 +211,15 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
setup_multitenant_onyx()
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
# Set up rate limiter
await setup_limiter()
yield
# Close rate limiter
await close_limiter()
def log_http_error(_: Request, exc: Exception) -> JSONResponse:
status_code = getattr(exc, "status_code", 500)
@@ -246,8 +270,6 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, user_router)
include_router_with_global_prefix_prepended(application, credential_router)
include_router_with_global_prefix_prepended(application, cc_pair_router)
include_router_with_global_prefix_prepended(application, folder_router)
include_router_with_global_prefix_prepended(application, document_set_router)
include_router_with_global_prefix_prepended(application, search_settings_router)
include_router_with_global_prefix_prepended(
application, slack_bot_management_router
@@ -262,15 +284,18 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, onyx_api_router)
include_router_with_global_prefix_prepended(application, gpts_router)
include_router_with_global_prefix_prepended(application, settings_router)
include_router_with_global_prefix_prepended(application, user_documents_router)
include_router_with_global_prefix_prepended(application, settings_admin_router)
include_router_with_global_prefix_prepended(application, llm_admin_router)
include_router_with_global_prefix_prepended(application, llm_router)
include_router_with_global_prefix_prepended(application, embedding_admin_router)
include_router_with_global_prefix_prepended(application, embedding_router)
include_router_with_global_prefix_prepended(application, document_set_router)
include_router_with_global_prefix_prepended(application, indexing_router)
include_router_with_global_prefix_prepended(
application, token_rate_limit_settings_router
)
include_router_with_global_prefix_prepended(application, indexing_router)
include_router_with_global_prefix_prepended(
application, get_full_openai_assistants_api_router()
)
@@ -283,42 +308,37 @@ def get_application() -> FastAPI:
pass
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_auth_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_register_router(UserRead, UserCreate),
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_reset_password_router(),
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_verify_router(UserRead),
prefix="/auth",
tags=["auth"],
)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_users_router(UserRead, UserUpdate),
prefix="/users",
tags=["users"],
)
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
oauth_client,
@@ -330,15 +350,13 @@ def get_application() -> FastAPI:
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
),
prefix="/auth/oauth",
tags=["auth"],
)
# Need basic auth router for `logout` endpoint
include_router_with_global_prefix_prepended(
include_auth_router_with_prefix(
application,
fastapi_users.get_logout_router(auth_backend),
prefix="/auth",
tags=["auth"],
)
application.add_exception_handler(

View File

@@ -131,10 +131,15 @@ class EmbeddingModel:
tries=10, delay=10, exceptions=ModelServerRateLimitError
)(final_make_request_func)
response: Response | None = None
try:
response = final_make_request_func()
return EmbedResponse(**response.json())
except requests.HTTPError as e:
if not response:
raise HTTPError("HTTP error occurred - response is None.") from e
try:
error_detail = response.json().get("detail", str(e))
except Exception:

View File

@@ -1,6 +1,4 @@
import re
from datetime import datetime
from re import Match
import pytz
import timeago # type: ignore
@@ -59,33 +57,6 @@ def get_feedback_reminder_blocks(thread_link: str, include_followup: bool) -> Bl
return SectionBlock(text=text)
def _process_citations_for_slack(text: str) -> str:
"""
Converts instances of [[x]](LINK) in the input text to Slack's link format <LINK|[x]>.
Args:
- text (str): The input string containing markdown links.
Returns:
- str: The string with markdown links converted to Slack format.
"""
# Regular expression to find all instances of [[x]](LINK)
pattern = r"\[\[(.*?)\]\]\((.*?)\)"
# Function to replace each found instance with Slack's format
def slack_link_format(match: Match) -> str:
link_text = match.group(1)
link_url = match.group(2)
# Account for empty link citations
if link_url == "":
return f"[{link_text}]"
return f"<{link_url}|[{link_text}]>"
# Substitute all matches in the input text
return re.sub(pattern, slack_link_format, text)
def _split_text(text: str, limit: int = 3000) -> list[str]:
if len(text) <= limit:
return [text]
@@ -369,15 +340,12 @@ def _build_citations_blocks(
def _build_qa_response_blocks(
answer: ChatOnyxBotResponse,
process_message_for_citations: bool = False,
) -> list[Block]:
retrieval_info = answer.docs
if not retrieval_info:
# This should not happen, even with no docs retrieved, there is still info returned
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
if DISABLE_GENERATIVE_AI:
return []
@@ -408,18 +376,18 @@ def _build_qa_response_blocks(
filter_block = SectionBlock(text=f"_{filter_text}_")
if not formatted_answer:
if not answer.answer:
answer_blocks = [
SectionBlock(
text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓"
)
]
else:
# replaces markdown links with slack format links
formatted_answer = format_slack_message(answer.answer)
answer_processed = decode_escapes(
remove_slack_text_interactions(formatted_answer)
)
if process_message_for_citations:
answer_processed = _process_citations_for_slack(answer_processed)
answer_blocks = [
SectionBlock(text=text) for text in _split_text(answer_processed)
]
@@ -525,7 +493,6 @@ def build_slack_response_blocks(
answer_blocks = _build_qa_response_blocks(
answer=answer,
process_message_for_citations=use_citations,
)
web_follow_up_block = []

View File

@@ -4,73 +4,55 @@ from onyx.configs.constants import DocumentSource
def source_to_github_img_link(source: DocumentSource) -> str | None:
# TODO: store these images somewhere better
if source == DocumentSource.WEB.value:
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/Web.png"
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Web.png"
if source == DocumentSource.FILE.value:
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/File.png"
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/File.png"
if source == DocumentSource.GOOGLE_SITES.value:
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/GoogleSites.png"
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/GoogleSites.png"
if source == DocumentSource.SLACK.value:
return (
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Slack.png"
)
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Slack.png"
if source == DocumentSource.GMAIL.value:
return (
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Gmail.png"
)
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Gmail.png"
if source == DocumentSource.GOOGLE_DRIVE.value:
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/GoogleDrive.png"
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/GoogleDrive.png"
if source == DocumentSource.GITHUB.value:
return (
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Github.png"
)
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Github.png"
if source == DocumentSource.GITLAB.value:
return (
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Gitlab.png"
)
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Gitlab.png"
if source == DocumentSource.CONFLUENCE.value:
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/Confluence.png"
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Confluence.png"
if source == DocumentSource.JIRA.value:
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/Jira.png"
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Jira.png"
if source == DocumentSource.NOTION.value:
return (
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Notion.png"
)
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Notion.png"
if source == DocumentSource.ZENDESK.value:
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/Zendesk.png"
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Zendesk.png"
if source == DocumentSource.GONG.value:
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Gong.png"
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Gong.png"
if source == DocumentSource.LINEAR.value:
return (
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Linear.png"
)
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Linear.png"
if source == DocumentSource.PRODUCTBOARD.value:
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Productboard.webp"
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Productboard.webp"
if source == DocumentSource.SLAB.value:
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/SlabLogo.png"
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/SlabLogo.png"
if source == DocumentSource.ZULIP.value:
return (
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Zulip.png"
)
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Zulip.png"
if source == DocumentSource.GURU.value:
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/Guru.png"
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Guru.png"
if source == DocumentSource.HUBSPOT.value:
return (
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/HubSpot.png"
)
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/HubSpot.png"
if source == DocumentSource.DOCUMENT360.value:
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Document360.png"
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Document360.png"
if source == DocumentSource.BOOKSTACK.value:
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Bookstack.png"
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Bookstack.png"
if source == DocumentSource.LOOPIO.value:
return (
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Loopio.png"
)
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Loopio.png"
if source == DocumentSource.SHAREPOINT.value:
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Sharepoint.png"
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Sharepoint.png"
if source == DocumentSource.REQUESTTRACKER.value:
# just use file icon for now
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/File.png"
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/File.png"
if source == DocumentSource.INGESTION_API.value:
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/File.png"
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/File.png"
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/File.png"
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/File.png"

View File

@@ -375,7 +375,6 @@ def remove_slack_text_interactions(slack_str: str) -> str:
slack_str = SlackTextCleaner.replace_tags_basic(slack_str)
slack_str = SlackTextCleaner.replace_channels_basic(slack_str)
slack_str = SlackTextCleaner.replace_special_mentions(slack_str)
slack_str = SlackTextCleaner.replace_links(slack_str)
slack_str = SlackTextCleaner.replace_special_catchall(slack_str)
slack_str = SlackTextCleaner.add_zero_width_whitespace_after_tag(slack_str)
return slack_str

View File

@@ -162,7 +162,7 @@ class RedisConnectorPermissionSync:
),
queue=OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
task_id=custom_task_id,
priority=OnyxCeleryPriority.MEDIUM,
priority=OnyxCeleryPriority.HIGH,
)
async_results.append(result)

View File

@@ -118,7 +118,7 @@ class RedisConnectorIndex:
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.active_key, 0, ex=300)
self.redis.set(self.active_key, 0, ex=3600)
def active(self) -> bool:
if self.redis.exists(self.active_key):
@@ -172,6 +172,9 @@ class RedisConnectorIndex:
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorIndex.ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_LOCK_PREFIX + "*"):
r.delete(key)

View File

@@ -1,3 +1,4 @@
import asyncio
import functools
import threading
from collections.abc import Callable
@@ -5,6 +6,7 @@ from typing import Any
from typing import Optional
import redis
from redis import asyncio as aioredis
from redis.client import Redis
from onyx.configs.app_configs import REDIS_DB_NUMBER
@@ -196,3 +198,33 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
# redis_client.set('key', 'value')
# value = redis_client.get('key')
# print(value.decode()) # Output: 'value'
_async_redis_connection: aioredis.Redis | None = None
_async_lock = asyncio.Lock()
async def get_async_redis_connection() -> aioredis.Redis:
"""
Provides a shared async Redis connection, using the same configs (host, port, SSL, etc.).
Ensures that the connection is created only once (lazily) and reused for all future calls.
"""
global _async_redis_connection
# If we haven't yet created an async Redis connection, we need to create one
if _async_redis_connection is None:
# Acquire the lock to ensure that only one coroutine attempts to create the connection
async with _async_lock:
# Double-check inside the lock to avoid race conditions
if _async_redis_connection is None:
scheme = "rediss" if REDIS_SSL else "redis"
url = f"{scheme}://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER}"
# Create a new Redis connection (or connection pool) from the URL
_async_redis_connection = aioredis.from_url(
url,
password=REDIS_PASSWORD,
max_connections=REDIS_POOL_MAX_CONNECTIONS,
)
# Return the established connection (or pool) for all future operations
return _async_redis_connection

View File

@@ -216,7 +216,7 @@ def seed_initial_documents(
# Retries here because the index may take a few seconds to become ready
# as we just sent over the Vespa schema and there is a slight delay
index_with_retries = retry_builder()(document_index.index)
index_with_retries = retry_builder(tries=15)(document_index.index)
index_with_retries(chunks=chunks, fresh_index=cohere_enabled)
# Mock a run for the UI even though it did not actually call out to anything

View File

@@ -0,0 +1,9 @@
# This file is used to seed the default chat folder for a user
default_folder:
name: "Chat"
description: "This is the default chat folder for a user"
name: "Recent Documents"
description: "This is the default folder for users to store their recent documents"

View File

@@ -5,6 +5,7 @@ from fastapi.dependencies.models import Dependant
from starlette.routing import BaseRoute
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
@@ -109,6 +110,7 @@ def check_router_auth(
or depends_fn == current_curator_or_admin_user
or depends_fn == api_key_dep
or depends_fn == current_user_with_expired_token
or depends_fn == current_chat_accesssible_user
or depends_fn == control_plane_dep
or depends_fn == current_cloud_superuser
):

View File

@@ -510,7 +510,7 @@ def associate_credential_to_connector(
db_session: Session = Depends(get_session),
) -> StatusResponse[int]:
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_user_creation_permissions", None
"onyx.db.user_group", "validate_object_creation_for_user", None
)(
db_session=db_session,
user=user,
@@ -532,7 +532,8 @@ def associate_credential_to_connector(
)
return response
except IntegrityError:
except IntegrityError as e:
logger.error(f"IntegrityError: {e}")
raise HTTPException(status_code=400, detail="Name must be unique")

View File

@@ -14,6 +14,7 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.celery_utils import get_deletion_attempt_snapshot
@@ -374,10 +375,8 @@ def check_drive_tokens(
return AuthStatus(authenticated=True)
@router.post("/admin/connector/file/upload")
def upload_files(
files: list[UploadFile],
_: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> FileUploadResponse:
for file in files:
@@ -407,6 +406,15 @@ def upload_files(
return FileUploadResponse(file_paths=deduped_file_paths)
@router.post("/admin/connector/file/upload")
def upload_files_api(
files: list[UploadFile],
_: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> FileUploadResponse:
return upload_files(files, db_session)
# Retrieves most recent failure cases for connectors that are currently failing
@router.get("/admin/connector/failed-indexing-status")
def get_currently_failed_indexing_status(
@@ -680,7 +688,7 @@ def create_connector_from_model(
_validate_connector_allowed(connector_data.source)
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_user_creation_permissions", None
"onyx.db.user_group", "validate_object_creation_for_user", None
)(
db_session=db_session,
user=user,
@@ -716,7 +724,7 @@ def create_connector_with_mock_credential(
tenant_id: str = Depends(get_current_tenant_id),
) -> StatusResponse:
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_user_creation_permissions", None
"onyx.db.user_group", "validate_object_creation_for_user", None
)(
db_session=db_session,
user=user,
@@ -776,7 +784,7 @@ def update_connector_from_model(
try:
_validate_connector_allowed(connector_data.source)
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_user_creation_permissions", None
"onyx.db.user_group", "validate_object_creation_for_user", None
)(
db_session=db_session,
user=user,
@@ -1055,10 +1063,10 @@ class BasicCCPairInfo(BaseModel):
@router.get("/connector-status")
def get_basic_connector_indexing_status(
_: User = Depends(current_user),
_: User = Depends(current_chat_accesssible_user),
db_session: Session = Depends(get_session),
) -> list[BasicCCPairInfo]:
cc_pairs = get_connector_credential_pairs(db_session)
cc_pairs = get_connector_credential_pairs(db_session, eager_load_connector=True)
return [
BasicCCPairInfo(
has_successful_run=cc_pair.last_successful_index_time is not None,

View File

@@ -122,7 +122,7 @@ def create_credential_from_model(
) -> ObjectCreationIdResponse:
if not _ignore_credential_permissions(credential_info.source):
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_user_creation_permissions", None
"onyx.db.user_group", "validate_object_creation_for_user", None
)(
db_session=db_session,
user=user,
@@ -164,7 +164,12 @@ def get_credential_by_id(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> CredentialSnapshot | StatusResponse[int]:
credential = fetch_credential_by_id(credential_id, user, db_session)
credential = fetch_credential_by_id(
credential_id,
user,
db_session,
get_editable=False,
)
if credential is None:
raise HTTPException(
status_code=401,

View File

@@ -1,3 +1,4 @@
import json
import uuid
from typing import Annotated
from typing import cast
@@ -6,7 +7,9 @@ from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import Request
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
@@ -28,6 +31,8 @@ router = APIRouter(prefix="/connector/oauth")
_OAUTH_STATE_KEY_FMT = "oauth_state:{state}"
_OAUTH_STATE_EXPIRATION_SECONDS = 10 * 60 # 10 minutes
_DESIRED_RETURN_URL_KEY = "desired_return_url"
_ADDITIONAL_KWARGS_KEY = "additional_kwargs"
# Cache for OAuth connectors, populated at module load time
_OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {}
@@ -51,12 +56,36 @@ def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]:
_discover_oauth_connectors()
def _get_additional_kwargs(
request: Request, connector_cls: type[OAuthConnector], args_to_ignore: list[str]
) -> dict[str, str]:
# get additional kwargs from request
# e.g. anything except for desired_return_url
additional_kwargs_dict = {
k: v for k, v in request.query_params.items() if k not in args_to_ignore
}
try:
# validate
connector_cls.AdditionalOauthKwargs(**additional_kwargs_dict)
except ValidationError:
raise HTTPException(
status_code=400,
detail=(
f"Invalid additional kwargs. Got {additional_kwargs_dict}, expected "
f"{connector_cls.AdditionalOauthKwargs.model_json_schema()}"
),
)
return additional_kwargs_dict
class AuthorizeResponse(BaseModel):
redirect_url: str
@router.get("/authorize/{source}")
def oauth_authorize(
request: Request,
source: DocumentSource,
desired_return_url: Annotated[str | None, Query()] = None,
_: User = Depends(current_user),
@@ -71,6 +100,12 @@ def oauth_authorize(
connector_cls = oauth_connectors[source]
base_url = WEB_DOMAIN
# get additional kwargs from request
# e.g. anything except for desired_return_url
additional_kwargs = _get_additional_kwargs(
request, connector_cls, ["desired_return_url"]
)
# store state in redis
if not desired_return_url:
desired_return_url = f"{base_url}/admin/connectors/{source}?step=0"
@@ -78,12 +113,19 @@ def oauth_authorize(
state = str(uuid.uuid4())
redis_client.set(
_OAUTH_STATE_KEY_FMT.format(state=state),
desired_return_url,
json.dumps(
{
_DESIRED_RETURN_URL_KEY: desired_return_url,
_ADDITIONAL_KWARGS_KEY: additional_kwargs,
}
),
ex=_OAUTH_STATE_EXPIRATION_SECONDS,
)
return AuthorizeResponse(
redirect_url=connector_cls.oauth_authorization_url(base_url, state)
redirect_url=connector_cls.oauth_authorization_url(
base_url, state, additional_kwargs
)
)
@@ -110,15 +152,18 @@ def oauth_callback(
# get state from redis
redis_client = get_redis_client(tenant_id=tenant_id)
original_url_bytes = cast(
oauth_state_bytes = cast(
bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state))
)
if not original_url_bytes:
if not oauth_state_bytes:
raise HTTPException(status_code=400, detail="Invalid OAuth state")
original_url = original_url_bytes.decode("utf-8")
oauth_state = json.loads(oauth_state_bytes.decode("utf-8"))
desired_return_url = cast(str, oauth_state[_DESIRED_RETURN_URL_KEY])
additional_kwargs = cast(dict[str, str], oauth_state[_ADDITIONAL_KWARGS_KEY])
base_url = WEB_DOMAIN
token_info = connector_cls.oauth_code_to_token(base_url, code)
token_info = connector_cls.oauth_code_to_token(base_url, code, additional_kwargs)
# Create a new credential with the token info
credential_data = CredentialBase(
@@ -136,8 +181,52 @@ def oauth_callback(
return CallbackResponse(
redirect_url=(
f"{original_url}?credentialId={credential.id}"
if "?" not in original_url
else f"{original_url}&credentialId={credential.id}"
f"{desired_return_url}?credentialId={credential.id}"
if "?" not in desired_return_url
else f"{desired_return_url}&credentialId={credential.id}"
)
)
class OAuthAdditionalKwargDescription(BaseModel):
name: str
display_name: str
description: str
class OAuthDetails(BaseModel):
oauth_enabled: bool
additional_kwargs: list[OAuthAdditionalKwargDescription]
@router.get("/details/{source}")
def oauth_details(
source: DocumentSource,
_: User = Depends(current_user),
) -> OAuthDetails:
oauth_connectors = _discover_oauth_connectors()
if source not in oauth_connectors:
return OAuthDetails(
oauth_enabled=False,
additional_kwargs=[],
)
connector_cls = oauth_connectors[source]
additional_kwarg_descriptions = []
for key, value in connector_cls.AdditionalOauthKwargs.model_json_schema()[
"properties"
].items():
additional_kwarg_descriptions.append(
OAuthAdditionalKwargDescription(
name=key,
display_name=value.get("title", key),
description=value.get("description", ""),
)
)
return OAuthDetails(
oauth_enabled=True,
additional_kwargs=additional_kwarg_descriptions,
)

View File

@@ -31,7 +31,7 @@ def create_document_set(
db_session: Session = Depends(get_session),
) -> int:
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_user_creation_permissions", None
"onyx.db.user_group", "validate_object_creation_for_user", None
)(
db_session=db_session,
user=user,
@@ -56,7 +56,7 @@ def patch_document_set(
db_session: Session = Depends(get_session),
) -> None:
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_user_creation_permissions", None
"onyx.db.user_group", "validate_object_creation_for_user", None
)(
db_session=db_session,
user=user,

View File

@@ -1,176 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Path
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.db.chat import get_chat_session_by_id
from onyx.db.engine import get_session
from onyx.db.folder import add_chat_to_folder
from onyx.db.folder import create_folder
from onyx.db.folder import delete_folder
from onyx.db.folder import get_user_folders
from onyx.db.folder import remove_chat_from_folder
from onyx.db.folder import rename_folder
from onyx.db.folder import update_folder_display_priority
from onyx.db.models import User
from onyx.server.features.folder.models import DeleteFolderOptions
from onyx.server.features.folder.models import FolderChatSessionRequest
from onyx.server.features.folder.models import FolderCreationRequest
from onyx.server.features.folder.models import FolderResponse
from onyx.server.features.folder.models import FolderUpdateRequest
from onyx.server.features.folder.models import GetUserFoldersResponse
from onyx.server.models import DisplayPriorityRequest
from onyx.server.query_and_chat.models import ChatSessionDetails
router = APIRouter(prefix="/folder")
@router.get("")
def get_folders(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> GetUserFoldersResponse:
folders = get_user_folders(
user_id=user.id if user else None,
db_session=db_session,
)
folders.sort()
return GetUserFoldersResponse(
folders=[
FolderResponse(
folder_id=folder.id,
folder_name=folder.name,
display_priority=folder.display_priority,
chat_sessions=[
ChatSessionDetails(
id=chat_session.id,
name=chat_session.description,
persona_id=chat_session.persona_id,
time_created=chat_session.time_created.isoformat(),
shared_status=chat_session.shared_status,
folder_id=folder.id,
)
for chat_session in folder.chat_sessions
if not chat_session.deleted
],
)
for folder in folders
]
)
@router.put("/reorder")
def put_folder_display_priority(
display_priority_request: DisplayPriorityRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
update_folder_display_priority(
user_id=user.id if user else None,
display_priority_map=display_priority_request.display_priority_map,
db_session=db_session,
)
@router.post("")
def create_folder_endpoint(
request: FolderCreationRequest,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> int:
return create_folder(
user_id=user.id if user else None,
folder_name=request.folder_name,
db_session=db_session,
)
@router.patch("/{folder_id}")
def patch_folder_endpoint(
request: FolderUpdateRequest,
folder_id: int = Path(..., description="The ID of the folder to rename"),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
rename_folder(
user_id=user.id if user else None,
folder_id=folder_id,
folder_name=request.folder_name,
db_session=db_session,
)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/{folder_id}")
def delete_folder_endpoint(
request: DeleteFolderOptions,
folder_id: int = Path(..., description="The ID of the folder to delete"),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
user_id = user.id if user else None
try:
delete_folder(
user_id=user_id,
folder_id=folder_id,
including_chats=request.including_chats,
db_session=db_session,
)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/{folder_id}/add-chat-session")
def add_chat_to_folder_endpoint(
request: FolderChatSessionRequest,
folder_id: int = Path(
..., description="The ID of the folder in which to add the chat session"
),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
user_id = user.id if user else None
try:
chat_session = get_chat_session_by_id(
chat_session_id=request.chat_session_id,
user_id=user_id,
db_session=db_session,
)
add_chat_to_folder(
user_id=user.id if user else None,
folder_id=folder_id,
chat_session=chat_session,
db_session=db_session,
)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/{folder_id}/remove-chat-session/")
def remove_chat_from_folder_endpoint(
request: FolderChatSessionRequest,
folder_id: int = Path(
..., description="The ID of the folder from which to remove the chat session"
),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
user_id = user.id if user else None
try:
chat_session = get_chat_session_by_id(
chat_session_id=request.chat_session_id,
user_id=user_id,
db_session=db_session,
)
remove_chat_from_folder(
user_id=user_id,
folder_id=folder_id,
chat_session=chat_session,
db_session=db_session,
)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -1,32 +0,0 @@
from uuid import UUID
from pydantic import BaseModel
from onyx.server.query_and_chat.models import ChatSessionDetails
class FolderResponse(BaseModel):
folder_id: int
folder_name: str | None
display_priority: int
chat_sessions: list[ChatSessionDetails]
class GetUserFoldersResponse(BaseModel):
folders: list[FolderResponse]
class FolderCreationRequest(BaseModel):
folder_name: str | None = None
class FolderUpdateRequest(BaseModel):
folder_name: str | None = None
class FolderChatSessionRequest(BaseModel):
chat_session_id: UUID
class DeleteFolderOptions(BaseModel):
including_chats: bool = False

View File

@@ -10,6 +10,7 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
@@ -323,7 +324,7 @@ def get_image_generation_tool(
@basic_router.get("")
def list_personas(
user: User | None = Depends(current_user),
user: User | None = Depends(current_chat_accesssible_user),
db_session: Session = Depends(get_session),
include_deleted: bool = False,
persona_ids: list[int] = Query(None),

View File

@@ -1,6 +1,7 @@
from fastapi import APIRouter
from onyx import __version__
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import user_needs_to_be_verified
from onyx.configs.app_configs import AUTH_TYPE
from onyx.server.manage.models import AuthTypeResponse
@@ -18,7 +19,9 @@ def healthcheck() -> StatusResponse:
@router.get("/auth/type")
def get_auth_type() -> AuthTypeResponse:
return AuthTypeResponse(
auth_type=AUTH_TYPE, requires_verification=user_needs_to_be_verified()
auth_type=AUTH_TYPE,
requires_verification=user_needs_to_be_verified(),
anonymous_user_enabled=anonymous_user_enabled(),
)

View File

@@ -7,7 +7,7 @@ from fastapi import Query
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.db.engine import get_session
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_provider
@@ -57,7 +57,6 @@ def test_llm_configuration(
)
functions_with_args: list[tuple[Callable, tuple]] = [(test_llm, (llm,))]
if (
test_llm_request.fast_default_model_name
and test_llm_request.fast_default_model_name
@@ -190,7 +189,7 @@ def set_provider_as_default(
@basic_router.get("/provider")
def list_llm_provider_basics(
user: User | None = Depends(current_user),
user: User | None = Depends(current_chat_accesssible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
return [

Some files were not shown because too many files have changed in this diff Show More