mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 15:55:45 +00:00
Compare commits
3 Commits
improve
...
cloud_debu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
09e6bd3c9c | ||
|
|
c1803cdd56 | ||
|
|
a5b9c76012 |
18
.github/pull_request_template.md
vendored
18
.github/pull_request_template.md
vendored
@@ -6,6 +6,24 @@
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Accepted Risk (provide if relevant)
|
||||
N/A
|
||||
|
||||
|
||||
## Related Issue(s) (provide if relevant)
|
||||
N/A
|
||||
|
||||
|
||||
## Mental Checklist:
|
||||
- All of the automated tests pass
|
||||
- All PR comments are addressed and marked resolved
|
||||
- If there are migrations, they have been rebased to latest main
|
||||
- If there are new dependencies, they are added to the requirements
|
||||
- If there are new environment variables, they are added to all of the deployment methods
|
||||
- If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- Docker images build and basic functionalities work
|
||||
- Author has done a final read through of the PR right before merge
|
||||
|
||||
## Backporting (check the box to trigger backport action)
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
|
||||
@@ -66,7 +66,6 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
14
.github/workflows/pr-python-connector-tests.yml
vendored
14
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -26,19 +26,7 @@ env:
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
# Slab
|
||||
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
|
||||
# Zendesk
|
||||
ZENDESK_SUBDOMAIN: ${{ secrets.ZENDESK_SUBDOMAIN }}
|
||||
ZENDESK_EMAIL: ${{ secrets.ZENDESK_EMAIL }}
|
||||
ZENDESK_TOKEN: ${{ secrets.ZENDESK_TOKEN }}
|
||||
# Salesforce
|
||||
SF_USERNAME: ${{ secrets.SF_USERNAME }}
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
# Airtable
|
||||
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
|
||||
18
README.md
18
README.md
@@ -3,7 +3,7 @@
|
||||
<a name="readme-top"></a>
|
||||
|
||||
<h2 align="center">
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/OnyxLogoCropped.jpg?raw=true)" /></a>
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/LogoOnyx.png?raw=true)" /></a>
|
||||
</h2>
|
||||
|
||||
<p align="center">
|
||||
@@ -13,7 +13,7 @@
|
||||
<a href="https://docs.onyx.app/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
@@ -24,7 +24,7 @@
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
|
||||
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
|
||||
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
|
||||
@@ -133,3 +133,15 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#onyx-dot-app/onyx&Date)
|
||||
|
||||
## ✨Contributors
|
||||
|
||||
<a href="https://github.com/onyx-dot-app/onyx/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=onyx-dot-app/onyx"/>
|
||||
</a>
|
||||
|
||||
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
||||
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
|
||||
↑ Back to Top ↑
|
||||
</a>
|
||||
</p>
|
||||
|
||||
1
backend/.gitignore
vendored
1
backend/.gitignore
vendored
@@ -9,4 +9,3 @@ api_keys.py
|
||||
vespa-app.zip
|
||||
dynamic_config_storage/
|
||||
celerybeat-schedule*
|
||||
onyx/connectors/salesforce/data/
|
||||
@@ -4,7 +4,7 @@ from onyx.configs.app_configs import USE_IAM_AUTH
|
||||
from onyx.configs.app_configs import POSTGRES_HOST
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.app_configs import AWS_REGION_NAME
|
||||
from onyx.configs.app_configs import AWS_REGION
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from sqlalchemy import event
|
||||
@@ -120,7 +120,7 @@ def provide_iam_token_for_alembic(
|
||||
) -> None:
|
||||
if USE_IAM_AUTH:
|
||||
# Database connection settings
|
||||
region = AWS_REGION_NAME
|
||||
region = AWS_REGION
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
"""add chunk count to document
|
||||
|
||||
Revision ID: 2955778aa44c
|
||||
Revises: c0aab6edb6dd
|
||||
Create Date: 2025-01-04 11:39:43.268612
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2955778aa44c"
|
||||
down_revision = "c0aab6edb6dd"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("document", sa.Column("chunk_count", sa.Integer(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("document", "chunk_count")
|
||||
@@ -6,7 +6,6 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.user_group import delete_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_group
|
||||
from ee.onyx.db.user_group import mark_user_group_as_synced
|
||||
from ee.onyx.db.user_group import prepare_user_group_for_deletion
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -47,20 +46,11 @@ def monitor_usergroup_taskset(
|
||||
|
||||
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
|
||||
if user_group:
|
||||
usergroup_name = user_group.name
|
||||
if user_group.is_up_for_deletion:
|
||||
# this prepare should have been run when the deletion was scheduled,
|
||||
# but run it again to be sure we're ready to go
|
||||
mark_user_group_as_synced(db_session, user_group)
|
||||
prepare_user_group_for_deletion(db_session, usergroup_id)
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(
|
||||
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
|
||||
)
|
||||
task_logger.info(f"Deleted usergroup. id='{usergroup_id}'")
|
||||
else:
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(
|
||||
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
|
||||
)
|
||||
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
|
||||
|
||||
rug.reset()
|
||||
|
||||
@@ -15,12 +15,6 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
|
||||
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
# This is a boolean that determines if anonymous access is public
|
||||
# Default behavior is to not make the page public and instead add a group
|
||||
# that contains all the users that we found in Confluence
|
||||
CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
|
||||
os.environ.get("CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC", "").lower() == "true"
|
||||
)
|
||||
# In seconds, default is 5 minutes
|
||||
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
|
||||
@@ -2,7 +2,6 @@ import datetime
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import case
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy import Date
|
||||
@@ -15,9 +14,6 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatMessageFeedback
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
|
||||
|
||||
def fetch_query_analytics(
|
||||
@@ -238,121 +234,3 @@ def fetch_persona_unique_users(
|
||||
)
|
||||
|
||||
return [tuple(row) for row in db_session.execute(query).all()]
|
||||
|
||||
|
||||
def fetch_assistant_message_analytics(
|
||||
db_session: Session,
|
||||
assistant_id: int,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
) -> list[tuple[int, datetime.date]]:
|
||||
"""
|
||||
Gets the daily message counts for a specific assistant in the given time range.
|
||||
"""
|
||||
query = (
|
||||
select(
|
||||
func.count(ChatMessage.id),
|
||||
cast(ChatMessage.time_sent, Date),
|
||||
)
|
||||
.join(
|
||||
ChatSession,
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
)
|
||||
.group_by(cast(ChatMessage.time_sent, Date))
|
||||
.order_by(cast(ChatMessage.time_sent, Date))
|
||||
)
|
||||
|
||||
return [tuple(row) for row in db_session.execute(query).all()]
|
||||
|
||||
|
||||
def fetch_assistant_unique_users(
|
||||
db_session: Session,
|
||||
assistant_id: int,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
) -> list[tuple[int, datetime.date]]:
|
||||
"""
|
||||
Gets the daily unique user counts for a specific assistant in the given time range.
|
||||
"""
|
||||
query = (
|
||||
select(
|
||||
func.count(func.distinct(ChatSession.user_id)),
|
||||
cast(ChatMessage.time_sent, Date),
|
||||
)
|
||||
.join(
|
||||
ChatSession,
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
)
|
||||
.group_by(cast(ChatMessage.time_sent, Date))
|
||||
.order_by(cast(ChatMessage.time_sent, Date))
|
||||
)
|
||||
|
||||
return [tuple(row) for row in db_session.execute(query).all()]
|
||||
|
||||
|
||||
def fetch_assistant_unique_users_total(
|
||||
db_session: Session,
|
||||
assistant_id: int,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
) -> int:
|
||||
"""
|
||||
Gets the total number of distinct users who have sent or received messages from
|
||||
the specified assistant in the given time range.
|
||||
"""
|
||||
query = (
|
||||
select(func.count(func.distinct(ChatSession.user_id)))
|
||||
.select_from(ChatMessage)
|
||||
.join(
|
||||
ChatSession,
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
)
|
||||
)
|
||||
|
||||
result = db_session.execute(query).scalar()
|
||||
return result if result else 0
|
||||
|
||||
|
||||
# Users can view assistant stats if they created the persona,
|
||||
# or if they are an admin
|
||||
def user_can_view_assistant_stats(
|
||||
db_session: Session, user: User | None, assistant_id: int
|
||||
) -> bool:
|
||||
# If user is None, assume the user is an admin or auth is disabled
|
||||
if user is None or user.role == UserRole.ADMIN:
|
||||
return True
|
||||
|
||||
# Check if the user created the persona
|
||||
stmt = select(Persona).where(
|
||||
and_(Persona.id == assistant_id, Persona.user_id == user.id)
|
||||
)
|
||||
|
||||
persona = db_session.execute(stmt).scalar_one_or_none()
|
||||
return persona is not None
|
||||
|
||||
@@ -122,7 +122,7 @@ def _cleanup_document_set__user_group_relationships__no_commit(
|
||||
)
|
||||
|
||||
|
||||
def validate_object_creation_for_user(
|
||||
def validate_user_creation_permissions(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
target_group_ids: list[int] | None = None,
|
||||
@@ -440,108 +440,32 @@ def remove_curator_status__no_commit(db_session: Session, user: User) -> None:
|
||||
_validate_curator_status__no_commit(db_session, [user])
|
||||
|
||||
|
||||
def _validate_curator_relationship_update_requester(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
user_making_change: User | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
This function validates that the user making the change has the necessary permissions
|
||||
to update the curator relationship for the target user in the given user group.
|
||||
"""
|
||||
|
||||
if user_making_change is None or user_making_change.role == UserRole.ADMIN:
|
||||
return
|
||||
|
||||
# check if the user making the change is a curator in the group they are changing the curator relationship for
|
||||
user_making_change_curator_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user_making_change.id,
|
||||
# only check if the user making the change is a curator if they are a curator
|
||||
# otherwise, they are a global_curator and can update the curator relationship
|
||||
# for any group they are a member of
|
||||
only_curator_groups=user_making_change.role == UserRole.CURATOR,
|
||||
)
|
||||
requestor_curator_group_ids = [
|
||||
group.id for group in user_making_change_curator_groups
|
||||
]
|
||||
if user_group_id not in requestor_curator_group_ids:
|
||||
raise ValueError(
|
||||
f"user making change {user_making_change.email} is not a curator,"
|
||||
f" admin, or global_curator for group '{user_group_id}'"
|
||||
)
|
||||
|
||||
|
||||
def _validate_curator_relationship_update_request(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
target_user: User,
|
||||
) -> None:
|
||||
"""
|
||||
This function validates that the curator_relationship_update request itself is valid.
|
||||
"""
|
||||
if target_user.role == UserRole.ADMIN:
|
||||
raise ValueError(
|
||||
f"User '{target_user.email}' is an admin and therefore has all permissions "
|
||||
"of a curator. If you'd like this user to only have curator permissions, "
|
||||
"you must update their role to BASIC then assign them to be CURATOR in the "
|
||||
"appropriate groups."
|
||||
)
|
||||
elif target_user.role == UserRole.GLOBAL_CURATOR:
|
||||
raise ValueError(
|
||||
f"User '{target_user.email}' is a global_curator and therefore has all "
|
||||
"permissions of a curator for all groups. If you'd like this user to only "
|
||||
"have curator permissions for a specific group, you must update their role "
|
||||
"to BASIC then assign them to be CURATOR in the appropriate groups."
|
||||
)
|
||||
elif target_user.role not in [UserRole.CURATOR, UserRole.BASIC]:
|
||||
raise ValueError(
|
||||
f"This endpoint can only be used to update the curator relationship for "
|
||||
"users with the CURATOR or BASIC role. \n"
|
||||
f"Target user: {target_user.email} \n"
|
||||
f"Target user role: {target_user.role} \n"
|
||||
)
|
||||
|
||||
# check if the target user is in the group they are changing the curator relationship for
|
||||
requested_user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=target_user.id,
|
||||
only_curator_groups=False,
|
||||
)
|
||||
group_ids = [group.id for group in requested_user_groups]
|
||||
if user_group_id not in group_ids:
|
||||
raise ValueError(
|
||||
f"target user {target_user.email} is not in group '{user_group_id}'"
|
||||
)
|
||||
|
||||
|
||||
def update_user_curator_relationship(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
user_making_change: User | None = None,
|
||||
) -> None:
|
||||
target_user = fetch_user_by_id(db_session, set_curator_request.user_id)
|
||||
if not target_user:
|
||||
user = fetch_user_by_id(db_session, set_curator_request.user_id)
|
||||
if not user:
|
||||
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
|
||||
|
||||
_validate_curator_relationship_update_request(
|
||||
if user.role == UserRole.ADMIN:
|
||||
raise ValueError(
|
||||
f"User '{user.email}' is an admin and therefore has all permissions "
|
||||
"of a curator. If you'd like this user to only have curator permissions, "
|
||||
"you must update their role to BASIC then assign them to be CURATOR in the "
|
||||
"appropriate groups."
|
||||
)
|
||||
|
||||
requested_user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
target_user=target_user,
|
||||
user_id=set_curator_request.user_id,
|
||||
only_curator_groups=False,
|
||||
)
|
||||
|
||||
_validate_curator_relationship_update_requester(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
user_making_change=user_making_change,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"user_making_change={user_making_change.email if user_making_change else 'None'} is "
|
||||
f"updating the curator relationship for user={target_user.email} "
|
||||
f"in group={user_group_id} to is_curator={set_curator_request.is_curator}"
|
||||
)
|
||||
group_ids = [group.id for group in requested_user_groups]
|
||||
if user_group_id not in group_ids:
|
||||
raise ValueError(f"user is not in group '{user_group_id}'")
|
||||
|
||||
relationship_to_update = (
|
||||
db_session.query(User__UserGroup)
|
||||
@@ -562,7 +486,7 @@ def update_user_curator_relationship(
|
||||
)
|
||||
db_session.add(relationship_to_update)
|
||||
|
||||
_validate_curator_status__no_commit(db_session, [target_user])
|
||||
_validate_curator_status__no_commit(db_session, [user])
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
# This is a group that we use to store all the users that we found in Confluence
|
||||
# Instead of setting a page to public, we just add this group so that the page
|
||||
# is only accessible to users who have confluence accounts.
|
||||
ALL_CONF_EMAILS_GROUP_NAME = "All_Confluence_Users_Found_By_Onyx"
|
||||
@@ -4,8 +4,6 @@ https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.htm
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
@@ -33,32 +31,14 @@ def _get_server_space_permissions(
|
||||
permission_category.get("spacePermissions", [])
|
||||
)
|
||||
|
||||
is_public = False
|
||||
user_names = set()
|
||||
group_names = set()
|
||||
for permission in viewspace_permissions:
|
||||
user_name = permission.get("userName")
|
||||
if user_name:
|
||||
if user_name := permission.get("userName"):
|
||||
user_names.add(user_name)
|
||||
group_name = permission.get("groupName")
|
||||
if group_name:
|
||||
if group_name := permission.get("groupName"):
|
||||
group_names.add(group_name)
|
||||
|
||||
# It seems that if anonymous access is turned on for the site and space,
|
||||
# then the space is publicly accessible.
|
||||
# For confluence server, we make a group that contains all users
|
||||
# that exist in confluence and then just add that group to the space permissions
|
||||
# if anonymous access is turned on for the site and space or we set is_public = True
|
||||
# if they set the env variable CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC to True so
|
||||
# that we can support confluence server deployments that want anonymous access
|
||||
# to be public (we cant test this because its paywalled)
|
||||
if user_name is None and group_name is None:
|
||||
# Defaults to False
|
||||
if CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC:
|
||||
is_public = True
|
||||
else:
|
||||
group_names.add(ALL_CONF_EMAILS_GROUP_NAME)
|
||||
|
||||
user_emails = set()
|
||||
for user_name in user_names:
|
||||
user_email = get_user_email_from_username__server(confluence_client, user_name)
|
||||
@@ -70,7 +50,11 @@ def _get_server_space_permissions(
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
is_public=is_public,
|
||||
# TODO: Check if the space is publicly accessible
|
||||
# Currently, we assume the space is not public
|
||||
# We need to check if anonymous access is turned on for the site and space
|
||||
# This information is paywalled so it remains unimplemented
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -150,7 +134,7 @@ def _get_space_permissions(
|
||||
|
||||
def _extract_read_access_restrictions(
|
||||
confluence_client: OnyxConfluence, restrictions: dict[str, Any]
|
||||
) -> tuple[set[str], set[str]]:
|
||||
) -> ExternalAccess | None:
|
||||
"""
|
||||
Converts a page's restrictions dict into an ExternalAccess object.
|
||||
If there are no restrictions, then return None
|
||||
@@ -193,57 +177,21 @@ def _extract_read_access_restrictions(
|
||||
group["name"] for group in read_access_group_jsons if group.get("name")
|
||||
]
|
||||
|
||||
return set(read_access_user_emails), set(read_access_group_names)
|
||||
|
||||
|
||||
def _get_all_page_restrictions(
|
||||
confluence_client: OnyxConfluence,
|
||||
perm_sync_data: dict[str, Any],
|
||||
) -> ExternalAccess | None:
|
||||
"""
|
||||
This function gets the restrictions for a page by taking the intersection
|
||||
of the page's restrictions and the restrictions of all the ancestors
|
||||
of the page.
|
||||
If the page/ancestor has no restrictions, then it is ignored (no intersection).
|
||||
If no restrictions are found anywhere, then return None, indicating that the page
|
||||
should inherit the space's restrictions.
|
||||
"""
|
||||
found_user_emails: set[str] = set()
|
||||
found_group_names: set[str] = set()
|
||||
|
||||
found_user_emails, found_group_names = _extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=perm_sync_data.get("restrictions", {}),
|
||||
)
|
||||
|
||||
ancestors: list[dict[str, Any]] = perm_sync_data.get("ancestors", [])
|
||||
for ancestor in ancestors:
|
||||
ancestor_user_emails, ancestor_group_names = _extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=ancestor.get("restrictions", {}),
|
||||
)
|
||||
if not ancestor_user_emails and not ancestor_group_names:
|
||||
# This ancestor has no restrictions, so it has no effect on
|
||||
# the page's restrictions, so we ignore it
|
||||
continue
|
||||
|
||||
found_user_emails.intersection_update(ancestor_user_emails)
|
||||
found_group_names.intersection_update(ancestor_group_names)
|
||||
|
||||
# If there are no restrictions found, then the page
|
||||
# inherits the space's restrictions so return None
|
||||
if not found_user_emails and not found_group_names:
|
||||
is_space_public = read_access_user_emails == [] and read_access_group_names == []
|
||||
if is_space_public:
|
||||
return None
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=found_user_emails,
|
||||
external_user_group_ids=found_group_names,
|
||||
external_user_emails=set(read_access_user_emails),
|
||||
external_user_group_ids=set(read_access_group_names),
|
||||
# there is no way for a page to be individually public if the space isn't public
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
def _fetch_all_page_restrictions(
|
||||
def _fetch_all_page_restrictions_for_space(
|
||||
confluence_client: OnyxConfluence,
|
||||
slim_docs: list[SlimDocument],
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
@@ -260,11 +208,11 @@ def _fetch_all_page_restrictions(
|
||||
raise ValueError(
|
||||
f"No permission sync data found for document {slim_doc.id}"
|
||||
)
|
||||
|
||||
if restrictions := _get_all_page_restrictions(
|
||||
restrictions = _extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
perm_sync_data=slim_doc.perm_sync_data,
|
||||
):
|
||||
restrictions=slim_doc.perm_sync_data.get("restrictions", {}),
|
||||
)
|
||||
if restrictions:
|
||||
document_restrictions.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
@@ -353,7 +301,7 @@ def confluence_doc_sync(
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
return _fetch_all_page_restrictions(
|
||||
return _fetch_all_page_restrictions_for_space(
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -53,7 +53,6 @@ def confluence_group_sync(
|
||||
confluence_client=confluence_client,
|
||||
)
|
||||
onyx_groups: list[ExternalUserGroup] = []
|
||||
all_found_emails = set()
|
||||
for group_id, group_member_emails in group_member_email_map.items():
|
||||
onyx_groups.append(
|
||||
ExternalUserGroup(
|
||||
@@ -61,15 +60,5 @@ def confluence_group_sync(
|
||||
user_emails=list(group_member_emails),
|
||||
)
|
||||
)
|
||||
all_found_emails.update(group_member_emails)
|
||||
|
||||
# This is so that when we find a public confleunce server page, we can
|
||||
# give access to all users only in if they have an email in Confluence
|
||||
if cc_pair.connector.connector_specific_config.get("is_cloud", False):
|
||||
all_found_group = ExternalUserGroup(
|
||||
id=ALL_CONF_EMAILS_GROUP_NAME,
|
||||
user_emails=list(all_found_emails),
|
||||
)
|
||||
onyx_groups.append(all_found_group)
|
||||
|
||||
return onyx_groups
|
||||
|
||||
@@ -40,7 +40,6 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.main import get_application as get_application_base
|
||||
from onyx.main import include_auth_router_with_prefix
|
||||
from onyx.main import include_router_with_global_prefix_prepended
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
@@ -63,7 +62,7 @@ def get_application() -> FastAPI:
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
oauth_client,
|
||||
@@ -75,17 +74,19 @@ def get_application() -> FastAPI:
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
|
||||
),
|
||||
prefix="/auth/oauth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# Need basic auth router for `logout` endpoint
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_logout_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.OIDC:
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
|
||||
@@ -96,20 +97,19 @@ def get_application() -> FastAPI:
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
|
||||
),
|
||||
prefix="/auth/oidc",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# need basic auth router for `logout` endpoint
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_auth_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
elif AUTH_TYPE == AuthType.SAML:
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
saml_router,
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, saml_router)
|
||||
|
||||
# RBAC / group access control
|
||||
include_router_with_global_prefix_prepended(application, user_group_router)
|
||||
|
||||
@@ -1,24 +1,17 @@
|
||||
import datetime
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.analytics import fetch_assistant_message_analytics
|
||||
from ee.onyx.db.analytics import fetch_assistant_unique_users
|
||||
from ee.onyx.db.analytics import fetch_assistant_unique_users_total
|
||||
from ee.onyx.db.analytics import fetch_onyxbot_analytics
|
||||
from ee.onyx.db.analytics import fetch_per_user_query_analytics
|
||||
from ee.onyx.db.analytics import fetch_persona_message_analytics
|
||||
from ee.onyx.db.analytics import fetch_persona_unique_users
|
||||
from ee.onyx.db.analytics import fetch_query_analytics
|
||||
from ee.onyx.db.analytics import user_can_view_assistant_stats
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
|
||||
@@ -198,74 +191,3 @@ def get_persona_unique_users(
|
||||
)
|
||||
)
|
||||
return unique_user_counts
|
||||
|
||||
|
||||
class AssistantDailyUsageResponse(BaseModel):
|
||||
date: datetime.date
|
||||
total_messages: int
|
||||
total_unique_users: int
|
||||
|
||||
|
||||
class AssistantStatsResponse(BaseModel):
|
||||
daily_stats: List[AssistantDailyUsageResponse]
|
||||
total_messages: int
|
||||
total_unique_users: int
|
||||
|
||||
|
||||
@router.get("/assistant/{assistant_id}/stats")
|
||||
def get_assistant_stats(
|
||||
assistant_id: int,
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantStatsResponse:
|
||||
"""
|
||||
Returns daily message and unique user counts for a user's assistant,
|
||||
along with the overall total messages and total distinct users.
|
||||
"""
|
||||
start = start or (
|
||||
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
|
||||
)
|
||||
end = end or datetime.datetime.utcnow()
|
||||
|
||||
if not user_can_view_assistant_stats(db_session, user, assistant_id):
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Not allowed to access this assistant's stats."
|
||||
)
|
||||
|
||||
# Pull daily usage from the DB calls
|
||||
messages_data = fetch_assistant_message_analytics(
|
||||
db_session, assistant_id, start, end
|
||||
)
|
||||
unique_users_data = fetch_assistant_unique_users(
|
||||
db_session, assistant_id, start, end
|
||||
)
|
||||
|
||||
# Map each day => (messages, unique_users).
|
||||
daily_messages_map = {date: count for count, date in messages_data}
|
||||
daily_unique_users_map = {date: count for count, date in unique_users_data}
|
||||
all_dates = set(daily_messages_map.keys()) | set(daily_unique_users_map.keys())
|
||||
|
||||
# Merge both sets of metrics by date
|
||||
daily_results: list[AssistantDailyUsageResponse] = []
|
||||
for date in sorted(all_dates):
|
||||
daily_results.append(
|
||||
AssistantDailyUsageResponse(
|
||||
date=date,
|
||||
total_messages=daily_messages_map.get(date, 0),
|
||||
total_unique_users=daily_unique_users_map.get(date, 0),
|
||||
)
|
||||
)
|
||||
|
||||
# Now pull a single total distinct user count across the entire time range
|
||||
total_msgs = sum(d.total_messages for d in daily_results)
|
||||
total_users = fetch_assistant_unique_users_total(
|
||||
db_session, assistant_id, start, end
|
||||
)
|
||||
|
||||
return AssistantStatsResponse(
|
||||
daily_stats=daily_results,
|
||||
total_messages=total_msgs,
|
||||
total_unique_users=total_users,
|
||||
)
|
||||
|
||||
@@ -2,14 +2,15 @@ import logging
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Callable
|
||||
|
||||
import jwt
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
|
||||
from onyx.auth.api_key import extract_tenant_from_api_key_header
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.db.engine import is_valid_schema_name
|
||||
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
@@ -21,11 +22,11 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
try:
|
||||
if MULTI_TENANT:
|
||||
tenant_id = await _get_tenant_id_from_request(request, logger)
|
||||
else:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
tenant_id = (
|
||||
_get_tenant_id_from_request(request, logger)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
return await call_next(request)
|
||||
|
||||
@@ -34,36 +35,27 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
raise
|
||||
|
||||
|
||||
async def _get_tenant_id_from_request(
|
||||
request: Request, logger: logging.LoggerAdapter
|
||||
) -> str:
|
||||
"""
|
||||
Attempt to extract tenant_id from:
|
||||
1) The API key header
|
||||
2) The Redis-based token (stored in Cookie: fastapiusersauth)
|
||||
Fallback: POSTGRES_DEFAULT_SCHEMA
|
||||
"""
|
||||
# Check for API key
|
||||
def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) -> str:
|
||||
# First check for API key
|
||||
tenant_id = extract_tenant_from_api_key_header(request)
|
||||
if tenant_id:
|
||||
if tenant_id is not None:
|
||||
return tenant_id
|
||||
|
||||
# Check for cookie-based auth
|
||||
token = request.cookies.get("fastapiusersauth")
|
||||
if not token:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
try:
|
||||
# Look up token data in Redis
|
||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
USER_AUTH_SECRET,
|
||||
audience=["fastapi-users:auth"],
|
||||
algorithms=["HS256"],
|
||||
)
|
||||
tenant_id_from_payload = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
if not token_data:
|
||||
logger.debug(
|
||||
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
|
||||
)
|
||||
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
|
||||
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
|
||||
# so we maintain consistency by returning it here when no valid tenant is found.
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
# Since token_data.get() can return None, ensure we have a string
|
||||
# Since payload.get() can return None, ensure we have a string
|
||||
tenant_id = (
|
||||
str(tenant_id_from_payload)
|
||||
if tenant_id_from_payload is not None
|
||||
@@ -75,6 +67,9 @@ async def _get_tenant_id_from_request(
|
||||
|
||||
return tenant_id
|
||||
|
||||
except jwt.InvalidTokenError:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
|
||||
logger.error(f"Unexpected error in set_tenant_id_middleware: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
@@ -12,29 +10,11 @@ from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
|
||||
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
@@ -82,7 +62,14 @@ class SlackOAuth:
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={cls.REDIRECT_URI}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
@@ -90,14 +77,10 @@ class SlackOAuth:
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&redirect_uri={cls.DEV_REDIRECT_URI}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
@@ -119,151 +102,82 @@ class SlackOAuth:
|
||||
return session
|
||||
|
||||
|
||||
class ConfluenceCloudOAuth:
|
||||
"""work in progress"""
|
||||
# Work in progress
|
||||
# class ConfluenceCloudOAuth:
|
||||
# """work in progress"""
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
# # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
# class OAuthSession(BaseModel):
|
||||
# """Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
# email: str
|
||||
# redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
# CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
|
||||
# CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
# TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
|
||||
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
CONFLUENCE_OAUTH_SCOPE = (
|
||||
"read:confluence-props%20"
|
||||
"read:confluence-content.all%20"
|
||||
"read:confluence-content.summary%20"
|
||||
"read:confluence-content.permission%20"
|
||||
"read:confluence-user%20"
|
||||
"read:confluence-groups%20"
|
||||
"readonly:content.attachment:confluence"
|
||||
)
|
||||
# # All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
# CONFLUENCE_OAUTH_SCOPE = (
|
||||
# "read:confluence-props%20"
|
||||
# "read:confluence-content.all%20"
|
||||
# "read:confluence-content.summary%20"
|
||||
# "read:confluence-content.permission%20"
|
||||
# "read:confluence-user%20"
|
||||
# "read:confluence-groups%20"
|
||||
# "readonly:content.attachment:confluence"
|
||||
# )
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
# REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
# DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
# eventually for Confluence Data Center
|
||||
# oauth_url = (
|
||||
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# f"&redirect_uri={redirectme_uri}"
|
||||
# )
|
||||
# # eventually for Confluence Data Center
|
||||
# # oauth_url = (
|
||||
# # f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# # f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# # f"&redirect_uri={redirectme_uri}"
|
||||
# # )
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
# @classmethod
|
||||
# def generate_oauth_url(cls, state: str) -> str:
|
||||
# return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
# @classmethod
|
||||
# def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
# """dev mode workaround for localhost testing
|
||||
# - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
# """
|
||||
# return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
"https://auth.atlassian.com/authorize"
|
||||
f"?audience=api.atlassian.com"
|
||||
f"&client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
f"&state={state}"
|
||||
"&response_type=code"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
# @classmethod
|
||||
# def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# url = (
|
||||
# "https://auth.atlassian.com/authorize"
|
||||
# f"?audience=api.atlassian.com"
|
||||
# f"&client_id={cls.CLIENT_ID}"
|
||||
# f"&redirect_uri={redirect_uri}"
|
||||
# f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
# f"&state={state}"
|
||||
# "&response_type=code"
|
||||
# "&prompt=consent"
|
||||
# )
|
||||
# return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = ConfluenceCloudOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
# @classmethod
|
||||
# def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
# """Temporary state to store in redis. to be looked up on auth response.
|
||||
# Returns a json string.
|
||||
# """
|
||||
# session = ConfluenceCloudOAuth.OAuthSession(
|
||||
# email=email, redirect_on_success=redirect_on_success
|
||||
# )
|
||||
# return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
class GoogleDriveOAuth:
|
||||
# https://developers.google.com/identity/protocols/oauth2
|
||||
# https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# SCOPE is per https://docs.onyx.app/connectors/google-drive
|
||||
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
|
||||
SCOPE = (
|
||||
"https://www.googleapis.com/auth/drive.readonly%20"
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# without prompt=consent, a refresh token is only issued the first time the user approves
|
||||
url = (
|
||||
f"https://accounts.google.com/o/oauth2/v2/auth"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
"&response_type=code"
|
||||
f"&scope={cls.SCOPE}"
|
||||
"&access_type=offline"
|
||||
f"&state={state}"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = GoogleDriveOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
# @classmethod
|
||||
# def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
|
||||
# session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
# return session
|
||||
|
||||
|
||||
@router.post("/prepare-authorization-request")
|
||||
@@ -278,11 +192,8 @@ def prepare_authorization_request(
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# urlsafe b64 encode the uuid for the oauth url
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
@@ -292,11 +203,6 @@ def prepare_authorization_request(
|
||||
session = SlackOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
|
||||
session = GoogleDriveOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
# elif connector == DocumentSource.CONFLUENCE:
|
||||
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
|
||||
# session = ConfluenceCloudOAuth.session_dump_json(
|
||||
@@ -304,6 +210,8 @@ def prepare_authorization_request(
|
||||
# )
|
||||
# elif connector == DocumentSource.JIRA:
|
||||
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
|
||||
# elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
# oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = None
|
||||
|
||||
@@ -315,7 +223,6 @@ def prepare_authorization_request(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# store important session state to retrieve when the user is redirected back
|
||||
# 10 min is the max we want an oauth flow to be valid
|
||||
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
|
||||
|
||||
@@ -514,116 +421,3 @@ def handle_slack_oauth_callback(
|
||||
# "redirect_on_success": session.redirect_on_success,
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
@router.post("/connector/google-drive/callback")
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = GoogleDriveOAuth.parse_session(session_json)
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
GoogleDriveOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": GoogleDriveOAuth.CLIENT_ID,
|
||||
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
authorization_response: dict[str, Any] = response.json()
|
||||
|
||||
# the connector wants us to store the json in its authorized_user_info format
|
||||
# returned from OAuthCredentials.get_authorized_user_info().
|
||||
# So refresh immediately via get_google_oauth_creds with the params filled in
|
||||
# from fields in authorization_response to get the json we need
|
||||
authorized_user_info = {}
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
|
||||
)
|
||||
if not oauth_creds:
|
||||
raise RuntimeError("get_google_oauth_creds returned None.")
|
||||
|
||||
# save off the credentials
|
||||
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
|
||||
credential_dict: dict[str, str] = {}
|
||||
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
|
||||
credential_dict[
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD
|
||||
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
name="OAuth (interactive)",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Google Drive OAuth completed successfully.",
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -13,8 +13,9 @@ from ee.onyx.db.usage_export import get_all_empty_chat_message_entries
|
||||
from ee.onyx.db.usage_export import write_usage_report
|
||||
from ee.onyx.server.reporting.usage_export_models import UsageReportMetadata
|
||||
from ee.onyx.server.reporting.usage_export_models import UserSkeleton
|
||||
from onyx.auth.schemas import UserStatus
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.users import get_all_users
|
||||
from onyx.db.users import list_users
|
||||
from onyx.file_store.constants import MAX_IN_MEMORY_SIZE
|
||||
from onyx.file_store.file_store import FileStore
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -83,15 +84,15 @@ def generate_user_report(
|
||||
max_size=MAX_IN_MEMORY_SIZE, mode="w+"
|
||||
) as temp_file:
|
||||
csvwriter = csv.writer(temp_file, delimiter=",")
|
||||
csvwriter.writerow(["user_id", "is_active"])
|
||||
csvwriter.writerow(["user_id", "status"])
|
||||
|
||||
users = get_all_users(db_session)
|
||||
users = list_users(db_session)
|
||||
for user in users:
|
||||
user_skeleton = UserSkeleton(
|
||||
user_id=str(user.id),
|
||||
is_active=user.is_active,
|
||||
status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED,
|
||||
)
|
||||
csvwriter.writerow([user_skeleton.user_id, user_skeleton.is_active])
|
||||
csvwriter.writerow([user_skeleton.user_id, user_skeleton.status])
|
||||
|
||||
temp_file.seek(0)
|
||||
file_store.save_file(
|
||||
|
||||
@@ -4,6 +4,8 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.schemas import UserStatus
|
||||
|
||||
|
||||
class FlowType(str, Enum):
|
||||
CHAT = "chat"
|
||||
@@ -20,7 +22,7 @@ class ChatMessageSkeleton(BaseModel):
|
||||
|
||||
class UserSkeleton(BaseModel):
|
||||
user_id: str
|
||||
is_active: bool
|
||||
status: UserStatus
|
||||
|
||||
|
||||
class UsageReportMetadata(BaseModel):
|
||||
|
||||
@@ -3,7 +3,6 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
@@ -13,23 +12,15 @@ from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_redis_strategy
|
||||
from onyx.auth.users import get_jwt_strategy
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.settings.store import store_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -112,7 +103,7 @@ async def impersonate_user(
|
||||
)
|
||||
if user_to_impersonate is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
token = await get_redis_strategy().write_token(user_to_impersonate)
|
||||
token = await get_jwt_strategy().write_token(user_to_impersonate)
|
||||
|
||||
response = await auth_backend.transport.get_login_response(token)
|
||||
response.set_cookie(
|
||||
@@ -123,48 +114,3 @@ async def impersonate_user(
|
||||
samesite="lax",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/leave-organization")
|
||||
async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> None:
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
)
|
||||
|
||||
user_to_delete = get_user_by_email(user_email.user_email, db_session)
|
||||
if user_to_delete is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
num_admin_users = await get_user_count(only_admin_users=True)
|
||||
|
||||
should_delete_tenant = num_admin_users == 1
|
||||
|
||||
if should_delete_tenant:
|
||||
logger.info(
|
||||
"Last admin user is leaving the organization. Deleting tenant from control plane."
|
||||
)
|
||||
try:
|
||||
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
|
||||
logger.debug("User deleted from control plane")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to remove user from control plane: {str(e)}",
|
||||
)
|
||||
|
||||
db_session.expunge(user_to_delete)
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
|
||||
if should_delete_tenant:
|
||||
remove_all_users_from_tenant(tenant_id)
|
||||
else:
|
||||
remove_users_from_tenant([user_to_delete.email], tenant_id)
|
||||
|
||||
@@ -46,7 +46,6 @@ def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscr
|
||||
"""
|
||||
Send a request to the control service to register the number of users for a tenant.
|
||||
"""
|
||||
|
||||
if not STRIPE_PRICE_ID:
|
||||
raise Exception("STRIPE_PRICE_ID is not set")
|
||||
|
||||
|
||||
@@ -39,8 +39,3 @@ class TenantCreationPayload(BaseModel):
|
||||
tenant_id: str
|
||||
email: str
|
||||
referral_source: str | None = None
|
||||
|
||||
|
||||
class TenantDeletionPayload(BaseModel):
|
||||
tenant_id: str
|
||||
email: str
|
||||
|
||||
@@ -15,7 +15,6 @@ from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
|
||||
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import TenantCreationPayload
|
||||
from ee.onyx.server.tenants.models import TenantDeletionPayload
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
from ee.onyx.server.tenants.schema_management import drop_schema
|
||||
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
|
||||
@@ -186,7 +185,6 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
try:
|
||||
# Drop the tenant's schema to rollback provisioning
|
||||
drop_schema(tenant_id)
|
||||
|
||||
# Remove tenant mapping
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
@@ -322,26 +320,3 @@ async def submit_to_hubspot(
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to submit to HubSpot: {response.text}")
|
||||
|
||||
|
||||
async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = TenantDeletionPayload(tenant_id=tenant_id, email=email)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.delete(
|
||||
f"{CONTROL_PLANE_API_BASE_URL}/tenants/delete",
|
||||
headers=headers,
|
||||
json=payload.model_dump(),
|
||||
) as response:
|
||||
print(response)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Control plane tenant creation failed: {error_text}")
|
||||
raise Exception(
|
||||
f"Failed to delete tenant on control plane: {error_text}"
|
||||
)
|
||||
|
||||
@@ -68,11 +68,3 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
|
||||
)
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
@@ -83,7 +83,7 @@ def patch_user_group(
|
||||
def set_user_curator(
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
@@ -91,7 +91,6 @@ def set_user_curator(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
set_curator_request=set_curator_request,
|
||||
user_making_change=user,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error setting user curator: {e}")
|
||||
|
||||
@@ -10,7 +10,6 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def posthog_on_error(error: Any, items: Any) -> None:
|
||||
"""Log any PostHog delivery errors."""
|
||||
logger.error(f"PostHog error: {error}, items: {items}")
|
||||
|
||||
|
||||
@@ -25,10 +24,15 @@ posthog = Posthog(
|
||||
def event_telemetry(
|
||||
distinct_id: str, event: str, properties: dict | None = None
|
||||
) -> None:
|
||||
"""Capture and send an event to PostHog, flushing immediately."""
|
||||
logger.info(f"Capturing PostHog event: {distinct_id} {event} {properties}")
|
||||
logger.info(f"Capturing Posthog event: {distinct_id} {event} {properties}")
|
||||
print("API KEY", POSTHOG_API_KEY)
|
||||
print("HOST", POSTHOG_HOST)
|
||||
try:
|
||||
posthog.capture(distinct_id, event, properties)
|
||||
print(type(distinct_id))
|
||||
print(type(event))
|
||||
print(type(properties))
|
||||
response = posthog.capture(distinct_id, event, properties)
|
||||
posthog.flush()
|
||||
print(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing PostHog event: {e}")
|
||||
logger.error(f"Error capturing Posthog event: {e}")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
@@ -321,6 +320,8 @@ async def embed_text(
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
) -> list[Embedding]:
|
||||
logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}")
|
||||
|
||||
if not all(texts):
|
||||
logger.error("Empty strings provided for embedding")
|
||||
raise ValueError("Empty strings are not allowed for embedding.")
|
||||
@@ -329,17 +330,8 @@ async def embed_text(
|
||||
logger.error("No texts provided for embedding")
|
||||
raise ValueError("No texts provided for embedding.")
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
total_chars = 0
|
||||
for text in texts:
|
||||
total_chars += len(text)
|
||||
|
||||
if provider_type is not None:
|
||||
logger.info(
|
||||
f"Embedding {len(texts)} texts with {total_chars} total characters with provider: {provider_type}"
|
||||
)
|
||||
|
||||
logger.debug(f"Using cloud provider {provider_type} for embedding")
|
||||
if api_key is None:
|
||||
logger.error("API key not provided for cloud model")
|
||||
raise RuntimeError("API key not provided for cloud model")
|
||||
@@ -371,16 +363,8 @@ async def embed_text(
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info(
|
||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||
f"with provider {provider_type} in {elapsed:.2f}"
|
||||
)
|
||||
elif model_name is not None:
|
||||
logger.info(
|
||||
f"Embedding {len(texts)} texts with {total_chars} total characters with local model: {model_name}"
|
||||
)
|
||||
|
||||
logger.debug(f"Using local model {model_name} for embedding")
|
||||
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
|
||||
|
||||
local_model = get_embedding_model(
|
||||
@@ -398,17 +382,13 @@ async def embed_text(
|
||||
for embedding in embeddings_vectors
|
||||
]
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info(
|
||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||
f"with local model {model_name} in {elapsed:.2f}"
|
||||
)
|
||||
else:
|
||||
logger.error("Neither model name nor provider specified for embedding")
|
||||
raise ValueError(
|
||||
"Either model name or provider must be provided to run embeddings."
|
||||
)
|
||||
|
||||
logger.info(f"Successfully embedded {len(texts)} texts")
|
||||
return embeddings
|
||||
|
||||
|
||||
@@ -460,8 +440,7 @@ async def process_embed_request(
|
||||
) -> EmbedResponse:
|
||||
if not embed_request.texts:
|
||||
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
||||
|
||||
if not all(embed_request.texts):
|
||||
elif not all(embed_request.texts):
|
||||
raise ValueError("Empty strings are not allowed for embedding.")
|
||||
|
||||
try:
|
||||
@@ -492,12 +471,9 @@ async def process_embed_request(
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error during embedding process: provider={embed_request.provider_type} model={embed_request.model_name}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error during embedding process: {e}"
|
||||
)
|
||||
exception_detail = f"Error during embedding process:\n{str(e)}"
|
||||
logger.exception(exception_detail)
|
||||
raise HTTPException(status_code=500, detail=exception_detail)
|
||||
|
||||
|
||||
@router.post("/cross-encoder-scores")
|
||||
|
||||
@@ -44,7 +44,6 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
|
||||
the files in the existing huggingface cache that don't exist in the temp
|
||||
huggingface cache.
|
||||
"""
|
||||
|
||||
for item in source.iterdir():
|
||||
target_path = dest / item.relative_to(source)
|
||||
if item.is_dir():
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
import smtplib
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from textwrap import dedent
|
||||
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
from onyx.configs.app_configs import SMTP_PASS
|
||||
from onyx.configs.app_configs import SMTP_PORT
|
||||
from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
def send_email(
|
||||
user_email: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
if not EMAIL_CONFIGURED:
|
||||
raise ValueError("Email is not configured.")
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
msg.attach(MIMEText(body))
|
||||
|
||||
try:
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
s.starttls()
|
||||
s.login(SMTP_USER, SMTP_PASS)
|
||||
s.send_message(msg)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
||||
subject = "Invitation to Join Onyx Organization"
|
||||
body = dedent(
|
||||
f"""\
|
||||
Hello,
|
||||
|
||||
You have been invited to join an organization on Onyx.
|
||||
|
||||
To join the organization, please visit the following link:
|
||||
|
||||
{WEB_DOMAIN}/auth/signup?email={user_email}
|
||||
|
||||
You'll be asked to set a password or login with Google to complete your registration.
|
||||
|
||||
Best regards,
|
||||
The Onyx Team
|
||||
"""
|
||||
)
|
||||
|
||||
send_email(user_email, subject, body, current_user.email)
|
||||
|
||||
|
||||
def send_forgot_password_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
body = f"Click the following link to reset your password: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
subject = "Onyx Email Verification"
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
body = f"Click the following link to verify your email address: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
@@ -30,16 +30,13 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
)
|
||||
|
||||
|
||||
def fetch_no_auth_user(
|
||||
store: KeyValueStore, *, anonymous_user_enabled: bool | None = None
|
||||
) -> UserInfo:
|
||||
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
|
||||
return UserInfo(
|
||||
id=NO_AUTH_USER_ID,
|
||||
email=NO_AUTH_USER_EMAIL,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
|
||||
role=UserRole.ADMIN,
|
||||
preferences=load_no_auth_user_preferences(store),
|
||||
is_anonymous_user=anonymous_user_enabled,
|
||||
)
|
||||
|
||||
@@ -33,6 +33,12 @@ class UserRole(str, Enum):
|
||||
]
|
||||
|
||||
|
||||
class UserStatus(str, Enum):
|
||||
LIVE = "live"
|
||||
INVITED = "invited"
|
||||
DEACTIVATED = "deactivated"
|
||||
|
||||
|
||||
class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
role: UserRole
|
||||
|
||||
@@ -43,7 +49,4 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
"""
|
||||
Role updates are not allowed through the user update endpoint for security reasons
|
||||
Role changes should be handled through a separate, admin-only process
|
||||
"""
|
||||
role: UserRole
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import json
|
||||
import secrets
|
||||
import smtplib
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
@@ -31,8 +32,10 @@ from fastapi_users import schemas
|
||||
from fastapi_users import UUIDIDMixin
|
||||
from fastapi_users.authentication import AuthenticationBackend
|
||||
from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication import RedisStrategy
|
||||
from fastapi_users.authentication import JWTStrategy
|
||||
from fastapi_users.authentication import Strategy
|
||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.exceptions import UserAlreadyExists
|
||||
from fastapi_users.jwt import decode_jwt
|
||||
from fastapi_users.jwt import generate_jwt
|
||||
@@ -46,22 +49,23 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.email_utils import send_forgot_password_email
|
||||
from onyx.auth.email_utils import send_user_verification_email
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserUpdate
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import REDIS_AUTH_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import SMTP_PASS
|
||||
from onyx.configs.app_configs import SMTP_PORT
|
||||
from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
@@ -70,10 +74,10 @@ from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.api_key import fetch_user_for_api_key
|
||||
from onyx.db.auth import get_access_token_db
|
||||
from onyx.db.auth import get_default_admin_user_emails
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.auth import get_user_db
|
||||
@@ -81,11 +85,11 @@ from onyx.db.auth import SQLAlchemyUserAdminDB
|
||||
from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
@@ -99,11 +103,6 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
@@ -144,20 +143,6 @@ def user_needs_to_be_verified() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def anonymous_user_enabled() -> bool:
|
||||
if MULTI_TENANT:
|
||||
return False
|
||||
|
||||
redis_client = get_redis_client(tenant_id=None)
|
||||
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
assert isinstance(value, bytes)
|
||||
return int(value.decode("utf-8")) == 1
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
whitelist = get_invited_users()
|
||||
if not whitelist:
|
||||
@@ -208,6 +193,30 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = "Onyx Email Verification"
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
|
||||
body = MIMEText(f"Click the following link to verify your email address: {link}")
|
||||
msg.attach(body)
|
||||
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
s.starttls()
|
||||
# If credentials fails with gmail, check (You need an app password, not just the basic email password)
|
||||
# https://support.google.com/accounts/answer/185833?sjid=8512343437447396151-NA
|
||||
s.login(SMTP_USER, SMTP_PASS)
|
||||
s.send_message(msg)
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
@@ -272,6 +281,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
if not user.role.is_web_login() and user_create.role.is_web_login():
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
@@ -395,9 +405,11 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# Explicitly set the Postgres schema for this session to ensure
|
||||
# OAuth account creation happens in the correct tenant schema
|
||||
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
@@ -416,6 +428,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
|
||||
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||
await self.user_db.update(
|
||||
@@ -493,15 +506,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async def on_after_forgot_password(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
) -> None:
|
||||
if not EMAIL_CONFIGURED:
|
||||
logger.error(
|
||||
"Email is not configured. Please configure email in the admin panel"
|
||||
)
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
"Your admin has not enbaled this feature.",
|
||||
)
|
||||
send_forgot_password_email(user.email, token)
|
||||
logger.notice(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
@@ -578,70 +583,51 @@ cookie_transport = CookieTransport(
|
||||
)
|
||||
|
||||
|
||||
def get_redis_strategy() -> RedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
|
||||
|
||||
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
"""
|
||||
A custom strategy that fetches the actual async Redis connection inside each method.
|
||||
We do NOT pass a synchronous or "coroutine" redis object to the constructor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lifetime_seconds: Optional[int] = REDIS_AUTH_EXPIRE_TIME_SECONDS,
|
||||
key_prefix: str = REDIS_AUTH_KEY_PREFIX,
|
||||
):
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
self.key_prefix = key_prefix
|
||||
|
||||
async def write_token(self, user: User) -> str:
|
||||
redis = await get_async_redis_connection()
|
||||
|
||||
# This strategy is used to add tenant_id to the JWT token
|
||||
class TenantAwareJWTStrategy(JWTStrategy):
|
||||
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(email=user.email)
|
||||
)(
|
||||
email=user.email,
|
||||
)
|
||||
|
||||
token_data = {
|
||||
data = {
|
||||
"sub": str(user.id),
|
||||
"aud": self.token_audience,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
token = secrets.token_urlsafe()
|
||||
await redis.set(
|
||||
f"{self.key_prefix}{token}",
|
||||
json.dumps(token_data),
|
||||
ex=self.lifetime_seconds,
|
||||
return data
|
||||
|
||||
async def write_token(self, user: User) -> str:
|
||||
data = await self._create_token_data(user)
|
||||
return generate_jwt(
|
||||
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
|
||||
)
|
||||
return token
|
||||
|
||||
async def read_token(
|
||||
self, token: Optional[str], user_manager: BaseUserManager[User, uuid.UUID]
|
||||
) -> Optional[User]:
|
||||
redis = await get_async_redis_connection()
|
||||
token_data_str = await redis.get(f"{self.key_prefix}{token}")
|
||||
if not token_data_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
token_data = json.loads(token_data_str)
|
||||
user_id = token_data["sub"]
|
||||
parsed_id = user_manager.parse_id(user_id)
|
||||
return await user_manager.get(parsed_id)
|
||||
except (exceptions.UserNotExists, exceptions.InvalidID, KeyError):
|
||||
return None
|
||||
def get_jwt_strategy() -> TenantAwareJWTStrategy:
|
||||
return TenantAwareJWTStrategy(
|
||||
secret=USER_AUTH_SECRET,
|
||||
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
|
||||
)
|
||||
|
||||
async def destroy_token(self, token: str, user: User) -> None:
|
||||
"""Properly delete the token from async redis."""
|
||||
redis = await get_async_redis_connection()
|
||||
await redis.delete(f"{self.key_prefix}{token}")
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
||||
)
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
|
||||
)
|
||||
name="jwt" if MULTI_TENANT else "database",
|
||||
transport=cookie_transport,
|
||||
get_strategy=get_jwt_strategy if MULTI_TENANT else get_database_strategy, # type: ignore
|
||||
) # type: ignore
|
||||
|
||||
|
||||
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
@@ -727,36 +713,30 @@ async def double_check_user(
|
||||
user: User | None,
|
||||
optional: bool = DISABLE_AUTH,
|
||||
include_expired: bool = False,
|
||||
allow_anonymous_access: bool = False,
|
||||
) -> User | None:
|
||||
if optional:
|
||||
return user
|
||||
|
||||
if user is not None:
|
||||
# If user attempted to authenticate, verify them, do not default
|
||||
# to anonymous access if it fails.
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if (
|
||||
user.oidc_expiry
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
if allow_anonymous_access:
|
||||
return None
|
||||
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
if user is None:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if (
|
||||
user.oidc_expiry
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def current_user_with_expired_token(
|
||||
@@ -771,14 +751,6 @@ async def current_limited_user(
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_chat_accesssible_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
return await double_check_user(
|
||||
user, allow_anonymous_access=anonymous_user_enabled()
|
||||
)
|
||||
|
||||
|
||||
async def current_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
|
||||
@@ -335,10 +335,6 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
if not hasattr(sender, "primary_worker_lock"):
|
||||
# primary_worker_lock will not exist when MULTI_TENANT is True
|
||||
return
|
||||
|
||||
if not sender.primary_worker_lock:
|
||||
return
|
||||
|
||||
@@ -418,21 +414,11 @@ def on_setup_logging(
|
||||
task_logger.setLevel(loglevel)
|
||||
task_logger.propagate = False
|
||||
|
||||
# hide celery task received spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] received"
|
||||
# Hide celery task received and succeeded/failed messages
|
||||
strategy.logger.setLevel(logging.WARNING)
|
||||
|
||||
# uncomment this to hide celery task succeeded/failed spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] succeeded in 0.03137450001668185s: None"
|
||||
trace.logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def set_task_finished_log_level(logLevel: int) -> None:
|
||||
"""call this to override the setLevel in on_setup_logging. We are interested
|
||||
in the task timings in the cloud but it can be spammy for self hosted."""
|
||||
trace.logger.setLevel(logLevel)
|
||||
|
||||
|
||||
class TenantContextFilter(logging.Filter):
|
||||
|
||||
"""Logging filter to inject tenant ID into the logger's name."""
|
||||
|
||||
@@ -60,12 +60,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
|
||||
# rkuo: been seeing transient connection exceptions here, so upping the connection count
|
||||
# from just concurrency/concurrency to concurrency/concurrency*2
|
||||
SqlEngine.init_engine(
|
||||
pool_size=sender.concurrency, max_overflow=sender.concurrency * 2
|
||||
)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -89,12 +88,12 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker. This is unnecessary in the multi tenant scenario
|
||||
r = get_redis_client(tenant_id=None)
|
||||
@@ -195,10 +194,6 @@ def on_setup_logging(
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
# this can be spammy, so just enable it in the cloud for now
|
||||
if MULTI_TENANT:
|
||||
app_base.set_task_finished_log_level(logging.INFO)
|
||||
|
||||
|
||||
class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
"""Regularly reacquires the primary worker lock outside of the task queue.
|
||||
@@ -286,6 +281,5 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -3,54 +3,12 @@ import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
|
||||
|
||||
def celery_get_unacked_length(r: Redis) -> int:
|
||||
"""Checking the unacked queue is useful because a non-zero length tells us there
|
||||
may be prefetched tasks.
|
||||
|
||||
There can be other tasks in here besides indexing tasks, so this is mostly useful
|
||||
just to see if the task count is non zero.
|
||||
|
||||
ref: https://blog.hikaru.run/2022/08/29/get-waiting-tasks-count-in-celery.html
|
||||
"""
|
||||
length = cast(int, r.hlen("unacked"))
|
||||
return length
|
||||
|
||||
|
||||
def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]:
|
||||
"""Gets the set of task id's matching the given queue in the unacked hash.
|
||||
|
||||
Unacked entries belonging to the indexing queue are "prefetched", so this gives
|
||||
us crucial visibility as to what tasks are in that state.
|
||||
"""
|
||||
tasks: set[str] = set()
|
||||
|
||||
for _, v in r.hscan_iter("unacked"):
|
||||
v_bytes = cast(bytes, v)
|
||||
v_str = v_bytes.decode("utf-8")
|
||||
task = json.loads(v_str)
|
||||
|
||||
task_description = task[0]
|
||||
task_queue = task[2]
|
||||
|
||||
if task_queue != queue:
|
||||
continue
|
||||
|
||||
task_id = task_description.get("headers", {}).get("id")
|
||||
if not task_id:
|
||||
continue
|
||||
|
||||
# if the queue matches and we see the task_id, add it
|
||||
tasks.add(task_id)
|
||||
return tasks
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
It is priority aware and knows how to count across the multiple redis lists
|
||||
@@ -89,74 +47,3 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
|
||||
"""Returns a list of current workers containing name_filter, or all workers if
|
||||
name_filter is None.
|
||||
|
||||
We've empirically discovered that the celery inspect API is potentially unstable
|
||||
and may hang or return empty results when celery is under load. Suggest using this
|
||||
more to debug and troubleshoot than in production code.
|
||||
"""
|
||||
worker_names: list[str] = []
|
||||
|
||||
# filter for and create an indexing specific inspect object
|
||||
inspect = app.control.inspect()
|
||||
workers: dict[str, Any] = inspect.ping() # type: ignore
|
||||
if workers:
|
||||
for worker_name in list(workers.keys()):
|
||||
# if the name filter not set, return all worker names
|
||||
if not name_filter:
|
||||
worker_names.append(worker_name)
|
||||
continue
|
||||
|
||||
# if the name filter is set, return only worker names that contain the name filter
|
||||
if name_filter not in worker_name:
|
||||
continue
|
||||
|
||||
worker_names.append(worker_name)
|
||||
|
||||
return worker_names
|
||||
|
||||
|
||||
def celery_inspect_get_reserved(worker_names: list[str], app: Celery) -> set[str]:
|
||||
"""Returns a list of reserved tasks on the specified workers.
|
||||
|
||||
We've empirically discovered that the celery inspect API is potentially unstable
|
||||
and may hang or return empty results when celery is under load. Suggest using this
|
||||
more to debug and troubleshoot than in production code.
|
||||
"""
|
||||
reserved_task_ids: set[str] = set()
|
||||
|
||||
inspect = app.control.inspect(destination=worker_names)
|
||||
|
||||
# get the list of reserved tasks
|
||||
reserved_tasks: dict[str, list] | None = inspect.reserved() # type: ignore
|
||||
if reserved_tasks:
|
||||
for _, task_list in reserved_tasks.items():
|
||||
for task in task_list:
|
||||
reserved_task_ids.add(task["id"])
|
||||
|
||||
return reserved_task_ids
|
||||
|
||||
|
||||
def celery_inspect_get_active(worker_names: list[str], app: Celery) -> set[str]:
|
||||
"""Returns a list of active tasks on the specified workers.
|
||||
|
||||
We've empirically discovered that the celery inspect API is potentially unstable
|
||||
and may hang or return empty results when celery is under load. Suggest using this
|
||||
more to debug and troubleshoot than in production code.
|
||||
"""
|
||||
active_task_ids: set[str] = set()
|
||||
|
||||
inspect = app.control.inspect(destination=worker_names)
|
||||
|
||||
# get the list of reserved tasks
|
||||
active_tasks: dict[str, list] | None = inspect.active() # type: ignore
|
||||
if active_tasks:
|
||||
for _, task_list in active_tasks.items():
|
||||
for task in task_list:
|
||||
active_task_ids.add(task["id"])
|
||||
|
||||
return active_task_ids
|
||||
|
||||
@@ -16,14 +16,6 @@ result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
# Indexing worker specific ... this lets us track the transition to STARTED in redis
|
||||
# We don't currently rely on this but it has the potential to be useful and
|
||||
# indexing tasks are not high volume
|
||||
|
||||
# we don't turn this on yet because celery occasionally runs tasks more than once
|
||||
# which means a duplicate run might change the task state unexpectedly
|
||||
# task_track_started = True
|
||||
|
||||
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
|
||||
@@ -1,16 +1,9 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
# we might be able to reduce this greatly if we can run a unified
|
||||
# loop across all tenants rather than tasks per tenant
|
||||
|
||||
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# we set expires because it isn't necessary to queue up these tasks
|
||||
# it's only important that they run relatively regularly
|
||||
tasks_to_schedule = [
|
||||
@@ -20,7 +13,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -29,7 +22,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -38,7 +31,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -47,7 +40,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -56,7 +49,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOWEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -65,7 +58,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -74,7 +67,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -83,25 +76,11 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return tasks_to_schedule
|
||||
|
||||
@@ -34,9 +34,7 @@ class TaskDependencyError(RuntimeError):
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_connector_deletion_task(
|
||||
self: Task, *, tenant_id: str | None
|
||||
) -> bool | None:
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -47,7 +45,7 @@ def check_for_connector_deletion_task(
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
return
|
||||
|
||||
# collect cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
@@ -83,8 +81,6 @@ def check_for_connector_deletion_task(
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_generate_document_cc_pair_cleanup_tasks(
|
||||
app: Celery,
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from time import sleep
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
@@ -20,7 +18,6 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -91,10 +88,10 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
lock_beat = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -102,7 +99,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
return
|
||||
|
||||
# get all cc pairs that need to be synced
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
@@ -131,8 +128,6 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
@@ -224,43 +219,6 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_permission_sync_generator_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector.permissions.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_permission_sync_generator_task - fence not found: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector.permissions.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError(
|
||||
"connector_permission_sync_generator_task: payload invalid or not found"
|
||||
)
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_permission_sync_generator_task - Waiting for fence: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"connector_permission_sync_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
@@ -296,11 +254,8 @@ def connector_permission_sync_generator_task(
|
||||
if not payload:
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
|
||||
new_payload = RedisConnectorPermissionSyncPayload(
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=payload.celery_task_id,
|
||||
)
|
||||
redis_connector.permissions.set_fence(new_payload)
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
||||
|
||||
|
||||
@@ -94,10 +94,10 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
lock_beat = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -105,7 +105,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
return
|
||||
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -149,8 +149,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_external_group_sync_task(
|
||||
app: Celery,
|
||||
@@ -164,7 +162,7 @@ def try_creating_external_group_sync_task(
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
lock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
import sentry_sdk
|
||||
@@ -19,12 +18,10 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.indexing.job_client import SimpleJobClient
|
||||
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -32,7 +29,6 @@ from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
@@ -62,7 +58,6 @@ from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
@@ -74,18 +69,14 @@ logger = setup_logger()
|
||||
|
||||
|
||||
class IndexingCallback(IndexingHeartbeatInterface):
|
||||
PARENT_CHECK_INTERVAL = 60
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_pid: int,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_pid = parent_pid
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
@@ -95,54 +86,26 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
|
||||
self.last_tag: str = "IndexingCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_parent_check = time.monotonic()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
|
||||
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
|
||||
# so leave this code in until we're ready to test it.
|
||||
|
||||
# if self.parent_pid:
|
||||
# # check if the parent pid is alive so we aren't running as a zombie
|
||||
# now = time.monotonic()
|
||||
# if now - self.last_parent_check > IndexingCallback.PARENT_CHECK_INTERVAL:
|
||||
# try:
|
||||
# # this is unintuitive, but it checks if the parent pid is still running
|
||||
# os.kill(self.parent_pid, 0)
|
||||
# except Exception:
|
||||
# logger.exception("IndexingCallback - parent pid check exceptioned")
|
||||
# raise
|
||||
# self.last_parent_check = now
|
||||
|
||||
try:
|
||||
current_time = time.monotonic()
|
||||
if current_time - self.last_lock_monotonic >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
self.redis_lock.reacquire()
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.redis_lock.reacquire()
|
||||
self.last_tag = tag
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"IndexingCallback - lock.reacquire exceptioned: "
|
||||
f"IndexingCallback - lock.reacquire exceptioned. "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_tag={self.last_tag} "
|
||||
f"last_reacquired={self.last_lock_reacquire} "
|
||||
f"now={datetime.now(timezone.utc)}"
|
||||
)
|
||||
|
||||
redis_lock_dump(self.redis_lock, self.redis_client)
|
||||
raise
|
||||
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
@@ -212,7 +175,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
# redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK,
|
||||
@@ -334,7 +297,6 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
# This shouldn't ever happen!
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
lock_beat.reacquire()
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(
|
||||
db_session, redis_client
|
||||
)
|
||||
@@ -356,20 +318,23 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
|
||||
lock_beat.reacquire()
|
||||
# clear any indexing fences that don't have associated celery tasks in progress
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_indexing_fences(
|
||||
tenant_id, self.app, redis_client, redis_client_celery, lock_beat
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating indexing fences")
|
||||
# rkuo: The following code logically appears to work, but the celery inspect code may be unstable
|
||||
# turning off for the moment to see if it helps cloud stability
|
||||
|
||||
redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
|
||||
# we want to run this less frequently than the overall task
|
||||
# if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
|
||||
# # clear any indexing fences that don't have associated celery tasks in progress
|
||||
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# # or be currently executing
|
||||
# try:
|
||||
# task_logger.info("Validating indexing fences...")
|
||||
# validate_indexing_fences(
|
||||
# tenant_id, self.app, redis_client, redis_client_celery, lock_beat
|
||||
# )
|
||||
# except Exception:
|
||||
# task_logger.exception("Exception while validating indexing fences")
|
||||
|
||||
# redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
@@ -386,10 +351,9 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
"check_for_indexing - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.debug(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
|
||||
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
|
||||
return tasks_created
|
||||
|
||||
|
||||
@@ -400,9 +364,46 @@ def validate_indexing_fences(
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
reserved_indexing_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
reserved_indexing_tasks: set[str] = set()
|
||||
active_indexing_tasks: set[str] = set()
|
||||
indexing_worker_names: list[str] = []
|
||||
|
||||
# filter for and create an indexing specific inspect object
|
||||
inspect = celery_app.control.inspect()
|
||||
workers: dict[str, Any] = inspect.ping() # type: ignore
|
||||
if not workers:
|
||||
raise ValueError("No workers found!")
|
||||
|
||||
for worker_name in list(workers.keys()):
|
||||
if "indexing" in worker_name:
|
||||
indexing_worker_names.append(worker_name)
|
||||
|
||||
if len(indexing_worker_names) == 0:
|
||||
raise ValueError("No indexing workers found!")
|
||||
|
||||
inspect_indexing = celery_app.control.inspect(destination=indexing_worker_names)
|
||||
|
||||
# NOTE: each dict entry is a map of worker name to a list of tasks
|
||||
# we want sets for reserved task and active task id's to optimize
|
||||
# subsequent validation lookups
|
||||
|
||||
# get the list of reserved tasks
|
||||
reserved_tasks: dict[str, list] | None = inspect_indexing.reserved() # type: ignore
|
||||
if reserved_tasks is None:
|
||||
raise ValueError("inspect_indexing.reserved() returned None!")
|
||||
|
||||
for _, task_list in reserved_tasks.items():
|
||||
for task in task_list:
|
||||
reserved_indexing_tasks.add(task["id"])
|
||||
|
||||
# get the list of active tasks
|
||||
active_tasks: dict[str, list] | None = inspect_indexing.active() # type: ignore
|
||||
if active_tasks is None:
|
||||
raise ValueError("inspect_indexing.active() returned None!")
|
||||
|
||||
for _, task_list in active_tasks.items():
|
||||
for task in task_list:
|
||||
active_indexing_tasks.add(task["id"])
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
@@ -412,6 +413,7 @@ def validate_indexing_fences(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_indexing_tasks,
|
||||
active_indexing_tasks,
|
||||
r_celery,
|
||||
db_session,
|
||||
)
|
||||
@@ -422,6 +424,7 @@ def validate_indexing_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
active_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
@@ -431,15 +434,11 @@ def validate_indexing_fence(
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1. Active signal is renewed with a 5 minute TTL
|
||||
1.1 When the fence is created
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
1.3. When the task is seen in the reserved or active list for a worker
|
||||
2. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
@@ -467,8 +466,6 @@ def validate_indexing_fence(
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector_index.fenced:
|
||||
return
|
||||
|
||||
@@ -504,24 +501,24 @@ def validate_indexing_fence(
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in active_tasks:
|
||||
# the celery task is active (aka currently executing)
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
# we didn't find any direct indication that associated celery tasks exist, but they still might be there
|
||||
# due to gaps in our ability to check states during transitions
|
||||
# Rely on the active signal (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector_index.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
logger.warning(
|
||||
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
|
||||
f"index_attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"fence={fence_key}"
|
||||
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: fence={fence_key}"
|
||||
)
|
||||
if payload.index_attempt_id:
|
||||
try:
|
||||
@@ -786,6 +783,7 @@ def connector_indexing_proxy_task(
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
@@ -797,58 +795,6 @@ def connector_indexing_proxy_task(
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
# renew active signal
|
||||
redis_connector_index.set_active()
|
||||
|
||||
# if the job is done, clean up and break
|
||||
if job.done():
|
||||
try:
|
||||
if job.status == "error":
|
||||
ignore_exitcode = False
|
||||
|
||||
exit_code: int | None = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
|
||||
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
|
||||
# even though logging clearly indicates successful completion
|
||||
# to work around this, we ignore the job error state if the completion signal is OK
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if not ignore_exitcode:
|
||||
raise RuntimeError("Spawned task exceptioned.")
|
||||
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
|
||||
raise
|
||||
finally:
|
||||
job.release()
|
||||
|
||||
break
|
||||
|
||||
# if a termination signal is detected, clean up and break
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - termination signal detected: "
|
||||
@@ -875,33 +821,75 @@ def connector_indexing_proxy_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
job.cancel()
|
||||
|
||||
break
|
||||
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
if not job.done():
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
ignore_exitcode = False
|
||||
|
||||
exit_code: int | None = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
|
||||
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
|
||||
# even though logging clearly indicates that they completed successfully
|
||||
# to work around this, we ignore the job error state if the completion signal is OK
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if ignore_exitcode:
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code}"
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
continue
|
||||
job.release()
|
||||
break
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - finished: attempt={index_attempt_id} "
|
||||
@@ -930,7 +918,7 @@ def connector_indexing_task_wrapper(
|
||||
tenant_id,
|
||||
is_ee,
|
||||
)
|
||||
except Exception:
|
||||
except:
|
||||
logger.exception(
|
||||
f"connector_indexing_task exceptioned: "
|
||||
f"tenant={tenant_id} "
|
||||
@@ -938,14 +926,7 @@ def connector_indexing_task_wrapper(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
# There is a cloud related bug outside of our code
|
||||
# where spawned tasks return with an exit code of 1.
|
||||
# Unfortunately, exceptions also return with an exit code of 1,
|
||||
# so just raising an exception isn't informative
|
||||
# Exiting with 255 makes it possible to distinguish between normal exits
|
||||
# and exceptions.
|
||||
sys.exit(255)
|
||||
raise
|
||||
|
||||
return result
|
||||
|
||||
@@ -1017,17 +998,7 @@ def connector_indexing_task(
|
||||
f"fence={redis_connector.stop.fence_key}"
|
||||
)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector_index.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
|
||||
@@ -1068,9 +1039,7 @@ def connector_indexing_task(
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"index_attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1106,7 +1075,6 @@ def connector_indexing_task(
|
||||
|
||||
# define a callback class
|
||||
callback = IndexingCallback(
|
||||
os.getppid(),
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector_index.generator_progress_key,
|
||||
lock,
|
||||
@@ -1140,19 +1108,8 @@ def connector_indexing_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
if attempt_found:
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id, db_session, failure_reason=str(e)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(index_attempt_id, db_session, failure_reason=str(e))
|
||||
|
||||
raise e
|
||||
finally:
|
||||
|
||||
@@ -1,105 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import LLMProvider
|
||||
|
||||
|
||||
def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
# Handle case where response is wrapped in a "data" field
|
||||
if isinstance(model_list_json, dict) and "data" in model_list_json:
|
||||
model_list_json = model_list_json["data"]
|
||||
|
||||
if not isinstance(model_list_json, list):
|
||||
raise ValueError(
|
||||
f"Invalid response from API - expected list, got {type(model_list_json)}"
|
||||
)
|
||||
|
||||
# Handle both string list and object list cases
|
||||
model_names: list[str] = []
|
||||
for item in model_list_json:
|
||||
if isinstance(item, str):
|
||||
model_names.append(item)
|
||||
elif isinstance(item, dict) and "model_name" in item:
|
||||
model_names.append(item["model_name"])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid item in model list - expected string or dict with model_name, got {type(item)}"
|
||||
)
|
||||
|
||||
return model_names
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
if not LLM_MODEL_UPDATE_API_URL:
|
||||
raise ValueError("LLM model update API URL not configured")
|
||||
|
||||
# First fetch the models from the API
|
||||
try:
|
||||
response = requests.get(LLM_MODEL_UPDATE_API_URL)
|
||||
response.raise_for_status()
|
||||
available_models = _process_model_list_response(response.json())
|
||||
task_logger.info(f"Found available models: {available_models}")
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Failed to fetch models from API.")
|
||||
return None
|
||||
|
||||
# Then update the database with the fetched models
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the default LLM provider
|
||||
default_provider = (
|
||||
db_session.query(LLMProvider)
|
||||
.filter(LLMProvider.is_default_provider.is_(True))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not default_provider:
|
||||
task_logger.warning("No default LLM provider found")
|
||||
return None
|
||||
|
||||
# log change if any
|
||||
old_models = set(default_provider.model_names or [])
|
||||
new_models = set(available_models)
|
||||
added_models = new_models - old_models
|
||||
removed_models = old_models - new_models
|
||||
|
||||
if added_models:
|
||||
task_logger.info(f"Adding models: {sorted(added_models)}")
|
||||
if removed_models:
|
||||
task_logger.info(f"Removing models: {sorted(removed_models)}")
|
||||
|
||||
# Update the provider's model list
|
||||
default_provider.model_names = available_models
|
||||
# if the default model is no longer available, set it to the first model in the list
|
||||
if default_provider.default_model_name not in available_models:
|
||||
task_logger.info(
|
||||
f"Default model {default_provider.default_model_name} not "
|
||||
f"available, setting to first model in list: {available_models[0]}"
|
||||
)
|
||||
default_provider.default_model_name = available_models[0]
|
||||
if default_provider.fast_default_model_name not in available_models:
|
||||
task_logger.info(
|
||||
f"Fast default model {default_provider.fast_default_model_name} "
|
||||
f"not available, setting to first model in list: {available_models[0]}"
|
||||
)
|
||||
default_provider.fast_default_model_name = available_models[0]
|
||||
db_session.commit()
|
||||
|
||||
if added_models or removed_models:
|
||||
task_logger.info("Updated model list for default provider.")
|
||||
|
||||
return True
|
||||
@@ -81,10 +81,10 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
lock_beat = r.lock(
|
||||
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -92,7 +92,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
return
|
||||
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -127,8 +127,6 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_prune_generator_task(
|
||||
celery_app: Celery,
|
||||
@@ -285,7 +283,6 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
|
||||
callback = IndexingCallback(
|
||||
0,
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector.prune.generator_progress_key,
|
||||
lock,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
@@ -21,7 +20,6 @@ from tenacity import RetryError
|
||||
from onyx.access.access import get_access_for_document
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
@@ -69,7 +67,6 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
@@ -78,7 +75,6 @@ from onyx.utils.variable_functionality import (
|
||||
)
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -91,7 +87,7 @@ logger = setup_logger()
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
time_start = time.monotonic()
|
||||
@@ -106,7 +102,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_stale_document_sync_tasks(
|
||||
@@ -114,7 +110,6 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
)
|
||||
|
||||
# region document set scan
|
||||
lock_beat.reacquire()
|
||||
document_set_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# check if any document sets are not synced
|
||||
@@ -126,7 +121,6 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
document_set_ids.append(document_set.id)
|
||||
|
||||
for document_set_id in document_set_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_document_set_sync_tasks(
|
||||
self.app, document_set_id, db_session, r, lock_beat, tenant_id
|
||||
@@ -135,8 +129,6 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
|
||||
# check if any user groups are not synced
|
||||
if global_version.is_ee_version():
|
||||
lock_beat.reacquire()
|
||||
|
||||
try:
|
||||
fetch_user_groups = fetch_versioned_implementation(
|
||||
"onyx.db.user_group", "fetch_user_groups"
|
||||
@@ -156,7 +148,6 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
usergroup_ids.append(usergroup.id)
|
||||
|
||||
for usergroup_id in usergroup_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_user_group_sync_tasks(
|
||||
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
|
||||
@@ -171,16 +162,10 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_vespa_sync_task - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, r)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.debug(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
|
||||
return True
|
||||
task_logger.info(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
|
||||
return
|
||||
|
||||
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
@@ -651,23 +636,15 @@ def monitor_ccpair_indexing_taskset(
|
||||
if not payload:
|
||||
return
|
||||
|
||||
elapsed_started_str = None
|
||||
if payload.started:
|
||||
elapsed_started = datetime.now(timezone.utc) - payload.started
|
||||
elapsed_started_str = f"{elapsed_started.total_seconds():.2f}"
|
||||
|
||||
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
|
||||
|
||||
progress = redis_connector_index.get_progress()
|
||||
if progress is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"Connector indexing progress: cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
@@ -738,14 +715,11 @@ def monitor_ccpair_indexing_taskset(
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"Connector indexing finished: cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
@@ -762,13 +736,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
Returns True if the task actually did work, False if it exited early to prevent overlap
|
||||
"""
|
||||
task_logger.info(f"monitor_vespa_sync starting: tenant={tenant_id}")
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
||||
timings: dict[str, float] = {}
|
||||
timings["start"] = time_start
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -779,76 +747,50 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
try:
|
||||
# prevent overlapping tasks
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
task_logger.info("monitor_vespa_sync exiting due to overlap")
|
||||
return False
|
||||
|
||||
# print current queue lengths
|
||||
phase_start = time.monotonic()
|
||||
# we don't need every tenant polling redis for this info.
|
||||
if not MULTI_TENANT or random.randint(1, 100) == 100:
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
n_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery
|
||||
)
|
||||
n_deletion = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
n_pruning = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
n_external_group_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
n_permissions_upsert = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
n_deletion = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
n_pruning = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
|
||||
prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"indexing_prefetched={len(prefetched)} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
f"external_group_sync={n_external_group_sync} "
|
||||
f"permissions_upsert={n_permissions_upsert} "
|
||||
)
|
||||
timings["queues"] = time.monotonic() - phase_start
|
||||
|
||||
# scan and monitor activity to completion
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
timings["connector"] = time.monotonic() - phase_start
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
|
||||
timings["connector_deletion"] = time.monotonic() - phase_start
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
timings["document_set"] = time.monotonic() - phase_start
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
|
||||
@@ -858,29 +800,29 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
timings["usergroup"] = time.monotonic() - phase_start
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
timings["pruning"] = time.monotonic() - phase_start
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
timings["indexing"] = time.monotonic() - phase_start
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
timings["permissions"] = time.monotonic() - phase_start
|
||||
# uncomment for debugging if needed
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
# length = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
# task_logger.warning(f"queue={OnyxCeleryQueues.VESPA_METADATA_SYNC} length={length}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -888,21 +830,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
t = timings
|
||||
task_logger.error(
|
||||
"monitor_vespa_sync - Lock not owned on completion: "
|
||||
f"tenant={tenant_id} "
|
||||
f"queues={t.get('queues')} "
|
||||
f"connector={t.get('connector')} "
|
||||
f"connector_deletion={t.get('connector_deletion')} "
|
||||
f"document_set={t.get('document_set')} "
|
||||
f"usergroup={t.get('usergroup')} "
|
||||
f"pruning={t.get('pruning')} "
|
||||
f"indexing={t.get('indexing')} "
|
||||
f"permissions={t.get('permissions')}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, r)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
|
||||
@@ -955,13 +882,6 @@ def vespa_metadata_sync_task(
|
||||
# the sync might repeat again later
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
|
||||
redis_syncing_key = RedisConnectorCredentialPair.make_redis_syncing_key(
|
||||
document_id
|
||||
)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r.delete(redis_syncing_key)
|
||||
# r.hdel(RedisConnectorCredentialPair.SYNCING_HASH, document_id)
|
||||
|
||||
task_logger.info(f"doc={document_id} action=sync chunks={chunks_affected}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
|
||||
@@ -14,7 +14,6 @@ from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
@@ -91,35 +90,6 @@ def _get_connector_runner(
|
||||
)
|
||||
|
||||
|
||||
def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
cleaned_batch = []
|
||||
for doc in doc_batch:
|
||||
cleaned_doc = doc.model_copy()
|
||||
|
||||
if "\x00" in cleaned_doc.id:
|
||||
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
|
||||
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
|
||||
|
||||
if "\x00" in cleaned_doc.semantic_identifier:
|
||||
logger.warning(
|
||||
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
|
||||
)
|
||||
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
|
||||
"\x00", ""
|
||||
)
|
||||
|
||||
for section in cleaned_doc.sections:
|
||||
if section.link and "\x00" in section.link:
|
||||
logger.warning(
|
||||
f"NUL characters found in document link for document: {cleaned_doc.id}"
|
||||
)
|
||||
section.link = section.link.replace("\x00", "")
|
||||
|
||||
cleaned_batch.append(cleaned_doc)
|
||||
|
||||
return cleaned_batch
|
||||
|
||||
|
||||
class ConnectorStopSignal(Exception):
|
||||
"""A custom exception used to signal a stop in processing."""
|
||||
|
||||
@@ -268,9 +238,7 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
batch_description = []
|
||||
|
||||
doc_batch_cleaned = strip_null_characters(doc_batch)
|
||||
for doc in doc_batch_cleaned:
|
||||
for doc in doc_batch:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
|
||||
doc_size = 0
|
||||
@@ -290,15 +258,15 @@ def _run_indexing(
|
||||
|
||||
# real work happens here!
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
document_batch=doc_batch_cleaned,
|
||||
document_batch=doc_batch,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
net_doc_change += new_docs
|
||||
chunk_count += total_batch_chunks
|
||||
document_count += len(doc_batch_cleaned)
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch_cleaned)
|
||||
document_count += len(doc_batch)
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch)
|
||||
|
||||
# commit transaction so that the `update` below begins
|
||||
# with a brand new transaction. Postgres uses the start
|
||||
@@ -308,7 +276,7 @@ def _run_indexing(
|
||||
db_session.commit()
|
||||
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch_cleaned))
|
||||
callback.progress("_run_indexing", len(doc_batch))
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
update_docs_indexed(
|
||||
|
||||
@@ -22,9 +22,7 @@ from onyx.chat.stream_processing.answer_response_handler import (
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import (
|
||||
map_document_id_order,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -208,9 +206,9 @@ class Answer:
|
||||
# + figure out what the next LLM call should be
|
||||
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
|
||||
final_search_results, displayed_search_results = SearchTool.get_search_result(
|
||||
search_result, displayed_search_results_map = SearchTool.get_search_result(
|
||||
current_llm_call
|
||||
) or ([], [])
|
||||
) or ([], {})
|
||||
|
||||
# Quotes are no longer supported
|
||||
# answer_handler: AnswerResponseHandler
|
||||
@@ -226,9 +224,9 @@ class Answer:
|
||||
# else:
|
||||
# raise ValueError("No answer style config provided")
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=final_search_results,
|
||||
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
display_doc_order_dict=displayed_search_results_map,
|
||||
)
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
|
||||
@@ -37,22 +37,22 @@ class CitationResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
final_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_order_dict: dict[str, int],
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
|
||||
self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.display_doc_order_dict = display_doc_order_dict
|
||||
self.citation_processor = CitationProcessor(
|
||||
context_docs=self.context_docs,
|
||||
final_doc_id_to_rank_map=self.final_doc_id_to_rank_map,
|
||||
display_doc_id_to_rank_map=self.display_doc_id_to_rank_map,
|
||||
doc_id_to_rank_map=self.doc_id_to_rank_map,
|
||||
display_doc_order_dict=self.display_doc_order_dict,
|
||||
)
|
||||
self.processed_text = ""
|
||||
self.citations: list[CitationInfo] = []
|
||||
|
||||
# TODO remove this after citation issue is resolved
|
||||
logger.debug(f"Document to ranking map {self.final_doc_id_to_rank_map}")
|
||||
logger.debug(f"Document to ranking map {self.doc_id_to_rank_map}")
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
|
||||
@@ -21,19 +21,20 @@ class CitationProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
final_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_order_dict: dict[str, int],
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
|
||||
self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.stop_stream = stop_stream
|
||||
self.final_order_mapping = final_doc_id_to_rank_map.order_mapping
|
||||
self.display_order_mapping = display_doc_id_to_rank_map.order_mapping
|
||||
self.order_mapping = doc_id_to_rank_map.order_mapping
|
||||
self.display_doc_order_dict = (
|
||||
display_doc_order_dict # original order of docs to displayed to user
|
||||
)
|
||||
self.llm_out = ""
|
||||
self.max_citation_num = len(context_docs)
|
||||
self.citation_order: list[int] = [] # order of citations in the LLM output
|
||||
self.citation_order: list[int] = []
|
||||
self.curr_segment = ""
|
||||
self.cited_inds: set[int] = set()
|
||||
self.hold = ""
|
||||
@@ -92,31 +93,29 @@ class CitationProcessor:
|
||||
|
||||
if 1 <= numerical_value <= self.max_citation_num:
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
final_citation_num = self.final_order_mapping[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
real_citation_num = self.order_mapping[context_llm_doc.document_id]
|
||||
|
||||
if final_citation_num not in self.citation_order:
|
||||
self.citation_order.append(final_citation_num)
|
||||
if real_citation_num not in self.citation_order:
|
||||
self.citation_order.append(real_citation_num)
|
||||
|
||||
citation_order_idx = (
|
||||
self.citation_order.index(final_citation_num) + 1
|
||||
target_citation_num = (
|
||||
self.citation_order.index(real_citation_num) + 1
|
||||
)
|
||||
|
||||
# get the value that was displayed to user, should always
|
||||
# be in the display_doc_order_dict. But check anyways
|
||||
if context_llm_doc.document_id in self.display_order_mapping:
|
||||
displayed_citation_num = self.display_order_mapping[
|
||||
if context_llm_doc.document_id in self.display_doc_order_dict:
|
||||
displayed_citation_num = self.display_doc_order_dict[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
else:
|
||||
displayed_citation_num = final_citation_num
|
||||
displayed_citation_num = real_citation_num
|
||||
logger.warning(
|
||||
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
||||
)
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if final_citation_num in self.current_citations:
|
||||
if target_citation_num in self.current_citations:
|
||||
start, end = citation.span()
|
||||
real_start = length_to_add + start
|
||||
diff = end - start
|
||||
@@ -135,8 +134,8 @@ class CitationProcessor:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
# citation_num is now the number post initial ranking, i.e. as displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
# stay with the original for now (order of LLM cites)
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -152,13 +151,13 @@ class CitationProcessor:
|
||||
link = context_llm_doc.link
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(final_citation_num)
|
||||
self.current_citations.append(target_citation_num)
|
||||
|
||||
if citation_order_idx not in self.cited_inds:
|
||||
self.cited_inds.add(citation_order_idx)
|
||||
if target_citation_num not in self.cited_inds:
|
||||
self.cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
# citation number is now the one that was displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
# stay with the original for now (order of LLM cites)
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
@@ -168,6 +167,7 @@ class CitationProcessor:
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
||||
# + f"[[{target_citation_num}]]({link})"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
@@ -176,6 +176,7 @@ class CitationProcessor:
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
||||
# + f"[[{target_citation_num}]]()"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
|
||||
@@ -54,17 +54,10 @@ MASK_CREDENTIAL_PREFIX = (
|
||||
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
||||
)
|
||||
|
||||
REDIS_AUTH_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
SESSION_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
# Default request timeout, mostly used by connectors
|
||||
REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS") or 60)
|
||||
|
||||
# set `VALID_EMAIL_DOMAINS` to a comma seperated list of domains in order to
|
||||
# restrict access to Onyx to only users with emails from those domains.
|
||||
# E.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will restrict Onyx
|
||||
@@ -99,7 +92,6 @@ SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
|
||||
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
|
||||
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
|
||||
EMAIL_CONFIGURED = all([SMTP_SERVER, SMTP_USER, SMTP_PASS])
|
||||
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
|
||||
|
||||
# If set, Onyx will listen to the `expires_at` returned by the identity
|
||||
@@ -153,7 +145,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
|
||||
AWS_REGION = os.environ.get("AWS_REGION") or "us-east-2"
|
||||
|
||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
|
||||
@@ -192,27 +184,6 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
|
||||
|
||||
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
|
||||
|
||||
|
||||
# Rate limiting for auth endpoints
|
||||
RATE_LIMIT_WINDOW_SECONDS: int | None = None
|
||||
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
|
||||
if _rate_limit_window_seconds_str is not None:
|
||||
try:
|
||||
RATE_LIMIT_WINDOW_SECONDS = int(_rate_limit_window_seconds_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
RATE_LIMIT_MAX_REQUESTS: int | None = None
|
||||
_rate_limit_max_requests_str = os.environ.get("RATE_LIMIT_MAX_REQUESTS")
|
||||
if _rate_limit_max_requests_str is not None:
|
||||
try:
|
||||
RATE_LIMIT_MAX_REQUESTS = int(_rate_limit_max_requests_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Used for general redis things
|
||||
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
|
||||
|
||||
@@ -376,17 +347,12 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
|
||||
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Typically set to http://localhost:3000 for OAuth connector development
|
||||
CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE")
|
||||
|
||||
# Egnyte specific configs
|
||||
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
|
||||
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
|
||||
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
|
||||
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
|
||||
|
||||
# Linear specific configs
|
||||
LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
|
||||
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
@@ -537,9 +503,6 @@ try:
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# LLM Model Update API endpoint
|
||||
LLM_MODEL_UPDATE_API_URL = os.environ.get("LLM_MODEL_UPDATE_API_URL")
|
||||
|
||||
#####
|
||||
# Enterprise Edition Configs
|
||||
#####
|
||||
@@ -579,6 +542,7 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
|
||||
# JWT configuration
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
|
||||
#####
|
||||
# API Key Configs
|
||||
#####
|
||||
|
||||
@@ -36,8 +36,6 @@ DISABLED_GEN_AI_MSG = (
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
DEFAULT_CC_PAIR_ID = 1
|
||||
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
@@ -76,19 +74,13 @@ KV_ENTERPRISE_SETTINGS_KEY = "onyx_enterprise_settings"
|
||||
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
|
||||
KV_DOCUMENTS_SEEDED_KEY = "documents_seeded"
|
||||
|
||||
# NOTE: we use this timeout / 4 in various places to refresh a lock
|
||||
# might be worth separating this timeout into separate timeouts for each situation
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
|
||||
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
|
||||
|
||||
# how long a task should wait for associated fence to be ready
|
||||
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
|
||||
@@ -142,11 +134,9 @@ class DocumentSource(str, Enum):
|
||||
OCI_STORAGE = "oci_storage"
|
||||
XENFORO = "xenforo"
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
DISCORD = "discord"
|
||||
FRESHDESK = "freshdesk"
|
||||
FIREFLIES = "fireflies"
|
||||
EGNYTE = "egnyte"
|
||||
AIRTABLE = "airtable"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
@@ -250,7 +240,6 @@ class OnyxCeleryQueues:
|
||||
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
||||
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
LLM_MODEL_UPDATE = "llm_model_update"
|
||||
|
||||
# Heavy queue
|
||||
CONNECTOR_PRUNING = "connector_pruning"
|
||||
@@ -284,7 +273,6 @@ class OnyxRedisLocks:
|
||||
|
||||
SLACK_BOT_LOCK = "da_lock:slack_bot"
|
||||
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
|
||||
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
@@ -306,7 +294,6 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_PRUNING = "check_for_pruning"
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
|
||||
@@ -1,289 +0,0 @@
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from pyairtable import Api as AirtableApi
|
||||
from pyairtable.api.types import RecordDict
|
||||
from pyairtable.models.schema import TableSchema
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# NOTE: all are made lowercase to avoid case sensitivity issues
|
||||
# these are the field types that are considered metadata rather
|
||||
# than sections
|
||||
_METADATA_FIELD_TYPES = {
|
||||
"singlecollaborator",
|
||||
"collaborator",
|
||||
"createdby",
|
||||
"singleselect",
|
||||
"multipleselects",
|
||||
"checkbox",
|
||||
"date",
|
||||
"datetime",
|
||||
"email",
|
||||
"phone",
|
||||
"url",
|
||||
"number",
|
||||
"currency",
|
||||
"duration",
|
||||
"percent",
|
||||
"rating",
|
||||
"createdtime",
|
||||
"lastmodifiedtime",
|
||||
"autonumber",
|
||||
"rollup",
|
||||
"lookup",
|
||||
"count",
|
||||
"formula",
|
||||
"date",
|
||||
}
|
||||
|
||||
|
||||
class AirtableClientNotSetUpError(PermissionError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Airtable Client is not set up, was load_credentials called?")
|
||||
|
||||
|
||||
class AirtableConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.base_id = base_id
|
||||
self.table_name_or_id = table_name_or_id
|
||||
self.batch_size = batch_size
|
||||
self.airtable_client: AirtableApi | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
|
||||
return None
|
||||
|
||||
def _get_field_value(self, field_info: Any, field_type: str) -> list[str]:
|
||||
"""
|
||||
Extract value(s) from a field regardless of its type.
|
||||
Returns either a single string or list of strings for attachments.
|
||||
"""
|
||||
if field_info is None:
|
||||
return []
|
||||
|
||||
# skip references to other records for now (would need to do another
|
||||
# request to get the actual record name/type)
|
||||
# TODO: support this
|
||||
if field_type == "multipleRecordLinks":
|
||||
return []
|
||||
|
||||
if field_type == "multipleAttachments":
|
||||
attachment_texts: list[str] = []
|
||||
for attachment in field_info:
|
||||
url = attachment.get("url")
|
||||
filename = attachment.get("filename", "")
|
||||
if not url:
|
||||
continue
|
||||
|
||||
@retry(
|
||||
tries=5,
|
||||
delay=1,
|
||||
backoff=2,
|
||||
max_delay=10,
|
||||
)
|
||||
def get_attachment_with_retry(url: str) -> bytes | None:
|
||||
attachment_response = requests.get(url)
|
||||
if attachment_response.status_code == 200:
|
||||
return attachment_response.content
|
||||
return None
|
||||
|
||||
attachment_content = get_attachment_with_retry(url)
|
||||
if attachment_content:
|
||||
try:
|
||||
file_ext = get_file_ext(filename)
|
||||
attachment_text = extract_file_text(
|
||||
BytesIO(attachment_content),
|
||||
filename,
|
||||
break_on_unprocessable=False,
|
||||
extension=file_ext,
|
||||
)
|
||||
if attachment_text:
|
||||
attachment_texts.append(f"{filename}:\n{attachment_text}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to process attachment {filename}: {str(e)}"
|
||||
)
|
||||
return attachment_texts
|
||||
|
||||
if field_type in ["singleCollaborator", "collaborator", "createdBy"]:
|
||||
combined = []
|
||||
collab_name = field_info.get("name")
|
||||
collab_email = field_info.get("email")
|
||||
if collab_name:
|
||||
combined.append(collab_name)
|
||||
if collab_email:
|
||||
combined.append(f"({collab_email})")
|
||||
return [" ".join(combined) if combined else str(field_info)]
|
||||
|
||||
if isinstance(field_info, list):
|
||||
return [str(item) for item in field_info]
|
||||
|
||||
return [str(field_info)]
|
||||
|
||||
def _should_be_metadata(self, field_type: str) -> bool:
|
||||
"""Determine if a field type should be treated as metadata."""
|
||||
return field_type.lower() in _METADATA_FIELD_TYPES
|
||||
|
||||
def _process_field(
|
||||
self,
|
||||
field_name: str,
|
||||
field_info: Any,
|
||||
field_type: str,
|
||||
table_id: str,
|
||||
record_id: str,
|
||||
) -> tuple[list[Section], dict[str, Any]]:
|
||||
"""
|
||||
Process a single Airtable field and return sections or metadata.
|
||||
|
||||
Args:
|
||||
field_name: Name of the field
|
||||
field_info: Raw field information from Airtable
|
||||
field_type: Airtable field type
|
||||
|
||||
Returns:
|
||||
(list of Sections, dict of metadata)
|
||||
"""
|
||||
if field_info is None:
|
||||
return [], {}
|
||||
|
||||
# Get the value(s) for the field
|
||||
field_values = self._get_field_value(field_info, field_type)
|
||||
if len(field_values) == 0:
|
||||
return [], {}
|
||||
|
||||
# Determine if it should be metadata or a section
|
||||
if self._should_be_metadata(field_type):
|
||||
if len(field_values) > 1:
|
||||
return [], {field_name: field_values}
|
||||
return [], {field_name: field_values[0]}
|
||||
|
||||
# Otherwise, create relevant sections
|
||||
sections = [
|
||||
Section(
|
||||
link=f"https://airtable.com/{self.base_id}/{table_id}/{record_id}",
|
||||
text=(
|
||||
f"{field_name}:\n"
|
||||
"------------------------\n"
|
||||
f"{text}\n"
|
||||
"------------------------"
|
||||
),
|
||||
)
|
||||
for text in field_values
|
||||
]
|
||||
return sections, {}
|
||||
|
||||
def _process_record(
|
||||
self,
|
||||
record: RecordDict,
|
||||
table_schema: TableSchema,
|
||||
primary_field_name: str | None,
|
||||
) -> Document:
|
||||
"""Process a single Airtable record into a Document.
|
||||
|
||||
Args:
|
||||
record: The Airtable record to process
|
||||
table_schema: Schema information for the table
|
||||
table_name: Name of the table
|
||||
table_id: ID of the table
|
||||
primary_field_name: Name of the primary field, if any
|
||||
|
||||
Returns:
|
||||
Document object representing the record
|
||||
"""
|
||||
table_id = table_schema.id
|
||||
table_name = table_schema.name
|
||||
record_id = record["id"]
|
||||
fields = record["fields"]
|
||||
sections: list[Section] = []
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
# Get primary field value if it exists
|
||||
primary_field_value = (
|
||||
fields.get(primary_field_name) if primary_field_name else None
|
||||
)
|
||||
|
||||
for field_schema in table_schema.fields:
|
||||
field_name = field_schema.name
|
||||
field_val = fields.get(field_name)
|
||||
field_type = field_schema.type
|
||||
|
||||
field_sections, field_metadata = self._process_field(
|
||||
field_name=field_name,
|
||||
field_info=field_val,
|
||||
field_type=field_type,
|
||||
table_id=table_id,
|
||||
record_id=record_id,
|
||||
)
|
||||
|
||||
sections.extend(field_sections)
|
||||
metadata.update(field_metadata)
|
||||
|
||||
semantic_id = (
|
||||
f"{table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
else table_name
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=f"airtable__{record_id}",
|
||||
sections=sections,
|
||||
source=DocumentSource.AIRTABLE,
|
||||
semantic_identifier=semantic_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Fetch all records from the table.
|
||||
|
||||
NOTE: Airtable does not support filtering by time updated, so
|
||||
we have to fetch all records every time.
|
||||
"""
|
||||
if not self.airtable_client:
|
||||
raise AirtableClientNotSetUpError()
|
||||
|
||||
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
|
||||
records = table.all()
|
||||
|
||||
table_schema = table.schema()
|
||||
primary_field_name = None
|
||||
|
||||
# Find a primary field from the schema
|
||||
for field in table_schema.fields:
|
||||
if field.id == table_schema.primary_field_id:
|
||||
primary_field_name = field.name
|
||||
break
|
||||
|
||||
record_documents: list[Document] = []
|
||||
for record in records:
|
||||
document = self._process_record(
|
||||
record=record,
|
||||
table_schema=table_schema,
|
||||
primary_field_name=primary_field_name,
|
||||
)
|
||||
record_documents.append(document)
|
||||
|
||||
if len(record_documents) >= self.batch_size:
|
||||
yield record_documents
|
||||
record_documents = []
|
||||
|
||||
if record_documents:
|
||||
yield record_documents
|
||||
@@ -52,29 +52,10 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
"space",
|
||||
"restrictions.read.restrictions.user",
|
||||
"restrictions.read.restrictions.group",
|
||||
"ancestors.restrictions.read.restrictions.user",
|
||||
"ancestors.restrictions.read.restrictions.group",
|
||||
]
|
||||
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
|
||||
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
|
||||
"png",
|
||||
"jpg",
|
||||
"jpeg",
|
||||
"gif",
|
||||
"mp4",
|
||||
"mov",
|
||||
"mp3",
|
||||
"wav",
|
||||
]
|
||||
_FULL_EXTENSION_FILTER_STRING = "".join(
|
||||
[
|
||||
f" and title!~'*.{extension}'"
|
||||
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
@@ -83,7 +64,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
is_cloud: bool,
|
||||
space: str = "",
|
||||
page_id: str = "",
|
||||
index_recursively: bool = False,
|
||||
index_recursively: bool = True,
|
||||
cql_query: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
@@ -101,25 +82,23 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
|
||||
"""
|
||||
If nothing is provided, we default to fetching all pages
|
||||
Only one or none of the following options should be specified so
|
||||
the order shouldn't matter
|
||||
However, we use elif to ensure that only of the following is enforced
|
||||
"""
|
||||
base_cql_page_query = "type=page"
|
||||
# if nothing is provided, we will fetch all pages
|
||||
cql_page_query = "type=page"
|
||||
if cql_query:
|
||||
base_cql_page_query = cql_query
|
||||
# if a cql_query is provided, we will use it to fetch the pages
|
||||
cql_page_query = cql_query
|
||||
elif page_id:
|
||||
# if a cql_query is not provided, we will use the page_id to fetch the page
|
||||
if index_recursively:
|
||||
base_cql_page_query += f" and (ancestor='{page_id}' or id='{page_id}')"
|
||||
cql_page_query += f" and ancestor='{page_id}'"
|
||||
else:
|
||||
base_cql_page_query += f" and id='{page_id}'"
|
||||
cql_page_query += f" and id='{page_id}'"
|
||||
elif space:
|
||||
uri_safe_space = quote(space)
|
||||
base_cql_page_query += f" and space='{uri_safe_space}'"
|
||||
# if no cql_query or page_id is provided, we will use the space to fetch the pages
|
||||
cql_page_query += f" and space='{quote(space)}'"
|
||||
|
||||
self.base_cql_page_query = base_cql_page_query
|
||||
self.cql_page_query = cql_page_query
|
||||
self.cql_time_filter = ""
|
||||
|
||||
self.cql_label_filter = ""
|
||||
if labels_to_skip:
|
||||
@@ -147,33 +126,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
return None
|
||||
|
||||
def _construct_page_query(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> str:
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
|
||||
# Add time filters
|
||||
if start:
|
||||
formatted_start_time = datetime.fromtimestamp(
|
||||
start, tz=self.timezone
|
||||
).strftime("%Y-%m-%d %H:%M")
|
||||
page_query += f" and lastmodified >= '{formatted_start_time}'"
|
||||
if end:
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
page_query += f" and lastmodified <= '{formatted_end_time}'"
|
||||
|
||||
return page_query
|
||||
|
||||
def _construct_attachment_query(self, confluence_page_id: str) -> str:
|
||||
attachment_query = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_query += self.cql_label_filter
|
||||
attachment_query += _FULL_EXTENSION_FILTER_STRING
|
||||
return attachment_query
|
||||
|
||||
def _get_comment_string_for_page_id(self, page_id: str) -> str:
|
||||
comment_string = ""
|
||||
|
||||
@@ -253,15 +205,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
metadata=doc_metadata,
|
||||
)
|
||||
|
||||
def _fetch_document_batches(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
|
||||
doc_batch: list[Document] = []
|
||||
confluence_page_ids: list[str] = []
|
||||
|
||||
page_query = self._construct_page_query(start, end)
|
||||
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
|
||||
logger.debug(f"page_query: {page_query}")
|
||||
# Fetch pages as Documents
|
||||
for page in self.confluence_client.paginated_cql_retrieval(
|
||||
@@ -280,10 +228,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
# Fetch attachments as Documents
|
||||
for confluence_page_id in confluence_page_ids:
|
||||
attachment_query = self._construct_attachment_query(confluence_page_id)
|
||||
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
# TODO: maybe should add time filter as well?
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=attachment_query,
|
||||
cql=attachment_cql,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
@@ -299,12 +248,17 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_document_batches()
|
||||
|
||||
def poll_source(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
return self._fetch_document_batches(start, end)
|
||||
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
|
||||
# Add time filters
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
|
||||
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
|
||||
return self._fetch_document_batches()
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
@@ -315,7 +269,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
|
||||
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
page_query = self.cql_page_query + self.cql_label_filter
|
||||
for page in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=page_query,
|
||||
expand=restrictions_expand,
|
||||
@@ -325,11 +279,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# These will be used by doc_sync.py to sync permissions
|
||||
page_restrictions = page.get("restrictions")
|
||||
page_space_key = page.get("space", {}).get("key")
|
||||
page_ancestors = page.get("ancestors", [])
|
||||
page_perm_sync_data = {
|
||||
"restrictions": page_restrictions or {},
|
||||
"space_key": page_space_key,
|
||||
"ancestors": page_ancestors or [],
|
||||
}
|
||||
|
||||
doc_metadata_list.append(
|
||||
@@ -342,9 +294,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
perm_sync_data=page_perm_sync_data,
|
||||
)
|
||||
)
|
||||
attachment_query = self._construct_attachment_query(page["id"])
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_query,
|
||||
cql=attachment_cql,
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import TypeVar
|
||||
|
||||
from dateutil.parser import parse
|
||||
|
||||
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
|
||||
from onyx.configs.constants import IGNORE_FOR_QA
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.utils.text_processing import is_valid_email
|
||||
@@ -72,10 +71,3 @@ def process_in_batches(
|
||||
|
||||
def get_metadata_keys_to_ignore() -> list[str]:
|
||||
return [IGNORE_FOR_QA]
|
||||
|
||||
|
||||
def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
|
||||
if CONNECTOR_LOCALHOST_OVERRIDE:
|
||||
# Used for development
|
||||
base_domain = CONNECTOR_LOCALHOST_OVERRIDE
|
||||
return f"{base_domain.strip('/')}/connector/oauth/callback/{connector_id}"
|
||||
|
||||
@@ -1,321 +0,0 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from discord import Client
|
||||
from discord.channel import TextChannel
|
||||
from discord.channel import Thread
|
||||
from discord.enums import MessageType
|
||||
from discord.flags import Intents
|
||||
from discord.message import Message as DiscordMessage
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_DISCORD_DOC_ID_PREFIX = "DISCORD_"
|
||||
_SNIPPET_LENGTH = 30
|
||||
|
||||
|
||||
def _convert_message_to_document(
|
||||
message: DiscordMessage,
|
||||
sections: list[Section],
|
||||
) -> Document:
|
||||
"""
|
||||
Convert a discord message to a document
|
||||
Sections are collected before calling this function because it relies on async
|
||||
calls to fetch the thread history if there is one
|
||||
"""
|
||||
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
semantic_substring = ""
|
||||
|
||||
# Only messages from TextChannels will make it here but we have to check for it anyways
|
||||
if isinstance(message.channel, TextChannel) and (
|
||||
channel_name := message.channel.name
|
||||
):
|
||||
metadata["Channel"] = channel_name
|
||||
semantic_substring += f" in Channel: #{channel_name}"
|
||||
|
||||
# Single messages dont have a title
|
||||
title = ""
|
||||
|
||||
# If there is a thread, add more detail to the metadata, title, and semantic identifier
|
||||
if isinstance(message.channel, Thread):
|
||||
# Threads do have a title
|
||||
title = message.channel.name
|
||||
|
||||
# If its a thread, update the metadata, title, and semantic_substring
|
||||
metadata["Thread"] = title
|
||||
|
||||
# Add more detail to the semantic identifier if available
|
||||
semantic_substring += f" in Thread: {title}"
|
||||
|
||||
snippet: str = (
|
||||
message.content[:_SNIPPET_LENGTH].rstrip() + "..."
|
||||
if len(message.content) > _SNIPPET_LENGTH
|
||||
else message.content
|
||||
)
|
||||
|
||||
semantic_identifier = f"{message.author.name} said{semantic_substring}: {snippet}"
|
||||
|
||||
return Document(
|
||||
id=f"{_DISCORD_DOC_ID_PREFIX}{message.id}",
|
||||
source=DocumentSource.DISCORD,
|
||||
semantic_identifier=semantic_identifier,
|
||||
doc_updated_at=message.edited_at,
|
||||
title=title,
|
||||
sections=sections,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_filtered_channels(
|
||||
discord_client: Client,
|
||||
server_ids: list[int] | None,
|
||||
channel_names: list[str] | None,
|
||||
) -> list[TextChannel]:
|
||||
filtered_channels: list[TextChannel] = []
|
||||
|
||||
for channel in discord_client.get_all_channels():
|
||||
if not channel.permissions_for(channel.guild.me).read_message_history:
|
||||
continue
|
||||
if not isinstance(channel, TextChannel):
|
||||
continue
|
||||
if server_ids and len(server_ids) > 0 and channel.guild.id not in server_ids:
|
||||
continue
|
||||
if channel_names and channel.name not in channel_names:
|
||||
continue
|
||||
filtered_channels.append(channel)
|
||||
|
||||
logger.info(f"Found {len(filtered_channels)} channels for the authenticated user")
|
||||
return filtered_channels
|
||||
|
||||
|
||||
async def _fetch_documents_from_channel(
|
||||
channel: TextChannel,
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
) -> AsyncIterable[Document]:
|
||||
# Discord's epoch starts at 2015-01-01
|
||||
discord_epoch = datetime(2015, 1, 1, tzinfo=timezone.utc)
|
||||
if start_time and start_time < discord_epoch:
|
||||
start_time = discord_epoch
|
||||
|
||||
async for channel_message in channel.history(
|
||||
after=start_time,
|
||||
before=end_time,
|
||||
):
|
||||
# Skip messages that are not the default type
|
||||
if channel_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections: list[Section] = [
|
||||
Section(
|
||||
text=channel_message.content,
|
||||
link=channel_message.jump_url,
|
||||
)
|
||||
]
|
||||
|
||||
yield _convert_message_to_document(channel_message, sections)
|
||||
|
||||
for active_thread in channel.threads:
|
||||
async for thread_message in active_thread.history(
|
||||
after=start_time,
|
||||
before=end_time,
|
||||
):
|
||||
# Skip messages that are not the default type
|
||||
if thread_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections = [
|
||||
Section(
|
||||
text=thread_message.content,
|
||||
link=thread_message.jump_url,
|
||||
)
|
||||
]
|
||||
|
||||
yield _convert_message_to_document(thread_message, sections)
|
||||
|
||||
async for archived_thread in channel.archived_threads():
|
||||
async for thread_message in archived_thread.history(
|
||||
after=start_time,
|
||||
before=end_time,
|
||||
):
|
||||
# Skip messages that are not the default type
|
||||
if thread_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections = [
|
||||
Section(
|
||||
text=thread_message.content,
|
||||
link=thread_message.jump_url,
|
||||
)
|
||||
]
|
||||
|
||||
yield _convert_message_to_document(thread_message, sections)
|
||||
|
||||
|
||||
def _manage_async_retrieval(
|
||||
token: str,
|
||||
requested_start_date_string: str,
|
||||
channel_names: list[str],
|
||||
server_ids: list[int],
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> Iterable[Document]:
|
||||
# parse requested_start_date_string to datetime
|
||||
pull_date: datetime | None = (
|
||||
datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(
|
||||
tzinfo=timezone.utc
|
||||
)
|
||||
if requested_start_date_string
|
||||
else None
|
||||
)
|
||||
|
||||
# Set start_time to the later of start and pull_date, or whichever is provided
|
||||
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
|
||||
|
||||
end_time: datetime | None = end
|
||||
|
||||
async def _async_fetch() -> AsyncIterable[Document]:
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
async with Client(intents=intents) as discord_client:
|
||||
asyncio.create_task(discord_client.start(token))
|
||||
await discord_client.wait_until_ready()
|
||||
|
||||
filtered_channels: list[TextChannel] = await _fetch_filtered_channels(
|
||||
discord_client=discord_client,
|
||||
server_ids=server_ids,
|
||||
channel_names=channel_names,
|
||||
)
|
||||
|
||||
for channel in filtered_channels:
|
||||
async for doc in _fetch_documents_from_channel(
|
||||
channel=channel,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
):
|
||||
yield doc
|
||||
|
||||
def run_and_yield() -> Iterable[Document]:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _async_fetch()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
# Run the coroutine to get the next document
|
||||
doc = loop.run_until_complete(next_coro)
|
||||
yield doc
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return run_and_yield()
|
||||
|
||||
|
||||
class DiscordConnector(PollConnector, LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
server_ids: list[str] = [],
|
||||
channel_names: list[str] = [],
|
||||
# YYYY-MM-DD
|
||||
start_date: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.channel_names: list[str] = channel_names if channel_names else []
|
||||
self.server_ids: list[int] = (
|
||||
[int(server_id) for server_id in server_ids] if server_ids else []
|
||||
)
|
||||
self._discord_bot_token: str | None = None
|
||||
self.requested_start_date_string: str = start_date or ""
|
||||
|
||||
@property
|
||||
def discord_bot_token(self) -> str:
|
||||
if self._discord_bot_token is None:
|
||||
raise ConnectorMissingCredentialError("Discord")
|
||||
return self._discord_bot_token
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._discord_bot_token = credentials["discord_bot_token"]
|
||||
return None
|
||||
|
||||
def _manage_doc_batching(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch = []
|
||||
for doc in _manage_async_retrieval(
|
||||
token=self.discord_bot_token,
|
||||
requested_start_date_string=self.requested_start_date_string,
|
||||
channel_names=self.channel_names,
|
||||
server_ids=self.server_ids,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
return self._manage_doc_batching(
|
||||
datetime.fromtimestamp(start, tz=timezone.utc),
|
||||
datetime.fromtimestamp(end, tz=timezone.utc),
|
||||
)
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._manage_doc_batching(None, None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import time
|
||||
|
||||
end = time.time()
|
||||
# 1 day
|
||||
start = end - 24 * 60 * 60 * 1
|
||||
# "1,2,3"
|
||||
server_ids: str | None = os.environ.get("server_ids", None)
|
||||
# "channel1,channel2"
|
||||
channel_names: str | None = os.environ.get("channel_names", None)
|
||||
|
||||
connector = DiscordConnector(
|
||||
server_ids=server_ids.split(",") if server_ids else [],
|
||||
channel_names=channel_names.split(",") if channel_names else [],
|
||||
start_date=os.environ.get("start_date", None),
|
||||
)
|
||||
connector.load_credentials(
|
||||
{"discord_bot_token": os.environ.get("discord_bot_token")}
|
||||
)
|
||||
|
||||
for doc_batch in connector.poll_source(start, end):
|
||||
for doc in doc_batch:
|
||||
print(doc)
|
||||
@@ -190,7 +190,7 @@ class DiscourseConnector(PollConnector):
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
page = 0
|
||||
page = 1
|
||||
while topic_ids := self._get_latest_topics(start, end, page):
|
||||
doc_batch: list[Document] = []
|
||||
for topic_id in topic_ids:
|
||||
|
||||
@@ -3,19 +3,20 @@ import os
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from logging import Logger
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
from urllib.parse import quote
|
||||
|
||||
from pydantic import Field
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN
|
||||
from onyx.configs.app_configs import EGNYTE_CLIENT_ID
|
||||
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_oauth_callback_uri,
|
||||
)
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import OAuthConnector
|
||||
@@ -32,13 +33,53 @@ from onyx.file_processing.extract_file_text import is_text_file_extension
|
||||
from onyx.file_processing.extract_file_text import is_valid_file_ext
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import request_with_retries
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
|
||||
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
|
||||
_TIMEOUT = 60
|
||||
|
||||
|
||||
def _request_with_retries(
|
||||
method: str,
|
||||
url: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
headers: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
timeout: int = _TIMEOUT,
|
||||
stream: bool = False,
|
||||
tries: int = 8,
|
||||
delay: float = 1,
|
||||
backoff: float = 2,
|
||||
) -> requests.Response:
|
||||
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
|
||||
def _make_request() -> requests.Response:
|
||||
response = requests.request(
|
||||
method,
|
||||
url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code != 403:
|
||||
logger.exception(
|
||||
f"Failed to call Egnyte API.\n"
|
||||
f"URL: {url}\n"
|
||||
f"Headers: {headers}\n"
|
||||
f"Data: {data}\n"
|
||||
f"Params: {params}"
|
||||
)
|
||||
raise e
|
||||
return response
|
||||
|
||||
return _make_request()
|
||||
|
||||
|
||||
def _parse_last_modified(last_modified: str) -> datetime:
|
||||
@@ -125,15 +166,6 @@ def _process_egnyte_file(
|
||||
|
||||
|
||||
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
class AdditionalOauthKwargs(OAuthConnector.AdditionalOauthKwargs):
|
||||
egnyte_domain: str = Field(
|
||||
title="Egnyte Domain",
|
||||
description=(
|
||||
"The domain for the Egnyte instance "
|
||||
"(e.g. 'company' for company.egnyte.com)"
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
folder_path: str | None = None,
|
||||
@@ -149,20 +181,18 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
return DocumentSource.EGNYTE
|
||||
|
||||
@classmethod
|
||||
def oauth_authorization_url(
|
||||
cls,
|
||||
base_domain: str,
|
||||
state: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> str:
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
if not EGNYTE_CLIENT_ID:
|
||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
|
||||
if EGNYTE_LOCALHOST_OVERRIDE:
|
||||
base_domain = EGNYTE_LOCALHOST_OVERRIDE
|
||||
|
||||
callback_uri = get_oauth_callback_uri(base_domain, "egnyte")
|
||||
callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte"
|
||||
return (
|
||||
f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
|
||||
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
f"?client_id={EGNYTE_CLIENT_ID}"
|
||||
f"&redirect_uri={callback_uri}"
|
||||
f"&scope=Egnyte.filesystem"
|
||||
@@ -171,23 +201,17 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def oauth_code_to_token(
|
||||
cls,
|
||||
base_domain: str,
|
||||
code: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
|
||||
if not EGNYTE_CLIENT_ID:
|
||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||
if not EGNYTE_CLIENT_SECRET:
|
||||
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
|
||||
|
||||
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
# Exchange code for token
|
||||
url = f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
|
||||
redirect_uri = get_oauth_callback_uri(base_domain, "egnyte")
|
||||
|
||||
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
redirect_uri = f"{EGNYTE_LOCALHOST_OVERRIDE or base_domain}/connector/oauth/callback/egnyte"
|
||||
data = {
|
||||
"client_id": EGNYTE_CLIENT_ID,
|
||||
"client_secret": EGNYTE_CLIENT_SECRET,
|
||||
@@ -198,7 +222,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
}
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
response = request_with_retries(
|
||||
response = _request_with_retries(
|
||||
method="POST",
|
||||
url=url,
|
||||
data=data,
|
||||
@@ -212,7 +236,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
|
||||
token_data = response.json()
|
||||
return {
|
||||
"domain": oauth_kwargs.egnyte_domain,
|
||||
"domain": EGNYTE_BASE_DOMAIN,
|
||||
"access_token": token_data["access_token"],
|
||||
}
|
||||
|
||||
@@ -236,10 +260,9 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
"list_content": True,
|
||||
}
|
||||
|
||||
url_encoded_path = quote(path or "")
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
|
||||
response = request_with_retries(
|
||||
method="GET", url=url, headers=headers, params=params
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{path or ''}"
|
||||
response = _request_with_retries(
|
||||
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
|
||||
)
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
|
||||
@@ -292,12 +315,12 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
}
|
||||
url_encoded_path = quote(file["path"])
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
|
||||
response = request_with_retries(
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}"
|
||||
response = _request_with_retries(
|
||||
method="GET",
|
||||
url=url,
|
||||
headers=headers,
|
||||
timeout=_TIMEOUT,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,14 +5,12 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSourceRequiringTenantContext
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
from onyx.connectors.bookstack.connector import BookstackConnector
|
||||
from onyx.connectors.clickup.connector import ClickupConnector
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.discord.connector import DiscordConnector
|
||||
from onyx.connectors.discourse.connector import DiscourseConnector
|
||||
from onyx.connectors.document360.connector import Document360Connector
|
||||
from onyx.connectors.dropbox.connector import DropboxConnector
|
||||
@@ -102,11 +100,9 @@ def identify_connector_class(
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.OCI_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.XENFORO: XenforoConnector,
|
||||
DocumentSource.DISCORD: DiscordConnector,
|
||||
DocumentSource.FRESHDESK: FreshdeskConnector,
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
DocumentSource.AIRTABLE: AirtableConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Dict
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -250,36 +249,17 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
return new_creds_dict
|
||||
|
||||
def _get_all_user_emails(self) -> list[str]:
|
||||
"""
|
||||
List all user emails if we are on a Google Workspace domain.
|
||||
If the domain is gmail.com, or if we attempt to call the Admin SDK and
|
||||
get a 404, fall back to using the single user.
|
||||
"""
|
||||
|
||||
try:
|
||||
admin_service = get_admin_service(self.creds, self.primary_admin_email)
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
|
||||
except HttpError as e:
|
||||
if e.resp.status == 404:
|
||||
logger.warning(
|
||||
"Received 404 from Admin SDK; this may indicate a personal Gmail account "
|
||||
"with no Workspace domain. Falling back to single user."
|
||||
)
|
||||
return [self.primary_admin_email]
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
raise
|
||||
admin_service = get_admin_service(self.creds, self.primary_admin_email)
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
|
||||
def _fetch_threads(
|
||||
self,
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import cast
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
|
||||
@@ -21,7 +20,6 @@ from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
|
||||
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
|
||||
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_utils.google_auth import get_google_creds
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
@@ -43,7 +41,6 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
|
||||
@@ -289,30 +286,13 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
logger.info(f"Impersonating user {user_email}")
|
||||
|
||||
drive_service = get_drive_service(self.creds, user_email)
|
||||
|
||||
# validate that the user has access to the drive APIs by performing a simple
|
||||
# request and checking for a 401
|
||||
try:
|
||||
retry_builder()(get_root_folder_id)(drive_service)
|
||||
except HttpError as e:
|
||||
if e.status_code == 401:
|
||||
# fail gracefully, let the other impersonations continue
|
||||
# one user without access shouldn't block the entire connector
|
||||
logger.exception(
|
||||
f"User '{user_email}' does not have access to the drive APIs."
|
||||
)
|
||||
return
|
||||
raise
|
||||
|
||||
# if we are including my drives, try to get the current user's my
|
||||
# drive if any of the following are true:
|
||||
# - include_my_drives is true
|
||||
# - the current user's email is in the requested emails
|
||||
if self.include_my_drives or user_email in self._requested_my_drive_emails:
|
||||
logger.info(f"Getting all files in my drive as '{user_email}'")
|
||||
yield from get_all_files_in_my_drive(
|
||||
service=drive_service,
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
@@ -323,7 +303,6 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
remaining_drive_ids = filtered_drive_ids - self._retrieved_ids
|
||||
for drive_id in remaining_drive_ids:
|
||||
logger.info(f"Getting files in shared drive '{drive_id}' as '{user_email}'")
|
||||
yield from get_files_in_shared_drive(
|
||||
service=drive_service,
|
||||
drive_id=drive_id,
|
||||
@@ -335,7 +314,6 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
remaining_folders = filtered_folder_ids - self._retrieved_ids
|
||||
for folder_id in remaining_folders:
|
||||
logger.info(f"Getting files in folder '{folder_id}' as '{user_email}'")
|
||||
yield from crawl_folders_for_files(
|
||||
service=drive_service,
|
||||
parent_id=folder_id,
|
||||
@@ -366,15 +344,6 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
elif self.include_shared_drives:
|
||||
drive_ids_to_retrieve = all_drive_ids
|
||||
|
||||
# checkpoint - we've found all users and drives, now time to actually start
|
||||
# fetching stuff
|
||||
logger.info(f"Found {len(all_org_emails)} users to impersonate")
|
||||
logger.debug(f"Users: {all_org_emails}")
|
||||
logger.info(f"Found {len(drive_ids_to_retrieve)} drives to retrieve")
|
||||
logger.debug(f"Drives: {drive_ids_to_retrieve}")
|
||||
logger.info(f"Found {len(folder_ids_to_retrieve)} folders to retrieve")
|
||||
logger.debug(f"Folders: {folder_ids_to_retrieve}")
|
||||
|
||||
# Process users in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
future_to_email = {
|
||||
@@ -411,13 +380,6 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
drive_service = get_drive_service(self.creds, self.primary_admin_email)
|
||||
|
||||
if self.include_files_shared_with_me or self.include_my_drives:
|
||||
logger.info(
|
||||
f"Getting shared files/my drive files for OAuth "
|
||||
f"with include_files_shared_with_me={self.include_files_shared_with_me}, "
|
||||
f"include_my_drives={self.include_my_drives}, "
|
||||
f"include_shared_drives={self.include_shared_drives}."
|
||||
f"Using '{self.primary_admin_email}' as the account."
|
||||
)
|
||||
yield from get_all_files_for_oauth(
|
||||
service=drive_service,
|
||||
include_files_shared_with_me=self.include_files_shared_with_me,
|
||||
@@ -450,9 +412,6 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
drive_ids_to_retrieve = all_drive_ids
|
||||
|
||||
for drive_id in drive_ids_to_retrieve:
|
||||
logger.info(
|
||||
f"Getting files in shared drive '{drive_id}' as '{self.primary_admin_email}'"
|
||||
)
|
||||
yield from get_files_in_shared_drive(
|
||||
service=drive_service,
|
||||
drive_id=drive_id,
|
||||
@@ -466,9 +425,6 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# that could be folders.
|
||||
remaining_folders = folder_ids_to_retrieve - self._retrieved_ids
|
||||
for folder_id in remaining_folders:
|
||||
logger.info(
|
||||
f"Getting files in folder '{folder_id}' as '{self.primary_admin_email}'"
|
||||
)
|
||||
yield from crawl_folders_for_files(
|
||||
service=drive_service,
|
||||
parent_id=folder_id,
|
||||
|
||||
@@ -2,8 +2,6 @@ import abc
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
@@ -68,10 +66,6 @@ class SlimConnector(BaseConnector):
|
||||
|
||||
|
||||
class OAuthConnector(BaseConnector):
|
||||
class AdditionalOauthKwargs(BaseModel):
|
||||
# if overridden, all fields should be str type
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_id(cls) -> DocumentSource:
|
||||
@@ -79,22 +73,12 @@ class OAuthConnector(BaseConnector):
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_authorization_url(
|
||||
cls,
|
||||
base_domain: str,
|
||||
state: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> str:
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_code_to_token(
|
||||
cls,
|
||||
base_domain: str,
|
||||
code: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -7,23 +7,16 @@ from typing import cast
|
||||
import requests
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import LINEAR_CLIENT_ID
|
||||
from onyx.configs.app_configs import LINEAR_CLIENT_SECRET
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_oauth_callback_uri,
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import OAuthConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import request_with_retries
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -64,7 +57,7 @@ def _make_query(request_body: dict[str, Any], api_key: str) -> requests.Response
|
||||
)
|
||||
|
||||
|
||||
class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
class LinearConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
@@ -72,68 +65,8 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
self.batch_size = batch_size
|
||||
self.linear_api_key: str | None = None
|
||||
|
||||
@classmethod
|
||||
def oauth_id(cls) -> DocumentSource:
|
||||
return DocumentSource.LINEAR
|
||||
|
||||
@classmethod
|
||||
def oauth_authorization_url(
|
||||
cls, base_domain: str, state: str, additional_kwargs: dict[str, str]
|
||||
) -> str:
|
||||
if not LINEAR_CLIENT_ID:
|
||||
raise ValueError("LINEAR_CLIENT_ID environment variable must be set")
|
||||
|
||||
callback_uri = get_oauth_callback_uri(base_domain, DocumentSource.LINEAR.value)
|
||||
return (
|
||||
f"https://linear.app/oauth/authorize"
|
||||
f"?client_id={LINEAR_CLIENT_ID}"
|
||||
f"&redirect_uri={callback_uri}"
|
||||
f"&response_type=code"
|
||||
f"&scope=read"
|
||||
f"&state={state}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def oauth_code_to_token(
|
||||
cls, base_domain: str, code: str, additional_kwargs: dict[str, str]
|
||||
) -> dict[str, Any]:
|
||||
data = {
|
||||
"code": code,
|
||||
"redirect_uri": get_oauth_callback_uri(
|
||||
base_domain, DocumentSource.LINEAR.value
|
||||
),
|
||||
"client_id": LINEAR_CLIENT_ID,
|
||||
"client_secret": LINEAR_CLIENT_SECRET,
|
||||
"grant_type": "authorization_code",
|
||||
}
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
response = request_with_retries(
|
||||
method="POST",
|
||||
url="https://api.linear.app/oauth/token",
|
||||
data=data,
|
||||
headers=headers,
|
||||
backoff=0,
|
||||
delay=0.1,
|
||||
)
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failed to exchange code for token: {response.text}")
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
return {
|
||||
"access_token": token_data["access_token"],
|
||||
}
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
if "linear_api_key" in credentials:
|
||||
self.linear_api_key = cast(str, credentials["linear_api_key"])
|
||||
elif "access_token" in credentials:
|
||||
self.linear_api_key = "Bearer " + cast(str, credentials["access_token"])
|
||||
else:
|
||||
# May need to handle case in the future if the OAuth flow expires
|
||||
raise ConnectorMissingCredentialError("Linear")
|
||||
|
||||
self.linear_api_key = cast(str, credentials["linear_api_key"])
|
||||
return None
|
||||
|
||||
def _process_issues(
|
||||
|
||||
@@ -101,11 +101,8 @@ class DocumentBase(BaseModel):
|
||||
source: DocumentSource | None = None
|
||||
semantic_identifier: str # displayed in the UI as the main identifier for the doc
|
||||
metadata: dict[str, str | list[str]]
|
||||
|
||||
# UTC time
|
||||
doc_updated_at: datetime | None = None
|
||||
chunk_count: int | None = None
|
||||
|
||||
# Owner, creator, etc.
|
||||
primary_owners: list[BasicExpertInfo] | None = None
|
||||
# Assignee, space owner, etc.
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from simple_salesforce import Salesforce
|
||||
from simple_salesforce import SFType
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -15,25 +19,24 @@ from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.salesforce.doc_conversion import extract_section
|
||||
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
|
||||
from onyx.connectors.salesforce.salesforce_calls import get_all_children_of_sf_type
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_record
|
||||
from onyx.connectors.salesforce.sqlite_functions import init_db
|
||||
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.connectors.salesforce.utils import extract_dict_text
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
# TODO: this connector does not work well at large scales
|
||||
# the large query against a large Salesforce instance has been reported to take 1.5 hours.
|
||||
# Additionally it seems to eat up more memory over time if the connection is long running (again a scale issue).
|
||||
|
||||
|
||||
DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
MAX_QUERY_LENGTH = 10000 # max query length is 20,000 characters
|
||||
ID_PREFIX = "SALESFORCE_"
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
_ID_PREFIX = "SALESFORCE_"
|
||||
|
||||
|
||||
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -41,170 +44,200 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
requested_objects: list[str] = [],
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self._sf_client: Salesforce | None = None
|
||||
self.sf_client: Salesforce | None = None
|
||||
self.parent_object_list = (
|
||||
[obj.capitalize() for obj in requested_objects]
|
||||
if requested_objects
|
||||
else _DEFAULT_PARENT_OBJECT_TYPES
|
||||
else DEFAULT_PARENT_OBJECT_TYPES
|
||||
)
|
||||
|
||||
def load_credentials(
|
||||
self,
|
||||
credentials: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
self._sf_client = Salesforce(
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.sf_client = Salesforce(
|
||||
username=credentials["sf_username"],
|
||||
password=credentials["sf_password"],
|
||||
security_token=credentials["sf_security_token"],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def sf_client(self) -> Salesforce:
|
||||
if self._sf_client is None:
|
||||
def _get_sf_type_object_json(self, type_name: str) -> Any:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
return self._sf_client
|
||||
sf_object = SFType(
|
||||
type_name, self.sf_client.session_id, self.sf_client.sf_instance
|
||||
)
|
||||
return sf_object.describe()
|
||||
|
||||
def _extract_primary_owners(
|
||||
self, sf_object: SalesforceObject
|
||||
) -> list[BasicExpertInfo] | None:
|
||||
object_dict = sf_object.data
|
||||
if not (last_modified_by_id := object_dict.get("LastModifiedById")):
|
||||
return None
|
||||
if not (last_modified_by := get_record(last_modified_by_id)):
|
||||
return None
|
||||
if not (last_modified_by_name := last_modified_by.data.get("Name")):
|
||||
return None
|
||||
primary_owners = [BasicExpertInfo(display_name=last_modified_by_name)]
|
||||
return primary_owners
|
||||
def _get_name_from_id(self, id: str) -> str:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
try:
|
||||
user_object_info = self.sf_client.query(
|
||||
f"SELECT Name FROM User WHERE Id = '{id}'"
|
||||
)
|
||||
name = user_object_info.get("Records", [{}])[0].get("Name", "Null User")
|
||||
return name
|
||||
except Exception:
|
||||
logger.warning(f"Couldnt find name for object id: {id}")
|
||||
return "Null User"
|
||||
|
||||
def _convert_object_instance_to_document(
|
||||
self, sf_object: SalesforceObject
|
||||
self, object_dict: dict[str, Any]
|
||||
) -> Document:
|
||||
object_dict = sf_object.data
|
||||
salesforce_id = object_dict["Id"]
|
||||
onyx_salesforce_id = f"{_ID_PREFIX}{salesforce_id}"
|
||||
base_url = f"https://{self.sf_client.sf_instance}"
|
||||
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
|
||||
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
sections = [extract_section(sf_object, base_url)]
|
||||
for id in get_child_ids(sf_object.id):
|
||||
if not (child_object := get_record(id)):
|
||||
continue
|
||||
sections.append(extract_section(child_object, base_url))
|
||||
salesforce_id = object_dict["Id"]
|
||||
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
|
||||
extracted_link = f"https://{self.sf_client.sf_instance}/{salesforce_id}"
|
||||
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
|
||||
extracted_object_text = extract_dict_text(object_dict)
|
||||
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
|
||||
extracted_primary_owners = [
|
||||
BasicExpertInfo(
|
||||
display_name=self._get_name_from_id(object_dict["LastModifiedById"])
|
||||
)
|
||||
]
|
||||
|
||||
doc = Document(
|
||||
id=onyx_salesforce_id,
|
||||
sections=sections,
|
||||
sections=[Section(link=extracted_link, text=extracted_object_text)],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier=extracted_semantic_identifier,
|
||||
doc_updated_at=extracted_doc_updated_at,
|
||||
primary_owners=self._extract_primary_owners(sf_object),
|
||||
primary_owners=extracted_primary_owners,
|
||||
metadata={},
|
||||
)
|
||||
return doc
|
||||
|
||||
def _is_valid_child_object(self, child_relationship: dict) -> bool:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
if not child_relationship["childSObject"]:
|
||||
return False
|
||||
if not child_relationship["relationshipName"]:
|
||||
return False
|
||||
|
||||
sf_type = child_relationship["childSObject"]
|
||||
object_description = self._get_sf_type_object_json(sf_type)
|
||||
if not object_description["queryable"]:
|
||||
return False
|
||||
|
||||
try:
|
||||
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
|
||||
result = self.sf_client.query(query)
|
||||
if result["totalSize"] == 0:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
|
||||
return False
|
||||
|
||||
if child_relationship["field"]:
|
||||
if child_relationship["field"] == "RelatedToId":
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_all_children_of_sf_type(self, sf_type: str) -> list[dict]:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
object_description = self._get_sf_type_object_json(sf_type)
|
||||
|
||||
children_objects: list[dict] = []
|
||||
for child_relationship in object_description["childRelationships"]:
|
||||
if self._is_valid_child_object(child_relationship):
|
||||
children_objects.append(
|
||||
{
|
||||
"relationship_name": child_relationship["relationshipName"],
|
||||
"object_type": child_relationship["childSObject"],
|
||||
}
|
||||
)
|
||||
return children_objects
|
||||
|
||||
def _get_all_fields_for_sf_type(self, sf_type: str) -> list[str]:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
object_description = self._get_sf_type_object_json(sf_type)
|
||||
|
||||
fields = [
|
||||
field.get("name")
|
||||
for field in object_description["fields"]
|
||||
if field.get("type", "base64") != "base64"
|
||||
]
|
||||
|
||||
return fields
|
||||
|
||||
def _generate_query_per_parent_type(self, parent_sf_type: str) -> Iterator[str]:
|
||||
"""
|
||||
This function takes in an object_type and generates query(s) designed to grab
|
||||
information associated to objects of that type.
|
||||
It does that by getting all the fields of the parent object type.
|
||||
Then it gets all the child objects of that object type and all the fields of
|
||||
those children as well.
|
||||
"""
|
||||
parent_fields = self._get_all_fields_for_sf_type(parent_sf_type)
|
||||
child_sf_types = self._get_all_children_of_sf_type(parent_sf_type)
|
||||
|
||||
query = f"SELECT {', '.join(parent_fields)}"
|
||||
for child_object_dict in child_sf_types:
|
||||
fields = self._get_all_fields_for_sf_type(child_object_dict["object_type"])
|
||||
query_addition = f", \n(SELECT {', '.join(fields)} FROM {child_object_dict['relationship_name']})"
|
||||
|
||||
if len(query_addition) + len(query) > MAX_QUERY_LENGTH:
|
||||
query += f"\n FROM {parent_sf_type}"
|
||||
yield query
|
||||
query = "SELECT Id" + query_addition
|
||||
else:
|
||||
query += query_addition
|
||||
|
||||
query += f"\n FROM {parent_sf_type}"
|
||||
|
||||
yield query
|
||||
|
||||
def _fetch_from_salesforce(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
init_db()
|
||||
all_object_types: set[str] = set(self.parent_object_list)
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
logger.info(f"Starting with {len(self.parent_object_list)} parent object types")
|
||||
logger.debug(f"Parent object types: {self.parent_object_list}")
|
||||
|
||||
# This takes like 20 seconds
|
||||
doc_batch: list[Document] = []
|
||||
for parent_object_type in self.parent_object_list:
|
||||
child_types = get_all_children_of_sf_type(
|
||||
self.sf_client, parent_object_type
|
||||
)
|
||||
all_object_types.update(child_types)
|
||||
logger.debug(
|
||||
f"Found {len(child_types)} child types for {parent_object_type}"
|
||||
)
|
||||
logger.debug(f"Processing: {parent_object_type}")
|
||||
|
||||
logger.info(f"Found total of {len(all_object_types)} object types to fetch")
|
||||
logger.debug(f"All object types: {all_object_types}")
|
||||
query_results: dict = {}
|
||||
for query in self._generate_query_per_parent_type(parent_object_type):
|
||||
if start is not None and end is not None:
|
||||
if start and start.tzinfo is None:
|
||||
start = start.replace(tzinfo=timezone.utc)
|
||||
if end and end.tzinfo is None:
|
||||
end = end.replace(tzinfo=timezone.utc)
|
||||
query += f" WHERE LastModifiedDate > {start.isoformat()} AND LastModifiedDate < {end.isoformat()}"
|
||||
|
||||
# checkpoint - we've found all object types, now time to fetch the data
|
||||
logger.info("Starting to fetch CSVs for all object types")
|
||||
# This takes like 30 minutes first time and <2 minutes for updates
|
||||
object_type_to_csv_path = fetch_all_csvs_in_parallel(
|
||||
sf_client=self.sf_client,
|
||||
object_types=all_object_types,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
query_result = self.sf_client.query_all(query)
|
||||
|
||||
updated_ids: set[str] = set()
|
||||
# This takes like 10 seconds
|
||||
# This is for testing the rest of the functionality if data has
|
||||
# already been fetched and put in sqlite
|
||||
# from import onyx.connectors.salesforce.sf_db.sqlite_functions find_ids_by_type
|
||||
# for object_type in self.parent_object_list:
|
||||
# updated_ids.update(list(find_ids_by_type(object_type)))
|
||||
for record_dict in query_result["records"]:
|
||||
query_results.setdefault(record_dict["Id"], {}).update(record_dict)
|
||||
|
||||
# This takes 10-70 minutes first time (idk why the range is so big)
|
||||
total_types = len(object_type_to_csv_path)
|
||||
logger.info(f"Starting to process {total_types} object types")
|
||||
|
||||
for i, (object_type, csv_paths) in enumerate(
|
||||
object_type_to_csv_path.items(), 1
|
||||
):
|
||||
logger.info(f"Processing object type {object_type} ({i}/{total_types})")
|
||||
# If path is None, it means it failed to fetch the csv
|
||||
if csv_paths is None:
|
||||
continue
|
||||
# Go through each csv path and use it to update the db
|
||||
for csv_path in csv_paths:
|
||||
logger.debug(f"Updating {object_type} with {csv_path}")
|
||||
new_ids = update_sf_db_with_csv(
|
||||
object_type=object_type,
|
||||
csv_download_path=csv_path,
|
||||
)
|
||||
updated_ids.update(new_ids)
|
||||
logger.debug(
|
||||
f"Added {len(new_ids)} new/updated records for {object_type}"
|
||||
)
|
||||
# Remove the csv file after it has been used
|
||||
# to successfully update the db
|
||||
os.remove(csv_path)
|
||||
|
||||
logger.info(f"Found {len(updated_ids)} total updated records")
|
||||
logger.info(
|
||||
f"Starting to process parent objects of types: {self.parent_object_list}"
|
||||
)
|
||||
|
||||
docs_to_yield: list[Document] = []
|
||||
docs_processed = 0
|
||||
# Takes 15-20 seconds per batch
|
||||
for parent_type, parent_id_batch in get_affected_parent_ids_by_type(
|
||||
updated_ids=list(updated_ids),
|
||||
parent_types=self.parent_object_list,
|
||||
):
|
||||
logger.info(
|
||||
f"Processing batch of {len(parent_id_batch)} {parent_type} objects"
|
||||
f"Number of {parent_object_type} Objects processed: {len(query_results)}"
|
||||
)
|
||||
for parent_id in parent_id_batch:
|
||||
if not (parent_object := get_record(parent_id, parent_type)):
|
||||
logger.warning(
|
||||
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||
)
|
||||
continue
|
||||
|
||||
docs_to_yield.append(
|
||||
self._convert_object_instance_to_document(parent_object)
|
||||
for combined_object_dict in query_results.values():
|
||||
doc_batch.append(
|
||||
self._convert_object_instance_to_document(combined_object_dict)
|
||||
)
|
||||
docs_processed += 1
|
||||
|
||||
if len(docs_to_yield) >= self.batch_size:
|
||||
yield docs_to_yield
|
||||
docs_to_yield = []
|
||||
|
||||
yield docs_to_yield
|
||||
if len(doc_batch) > self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_from_salesforce()
|
||||
@@ -212,20 +245,26 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
return self._fetch_from_salesforce(start=start, end=end)
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
start_datetime = datetime.utcfromtimestamp(start)
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
for parent_object_type in self.parent_object_list:
|
||||
query = f"SELECT Id FROM {parent_object_type}"
|
||||
query_result = self.sf_client.query_all(query)
|
||||
doc_metadata_list.extend(
|
||||
SlimDocument(
|
||||
id=f"{_ID_PREFIX}{instance_dict.get('Id', '')}",
|
||||
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",
|
||||
perm_sync_data={},
|
||||
)
|
||||
for instance_dict in query_result["records"]
|
||||
@@ -235,9 +274,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import time
|
||||
|
||||
connector = SalesforceConnector(requested_objects=["Account"])
|
||||
connector = SalesforceConnector(
|
||||
requested_objects=os.environ["REQUESTED_OBJECTS"].split(",")
|
||||
)
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
@@ -246,20 +285,5 @@ if __name__ == "__main__":
|
||||
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
|
||||
}
|
||||
)
|
||||
start_time = time.time()
|
||||
doc_count = 0
|
||||
section_count = 0
|
||||
text_count = 0
|
||||
for doc_batch in connector.load_from_state():
|
||||
doc_count += len(doc_batch)
|
||||
print(f"doc_count: {doc_count}")
|
||||
for doc in doc_batch:
|
||||
section_count += len(doc.sections)
|
||||
for section in doc.sections:
|
||||
text_count += len(section.text)
|
||||
end_time = time.time()
|
||||
|
||||
print(f"Doc count: {doc_count}")
|
||||
print(f"Section count: {section_count}")
|
||||
print(f"Text count: {text_count}")
|
||||
print(f"Time taken: {end_time - start_time}")
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
|
||||
# All of these types of keys are handled by specific fields in the doc
|
||||
# conversion process (E.g. URLs) or are not useful for the user (E.g. UUIDs)
|
||||
_SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
|
||||
|
||||
|
||||
def _clean_salesforce_dict(data: dict | list) -> dict | list:
|
||||
"""Clean and transform Salesforce API response data by recursively:
|
||||
1. Extracting records from the response if present
|
||||
2. Merging attributes into the main dictionary
|
||||
3. Filtering out keys matching certain patterns (Id, Date, stamp, url)
|
||||
4. Removing '__c' suffix from custom field names
|
||||
5. Removing None values and empty containers
|
||||
|
||||
Args:
|
||||
data: A dictionary or list from Salesforce API response
|
||||
|
||||
Returns:
|
||||
Cleaned dictionary or list with transformed keys and filtered values
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
if "records" in data.keys():
|
||||
data = data["records"]
|
||||
if isinstance(data, dict):
|
||||
if "attributes" in data.keys():
|
||||
if isinstance(data["attributes"], dict):
|
||||
data.update(data.pop("attributes"))
|
||||
|
||||
if isinstance(data, dict):
|
||||
filtered_dict = {}
|
||||
for key, value in data.items():
|
||||
if not re.search(_SF_JSON_FILTER, key, re.IGNORECASE):
|
||||
# remove the custom object indicator for display
|
||||
if "__c" in key:
|
||||
key = key[:-3]
|
||||
if isinstance(value, (dict, list)):
|
||||
filtered_value = _clean_salesforce_dict(value)
|
||||
# Only add non-empty dictionaries or lists
|
||||
if filtered_value:
|
||||
filtered_dict[key] = filtered_value
|
||||
elif value is not None:
|
||||
filtered_dict[key] = value
|
||||
return filtered_dict
|
||||
elif isinstance(data, list):
|
||||
filtered_list = []
|
||||
for item in data:
|
||||
if isinstance(item, (dict, list)):
|
||||
filtered_item = _clean_salesforce_dict(item)
|
||||
# Only add non-empty dictionaries or lists
|
||||
if filtered_item:
|
||||
filtered_list.append(filtered_item)
|
||||
elif item is not None:
|
||||
filtered_list.append(filtered_item)
|
||||
return filtered_list
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def _json_to_natural_language(data: dict | list, indent: int = 0) -> str:
|
||||
"""Convert a nested dictionary or list into a human-readable string format.
|
||||
|
||||
Recursively traverses the data structure and formats it with:
|
||||
- Key-value pairs on separate lines
|
||||
- Nested structures indented for readability
|
||||
- Lists and dictionaries handled with appropriate formatting
|
||||
|
||||
Args:
|
||||
data: The dictionary or list to convert
|
||||
indent: Number of spaces to indent (default: 0)
|
||||
|
||||
Returns:
|
||||
A formatted string representation of the data structure
|
||||
"""
|
||||
result = []
|
||||
indent_str = " " * indent
|
||||
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
result.append(f"{indent_str}{key}:")
|
||||
result.append(_json_to_natural_language(value, indent + 2))
|
||||
else:
|
||||
result.append(f"{indent_str}{key}: {value}")
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
result.append(_json_to_natural_language(item, indent + 2))
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def _extract_dict_text(raw_dict: dict) -> str:
|
||||
"""Extract text from a Salesforce API response dictionary by:
|
||||
1. Cleaning the dictionary
|
||||
2. Converting the cleaned dictionary to natural language
|
||||
"""
|
||||
processed_dict = _clean_salesforce_dict(raw_dict)
|
||||
natural_language_for_dict = _json_to_natural_language(processed_dict)
|
||||
return natural_language_for_dict
|
||||
|
||||
|
||||
def extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
|
||||
return Section(
|
||||
text=_extract_dict_text(salesforce_object.data),
|
||||
link=f"{base_url}/{salesforce_object.id}",
|
||||
)
|
||||
|
||||
|
||||
def _field_value_is_child_object(field_value: dict) -> bool:
|
||||
"""
|
||||
Checks if the field value is a child object.
|
||||
"""
|
||||
return (
|
||||
isinstance(field_value, OrderedDict)
|
||||
and "records" in field_value.keys()
|
||||
and isinstance(field_value["records"], list)
|
||||
and len(field_value["records"]) > 0
|
||||
and "Id" in field_value["records"][0].keys()
|
||||
)
|
||||
|
||||
|
||||
def _extract_sections(salesforce_object: dict, base_url: str) -> list[Section]:
|
||||
"""
|
||||
This goes through the salesforce_object and extracts the top level fields as a Section.
|
||||
It also goes through the child objects and extracts them as Sections.
|
||||
"""
|
||||
top_level_dict = {}
|
||||
|
||||
child_object_sections = []
|
||||
for field_name, field_value in salesforce_object.items():
|
||||
# If the field value is not a child object, add it to the top level dict
|
||||
# to turn into text for the top level section
|
||||
if not _field_value_is_child_object(field_value):
|
||||
top_level_dict[field_name] = field_value
|
||||
continue
|
||||
|
||||
# If the field value is a child object, extract the child objects and add them as sections
|
||||
for record in field_value["records"]:
|
||||
child_object_id = record["Id"]
|
||||
child_object_sections.append(
|
||||
Section(
|
||||
text=f"Child Object(s): {field_name}\n{_extract_dict_text(record)}",
|
||||
link=f"{base_url}/{child_object_id}",
|
||||
)
|
||||
)
|
||||
|
||||
top_level_id = salesforce_object["Id"]
|
||||
top_level_section = Section(
|
||||
text=_extract_dict_text(top_level_dict),
|
||||
link=f"{base_url}/{top_level_id}",
|
||||
)
|
||||
return [top_level_section, *child_object_sections]
|
||||
@@ -1,210 +0,0 @@
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pytz import UTC
|
||||
from simple_salesforce import Salesforce
|
||||
from simple_salesforce import SFType
|
||||
from simple_salesforce.bulk2 import SFBulk2Handler
|
||||
from simple_salesforce.bulk2 import SFBulk2Type
|
||||
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.salesforce.sqlite_functions import has_at_least_one_object_of_type
|
||||
from onyx.connectors.salesforce.utils import get_object_type_path
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_time_filter_for_salesforce(
|
||||
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> str:
|
||||
if start is None or end is None:
|
||||
return ""
|
||||
start_datetime = datetime.fromtimestamp(start, UTC)
|
||||
end_datetime = datetime.fromtimestamp(end, UTC)
|
||||
return (
|
||||
f" WHERE LastModifiedDate > {start_datetime.isoformat()} "
|
||||
f"AND LastModifiedDate < {end_datetime.isoformat()}"
|
||||
)
|
||||
|
||||
|
||||
def _get_sf_type_object_json(sf_client: Salesforce, type_name: str) -> Any:
|
||||
sf_object = SFType(type_name, sf_client.session_id, sf_client.sf_instance)
|
||||
return sf_object.describe()
|
||||
|
||||
|
||||
def _is_valid_child_object(
|
||||
sf_client: Salesforce, child_relationship: dict[str, Any]
|
||||
) -> bool:
|
||||
if not child_relationship["childSObject"]:
|
||||
return False
|
||||
if not child_relationship["relationshipName"]:
|
||||
return False
|
||||
|
||||
sf_type = child_relationship["childSObject"]
|
||||
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
||||
if not object_description["queryable"]:
|
||||
return False
|
||||
|
||||
if child_relationship["field"]:
|
||||
if child_relationship["field"] == "RelatedToId":
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_all_children_of_sf_type(sf_client: Salesforce, sf_type: str) -> set[str]:
|
||||
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
||||
|
||||
child_object_types = set()
|
||||
for child_relationship in object_description["childRelationships"]:
|
||||
if _is_valid_child_object(sf_client, child_relationship):
|
||||
logger.debug(
|
||||
f"Found valid child object {child_relationship['childSObject']}"
|
||||
)
|
||||
child_object_types.add(child_relationship["childSObject"])
|
||||
return child_object_types
|
||||
|
||||
|
||||
def _get_all_queryable_fields_of_sf_type(
|
||||
sf_client: Salesforce,
|
||||
sf_type: str,
|
||||
) -> list[str]:
|
||||
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
||||
fields: list[dict[str, Any]] = object_description["fields"]
|
||||
valid_fields: set[str] = set()
|
||||
compound_field_names: set[str] = set()
|
||||
for field in fields:
|
||||
if compound_field_name := field.get("compoundFieldName"):
|
||||
compound_field_names.add(compound_field_name)
|
||||
if field.get("type", "base64") == "base64":
|
||||
continue
|
||||
if field_name := field.get("name"):
|
||||
valid_fields.add(field_name)
|
||||
|
||||
return list(valid_fields - compound_field_names)
|
||||
|
||||
|
||||
def _check_if_object_type_is_empty(sf_client: Salesforce, sf_type: str) -> bool:
|
||||
"""
|
||||
Send a small query to check if the object type is empty so we don't
|
||||
perform extra bulk queries
|
||||
"""
|
||||
try:
|
||||
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
|
||||
result = sf_client.query(query)
|
||||
if result["totalSize"] == 0:
|
||||
return False
|
||||
except Exception as e:
|
||||
if "OPERATION_TOO_LARGE" not in str(e):
|
||||
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _check_for_existing_csvs(sf_type: str) -> list[str] | None:
|
||||
# Check if the csv already exists
|
||||
if os.path.exists(get_object_type_path(sf_type)):
|
||||
existing_csvs = [
|
||||
os.path.join(get_object_type_path(sf_type), f)
|
||||
for f in os.listdir(get_object_type_path(sf_type))
|
||||
if f.endswith(".csv")
|
||||
]
|
||||
# If the csv already exists, return the path
|
||||
# This is likely due to a previous run that failed
|
||||
# after downloading the csv but before the data was
|
||||
# written to the db
|
||||
if existing_csvs:
|
||||
return existing_csvs
|
||||
return None
|
||||
|
||||
|
||||
def _build_bulk_query(sf_client: Salesforce, sf_type: str, time_filter: str) -> str:
|
||||
queryable_fields = _get_all_queryable_fields_of_sf_type(sf_client, sf_type)
|
||||
query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}"
|
||||
return query
|
||||
|
||||
|
||||
def _bulk_retrieve_from_salesforce(
|
||||
sf_client: Salesforce,
|
||||
sf_type: str,
|
||||
time_filter: str,
|
||||
) -> tuple[str, list[str] | None]:
|
||||
if not _check_if_object_type_is_empty(sf_client, sf_type):
|
||||
return sf_type, None
|
||||
|
||||
if existing_csvs := _check_for_existing_csvs(sf_type):
|
||||
return sf_type, existing_csvs
|
||||
|
||||
query = _build_bulk_query(sf_client, sf_type, time_filter)
|
||||
|
||||
bulk_2_handler = SFBulk2Handler(
|
||||
session_id=sf_client.session_id,
|
||||
bulk2_url=sf_client.bulk2_url,
|
||||
proxies=sf_client.proxies,
|
||||
session=sf_client.session,
|
||||
)
|
||||
bulk_2_type = SFBulk2Type(
|
||||
object_name=sf_type,
|
||||
bulk2_url=bulk_2_handler.bulk2_url,
|
||||
headers=bulk_2_handler.headers,
|
||||
session=bulk_2_handler.session,
|
||||
)
|
||||
|
||||
logger.info(f"Downloading {sf_type}")
|
||||
logger.info(f"Query: {query}")
|
||||
|
||||
try:
|
||||
# This downloads the file to a file in the target path with a random name
|
||||
results = bulk_2_type.download(
|
||||
query=query,
|
||||
path=get_object_type_path(sf_type),
|
||||
max_records=1000000,
|
||||
)
|
||||
all_download_paths = [result["file"] for result in results]
|
||||
logger.info(f"Downloaded {sf_type} to {all_download_paths}")
|
||||
return sf_type, all_download_paths
|
||||
except Exception as e:
|
||||
logger.info(f"Failed to download salesforce csv for object type {sf_type}: {e}")
|
||||
return sf_type, None
|
||||
|
||||
|
||||
def fetch_all_csvs_in_parallel(
|
||||
sf_client: Salesforce,
|
||||
object_types: set[str],
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
) -> dict[str, list[str] | None]:
|
||||
"""
|
||||
Fetches all the csvs in parallel for the given object types
|
||||
Returns a dict of (sf_type, full_download_path)
|
||||
"""
|
||||
time_filter = _build_time_filter_for_salesforce(start, end)
|
||||
time_filter_for_each_object_type = {}
|
||||
# We do this outside of the thread pool executor because this requires
|
||||
# a database connection and we don't want to block the thread pool
|
||||
# executor from running
|
||||
for sf_type in object_types:
|
||||
"""Only add time filter if there is at least one object of the type
|
||||
in the database. We aren't worried about partially completed object update runs
|
||||
because this occurs after we check for existing csvs which covers this case"""
|
||||
if has_at_least_one_object_of_type(sf_type):
|
||||
time_filter_for_each_object_type[sf_type] = time_filter
|
||||
else:
|
||||
time_filter_for_each_object_type[sf_type] = ""
|
||||
|
||||
# Run the bulk retrieve in parallel
|
||||
with ThreadPoolExecutor() as executor:
|
||||
results = executor.map(
|
||||
lambda object_type: _bulk_retrieve_from_salesforce(
|
||||
sf_client=sf_client,
|
||||
sf_type=object_type,
|
||||
time_filter=time_filter_for_each_object_type[object_type],
|
||||
),
|
||||
object_types,
|
||||
)
|
||||
return dict(results)
|
||||
@@ -1,209 +0,0 @@
|
||||
import csv
|
||||
import shelve
|
||||
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import (
|
||||
get_child_to_parent_shelf_path,
|
||||
)
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_id_type_shelf_path
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_object_shelf_path
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import (
|
||||
get_parent_to_child_shelf_path,
|
||||
)
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.connectors.salesforce.utils import validate_salesforce_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _update_relationship_shelves(
|
||||
child_id: str,
|
||||
parent_ids: set[str],
|
||||
) -> None:
|
||||
"""Update the relationship shelf when a record is updated."""
|
||||
try:
|
||||
# Convert child_id to string once
|
||||
str_child_id = str(child_id)
|
||||
|
||||
# First update child to parent mapping
|
||||
with shelve.open(
|
||||
get_child_to_parent_shelf_path(),
|
||||
flag="c",
|
||||
protocol=None,
|
||||
writeback=True,
|
||||
) as child_to_parent_db:
|
||||
old_parent_ids = set(child_to_parent_db.get(str_child_id, []))
|
||||
child_to_parent_db[str_child_id] = list(parent_ids)
|
||||
|
||||
# Calculate differences outside the next context manager
|
||||
parent_ids_to_remove = old_parent_ids - parent_ids
|
||||
parent_ids_to_add = parent_ids - old_parent_ids
|
||||
|
||||
# Only sync once at the end
|
||||
child_to_parent_db.sync()
|
||||
|
||||
# Then update parent to child mapping in a single transaction
|
||||
if not parent_ids_to_remove and not parent_ids_to_add:
|
||||
return
|
||||
with shelve.open(
|
||||
get_parent_to_child_shelf_path(),
|
||||
flag="c",
|
||||
protocol=None,
|
||||
writeback=True,
|
||||
) as parent_to_child_db:
|
||||
# Process all removals first
|
||||
for parent_id in parent_ids_to_remove:
|
||||
str_parent_id = str(parent_id)
|
||||
existing_children = set(parent_to_child_db.get(str_parent_id, []))
|
||||
if str_child_id in existing_children:
|
||||
existing_children.remove(str_child_id)
|
||||
parent_to_child_db[str_parent_id] = list(existing_children)
|
||||
|
||||
# Then process all additions
|
||||
for parent_id in parent_ids_to_add:
|
||||
str_parent_id = str(parent_id)
|
||||
existing_children = set(parent_to_child_db.get(str_parent_id, []))
|
||||
existing_children.add(str_child_id)
|
||||
parent_to_child_db[str_parent_id] = list(existing_children)
|
||||
|
||||
# Single sync at the end
|
||||
parent_to_child_db.sync()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating relationship shelves: {e}")
|
||||
logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}")
|
||||
raise
|
||||
|
||||
|
||||
def get_child_ids(parent_id: str) -> set[str]:
|
||||
"""Get all child IDs for a given parent ID.
|
||||
|
||||
Args:
|
||||
parent_id: The ID of the parent object
|
||||
|
||||
Returns:
|
||||
A set of child object IDs
|
||||
"""
|
||||
with shelve.open(get_parent_to_child_shelf_path()) as parent_to_child_db:
|
||||
return set(parent_to_child_db.get(parent_id, []))
|
||||
|
||||
|
||||
def update_sf_db_with_csv(
|
||||
object_type: str,
|
||||
csv_download_path: str,
|
||||
) -> list[str]:
|
||||
"""Update the SF DB with a CSV file using shelve storage."""
|
||||
updated_ids = []
|
||||
shelf_path = get_object_shelf_path(object_type)
|
||||
|
||||
# First read the CSV to get all the data
|
||||
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
id = row["Id"]
|
||||
parent_ids = set()
|
||||
field_to_remove: set[str] = set()
|
||||
# Update relationship shelves for any parent references
|
||||
for field, value in row.items():
|
||||
if validate_salesforce_id(value) and field != "Id":
|
||||
parent_ids.add(value)
|
||||
field_to_remove.add(field)
|
||||
if not value:
|
||||
field_to_remove.add(field)
|
||||
_update_relationship_shelves(id, parent_ids)
|
||||
for field in field_to_remove:
|
||||
# We use this to extract the Primary Owner later
|
||||
if field != "LastModifiedById":
|
||||
del row[field]
|
||||
|
||||
# Update the main object shelf
|
||||
with shelve.open(shelf_path) as object_type_db:
|
||||
object_type_db[id] = row
|
||||
# Update the ID-to-type mapping shelf
|
||||
with shelve.open(get_id_type_shelf_path()) as id_type_db:
|
||||
id_type_db[id] = object_type
|
||||
|
||||
updated_ids.append(id)
|
||||
|
||||
# os.remove(csv_download_path)
|
||||
return updated_ids
|
||||
|
||||
|
||||
def get_type_from_id(object_id: str) -> str | None:
|
||||
"""Get the type of an object from its ID."""
|
||||
# Look up the object type from the ID-to-type mapping
|
||||
with shelve.open(get_id_type_shelf_path()) as id_type_db:
|
||||
if object_id not in id_type_db:
|
||||
logger.warning(f"Object ID {object_id} not found in ID-to-type mapping")
|
||||
return None
|
||||
return id_type_db[object_id]
|
||||
|
||||
|
||||
def get_record(
|
||||
object_id: str, object_type: str | None = None
|
||||
) -> SalesforceObject | None:
|
||||
"""
|
||||
Retrieve the record and return it as a SalesforceObject.
|
||||
The object type will be looked up from the ID-to-type mapping shelf.
|
||||
"""
|
||||
if object_type is None:
|
||||
if not (object_type := get_type_from_id(object_id)):
|
||||
return None
|
||||
|
||||
shelf_path = get_object_shelf_path(object_type)
|
||||
with shelve.open(shelf_path) as db:
|
||||
if object_id not in db:
|
||||
logger.warning(f"Object ID {object_id} not found in {shelf_path}")
|
||||
return None
|
||||
data = db[object_id]
|
||||
return SalesforceObject(
|
||||
id=object_id,
|
||||
type=object_type,
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
def find_ids_by_type(object_type: str) -> list[str]:
|
||||
"""
|
||||
Find all object IDs for rows of the specified type.
|
||||
"""
|
||||
shelf_path = get_object_shelf_path(object_type)
|
||||
try:
|
||||
with shelve.open(shelf_path) as db:
|
||||
return list(db.keys())
|
||||
except FileNotFoundError:
|
||||
return []
|
||||
|
||||
|
||||
def get_affected_parent_ids_by_type(
|
||||
updated_ids: set[str], parent_types: list[str]
|
||||
) -> dict[str, set[str]]:
|
||||
"""Get IDs of objects that are of the specified parent types and are either in the updated_ids
|
||||
or have children in the updated_ids.
|
||||
|
||||
Args:
|
||||
updated_ids: List of IDs that were updated
|
||||
parent_types: List of object types to filter by
|
||||
|
||||
Returns:
|
||||
A dictionary of IDs that match the criteria
|
||||
"""
|
||||
affected_ids_by_type: dict[str, set[str]] = {}
|
||||
|
||||
# Check each updated ID
|
||||
for updated_id in updated_ids:
|
||||
# Add the ID itself if it's of a parent type
|
||||
updated_type = get_type_from_id(updated_id)
|
||||
if updated_type in parent_types:
|
||||
affected_ids_by_type.setdefault(updated_type, set()).add(updated_id)
|
||||
continue
|
||||
|
||||
# Get parents of this ID and add them if they're of a parent type
|
||||
with shelve.open(get_child_to_parent_shelf_path()) as child_to_parent_db:
|
||||
parent_ids = child_to_parent_db.get(updated_id, [])
|
||||
for parent_id in parent_ids:
|
||||
parent_type = get_type_from_id(parent_id)
|
||||
if parent_type in parent_types:
|
||||
affected_ids_by_type.setdefault(parent_type, set()).add(parent_id)
|
||||
|
||||
return affected_ids_by_type
|
||||
@@ -1,29 +0,0 @@
|
||||
import os
|
||||
|
||||
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||
from onyx.connectors.salesforce.utils import get_object_type_path
|
||||
|
||||
|
||||
def get_object_shelf_path(object_type: str) -> str:
|
||||
"""Get the path to the shelf file for a specific object type."""
|
||||
base_path = get_object_type_path(object_type)
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
return os.path.join(base_path, "data.shelf")
|
||||
|
||||
|
||||
def get_id_type_shelf_path() -> str:
|
||||
"""Get the path to the ID-to-type mapping shelf."""
|
||||
os.makedirs(BASE_DATA_PATH, exist_ok=True)
|
||||
return os.path.join(BASE_DATA_PATH, "id_type_mapping.shelf.4g")
|
||||
|
||||
|
||||
def get_parent_to_child_shelf_path() -> str:
|
||||
"""Get the path to the parent-to-child mapping shelf."""
|
||||
os.makedirs(BASE_DATA_PATH, exist_ok=True)
|
||||
return os.path.join(BASE_DATA_PATH, "parent_to_child_mapping.shelf.4g")
|
||||
|
||||
|
||||
def get_child_to_parent_shelf_path() -> str:
|
||||
"""Get the path to the child-to-parent mapping shelf."""
|
||||
os.makedirs(BASE_DATA_PATH, exist_ok=True)
|
||||
return os.path.join(BASE_DATA_PATH, "child_to_parent_mapping.shelf.4g")
|
||||
@@ -1,737 +0,0 @@
|
||||
import csv
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import find_ids_by_type
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
|
||||
get_affected_parent_ids_by_type,
|
||||
)
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_child_ids
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_record
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
|
||||
update_sf_db_with_csv,
|
||||
)
|
||||
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||
from onyx.connectors.salesforce.utils import get_object_type_path
|
||||
|
||||
_VALID_SALESFORCE_IDS = [
|
||||
"001bm00000fd9Z3AAI",
|
||||
"001bm00000fdYTdAAM",
|
||||
"001bm00000fdYTeAAM",
|
||||
"001bm00000fdYTfAAM",
|
||||
"001bm00000fdYTgAAM",
|
||||
"001bm00000fdYThAAM",
|
||||
"001bm00000fdYTiAAM",
|
||||
"001bm00000fdYTjAAM",
|
||||
"001bm00000fdYTkAAM",
|
||||
"001bm00000fdYTlAAM",
|
||||
"001bm00000fdYTmAAM",
|
||||
"001bm00000fdYTnAAM",
|
||||
"001bm00000fdYToAAM",
|
||||
"500bm00000XoOxtAAF",
|
||||
"500bm00000XoOxuAAF",
|
||||
"500bm00000XoOxvAAF",
|
||||
"500bm00000XoOxwAAF",
|
||||
"500bm00000XoOxxAAF",
|
||||
"500bm00000XoOxyAAF",
|
||||
"500bm00000XoOxzAAF",
|
||||
"500bm00000XoOy0AAF",
|
||||
"500bm00000XoOy1AAF",
|
||||
"500bm00000XoOy2AAF",
|
||||
"500bm00000XoOy3AAF",
|
||||
"500bm00000XoOy4AAF",
|
||||
"500bm00000XoOy5AAF",
|
||||
"500bm00000XoOy6AAF",
|
||||
"500bm00000XoOy7AAF",
|
||||
"500bm00000XoOy8AAF",
|
||||
"500bm00000XoOy9AAF",
|
||||
"500bm00000XoOyAAAV",
|
||||
"500bm00000XoOyBAAV",
|
||||
"500bm00000XoOyCAAV",
|
||||
"500bm00000XoOyDAAV",
|
||||
"500bm00000XoOyEAAV",
|
||||
"500bm00000XoOyFAAV",
|
||||
"500bm00000XoOyGAAV",
|
||||
"500bm00000XoOyHAAV",
|
||||
"500bm00000XoOyIAAV",
|
||||
"003bm00000EjHCjAAN",
|
||||
"003bm00000EjHCkAAN",
|
||||
"003bm00000EjHClAAN",
|
||||
"003bm00000EjHCmAAN",
|
||||
"003bm00000EjHCnAAN",
|
||||
"003bm00000EjHCoAAN",
|
||||
"003bm00000EjHCpAAN",
|
||||
"003bm00000EjHCqAAN",
|
||||
"003bm00000EjHCrAAN",
|
||||
"003bm00000EjHCsAAN",
|
||||
"003bm00000EjHCtAAN",
|
||||
"003bm00000EjHCuAAN",
|
||||
"003bm00000EjHCvAAN",
|
||||
"003bm00000EjHCwAAN",
|
||||
"003bm00000EjHCxAAN",
|
||||
"003bm00000EjHCyAAN",
|
||||
"003bm00000EjHCzAAN",
|
||||
"003bm00000EjHD0AAN",
|
||||
"003bm00000EjHD1AAN",
|
||||
"003bm00000EjHD2AAN",
|
||||
"550bm00000EXc2tAAD",
|
||||
"006bm000006kyDpAAI",
|
||||
"006bm000006kyDqAAI",
|
||||
"006bm000006kyDrAAI",
|
||||
"006bm000006kyDsAAI",
|
||||
"006bm000006kyDtAAI",
|
||||
"006bm000006kyDuAAI",
|
||||
"006bm000006kyDvAAI",
|
||||
"006bm000006kyDwAAI",
|
||||
"006bm000006kyDxAAI",
|
||||
"006bm000006kyDyAAI",
|
||||
"006bm000006kyDzAAI",
|
||||
"006bm000006kyE0AAI",
|
||||
"006bm000006kyE1AAI",
|
||||
"006bm000006kyE2AAI",
|
||||
"006bm000006kyE3AAI",
|
||||
"006bm000006kyE4AAI",
|
||||
"006bm000006kyE5AAI",
|
||||
"006bm000006kyE6AAI",
|
||||
"006bm000006kyE7AAI",
|
||||
"006bm000006kyE8AAI",
|
||||
"006bm000006kyE9AAI",
|
||||
"006bm000006kyEAAAY",
|
||||
"006bm000006kyEBAAY",
|
||||
"006bm000006kyECAAY",
|
||||
"006bm000006kyEDAAY",
|
||||
"006bm000006kyEEAAY",
|
||||
"006bm000006kyEFAAY",
|
||||
"006bm000006kyEGAAY",
|
||||
"006bm000006kyEHAAY",
|
||||
"006bm000006kyEIAAY",
|
||||
"006bm000006kyEJAAY",
|
||||
"005bm000009zy0TAAQ",
|
||||
"005bm000009zy25AAA",
|
||||
"005bm000009zy26AAA",
|
||||
"005bm000009zy28AAA",
|
||||
"005bm000009zy29AAA",
|
||||
"005bm000009zy2AAAQ",
|
||||
"005bm000009zy2BAAQ",
|
||||
]
|
||||
|
||||
|
||||
def clear_sf_db() -> None:
|
||||
"""
|
||||
Clears the SF DB by deleting all files in the data directory.
|
||||
"""
|
||||
shutil.rmtree(BASE_DATA_PATH)
|
||||
|
||||
|
||||
def create_csv_file(
|
||||
object_type: str, records: list[dict], filename: str = "test_data.csv"
|
||||
) -> None:
|
||||
"""
|
||||
Creates a CSV file for the given object type and records.
|
||||
|
||||
Args:
|
||||
object_type: The Salesforce object type (e.g. "Account", "Contact")
|
||||
records: List of dictionaries containing the record data
|
||||
filename: Name of the CSV file to create (default: test_data.csv)
|
||||
"""
|
||||
if not records:
|
||||
return
|
||||
|
||||
# Get all unique fields from records
|
||||
fields: set[str] = set()
|
||||
for record in records:
|
||||
fields.update(record.keys())
|
||||
fields = set(sorted(list(fields))) # Sort for consistent order
|
||||
|
||||
# Create CSV file
|
||||
csv_path = os.path.join(get_object_type_path(object_type), filename)
|
||||
with open(csv_path, "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fields)
|
||||
writer.writeheader()
|
||||
for record in records:
|
||||
writer.writerow(record)
|
||||
|
||||
# Update the database with the CSV
|
||||
update_sf_db_with_csv(object_type, csv_path)
|
||||
|
||||
|
||||
def create_csv_with_example_data() -> None:
|
||||
"""
|
||||
Creates CSV files with example data, organized by object type.
|
||||
"""
|
||||
example_data: dict[str, list[dict]] = {
|
||||
"Account": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Acme Inc.",
|
||||
"BillingCity": "New York",
|
||||
"Industry": "Technology",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[1],
|
||||
"Name": "Globex Corp",
|
||||
"BillingCity": "Los Angeles",
|
||||
"Industry": "Manufacturing",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[2],
|
||||
"Name": "Initech",
|
||||
"BillingCity": "Austin",
|
||||
"Industry": "Software",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[3],
|
||||
"Name": "TechCorp Solutions",
|
||||
"BillingCity": "San Francisco",
|
||||
"Industry": "Software",
|
||||
"AnnualRevenue": 5000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[4],
|
||||
"Name": "BioMed Research",
|
||||
"BillingCity": "Boston",
|
||||
"Industry": "Healthcare",
|
||||
"AnnualRevenue": 12000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[5],
|
||||
"Name": "Green Energy Co",
|
||||
"BillingCity": "Portland",
|
||||
"Industry": "Energy",
|
||||
"AnnualRevenue": 8000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[6],
|
||||
"Name": "DataFlow Analytics",
|
||||
"BillingCity": "Seattle",
|
||||
"Industry": "Technology",
|
||||
"AnnualRevenue": 3000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[7],
|
||||
"Name": "Cloud Nine Services",
|
||||
"BillingCity": "Denver",
|
||||
"Industry": "Cloud Computing",
|
||||
"AnnualRevenue": 7000000,
|
||||
},
|
||||
],
|
||||
"Contact": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"FirstName": "John",
|
||||
"LastName": "Doe",
|
||||
"Email": "john.doe@acme.com",
|
||||
"Title": "CEO",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[41],
|
||||
"FirstName": "Jane",
|
||||
"LastName": "Smith",
|
||||
"Email": "jane.smith@acme.com",
|
||||
"Title": "CTO",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[42],
|
||||
"FirstName": "Bob",
|
||||
"LastName": "Johnson",
|
||||
"Email": "bob.j@globex.com",
|
||||
"Title": "Sales Director",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[43],
|
||||
"FirstName": "Sarah",
|
||||
"LastName": "Chen",
|
||||
"Email": "sarah.chen@techcorp.com",
|
||||
"Title": "Product Manager",
|
||||
"Phone": "415-555-0101",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[44],
|
||||
"FirstName": "Michael",
|
||||
"LastName": "Rodriguez",
|
||||
"Email": "m.rodriguez@biomed.com",
|
||||
"Title": "Research Director",
|
||||
"Phone": "617-555-0202",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[45],
|
||||
"FirstName": "Emily",
|
||||
"LastName": "Green",
|
||||
"Email": "emily.g@greenenergy.com",
|
||||
"Title": "Sustainability Lead",
|
||||
"Phone": "503-555-0303",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[46],
|
||||
"FirstName": "David",
|
||||
"LastName": "Kim",
|
||||
"Email": "david.kim@dataflow.com",
|
||||
"Title": "Data Scientist",
|
||||
"Phone": "206-555-0404",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[47],
|
||||
"FirstName": "Rachel",
|
||||
"LastName": "Taylor",
|
||||
"Email": "r.taylor@cloudnine.com",
|
||||
"Title": "Cloud Architect",
|
||||
"Phone": "303-555-0505",
|
||||
},
|
||||
],
|
||||
"Opportunity": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[62],
|
||||
"Name": "Acme Server Upgrade",
|
||||
"Amount": 50000,
|
||||
"Stage": "Prospecting",
|
||||
"CloseDate": "2024-06-30",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[63],
|
||||
"Name": "Globex Manufacturing Line",
|
||||
"Amount": 150000,
|
||||
"Stage": "Negotiation",
|
||||
"CloseDate": "2024-03-15",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[64],
|
||||
"Name": "Initech Software License",
|
||||
"Amount": 75000,
|
||||
"Stage": "Closed Won",
|
||||
"CloseDate": "2024-01-30",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[65],
|
||||
"Name": "TechCorp AI Implementation",
|
||||
"Amount": 250000,
|
||||
"Stage": "Needs Analysis",
|
||||
"CloseDate": "2024-08-15",
|
||||
"Probability": 60,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[66],
|
||||
"Name": "BioMed Lab Equipment",
|
||||
"Amount": 500000,
|
||||
"Stage": "Value Proposition",
|
||||
"CloseDate": "2024-09-30",
|
||||
"Probability": 75,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[67],
|
||||
"Name": "Green Energy Solar Project",
|
||||
"Amount": 750000,
|
||||
"Stage": "Proposal",
|
||||
"CloseDate": "2024-07-15",
|
||||
"Probability": 80,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[68],
|
||||
"Name": "DataFlow Analytics Platform",
|
||||
"Amount": 180000,
|
||||
"Stage": "Negotiation",
|
||||
"CloseDate": "2024-05-30",
|
||||
"Probability": 90,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[69],
|
||||
"Name": "Cloud Nine Infrastructure",
|
||||
"Amount": 300000,
|
||||
"Stage": "Qualification",
|
||||
"CloseDate": "2024-10-15",
|
||||
"Probability": 40,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Create CSV files for each object type
|
||||
for object_type, records in example_data.items():
|
||||
create_csv_file(object_type, records)
|
||||
|
||||
|
||||
def test_query() -> None:
|
||||
"""
|
||||
Tests querying functionality by verifying:
|
||||
1. All expected Account IDs are found
|
||||
2. Each Account's data matches what was inserted
|
||||
"""
|
||||
# Expected test data for verification
|
||||
expected_accounts: dict[str, dict[str, str | int]] = {
|
||||
_VALID_SALESFORCE_IDS[0]: {
|
||||
"Name": "Acme Inc.",
|
||||
"BillingCity": "New York",
|
||||
"Industry": "Technology",
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[1]: {
|
||||
"Name": "Globex Corp",
|
||||
"BillingCity": "Los Angeles",
|
||||
"Industry": "Manufacturing",
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[2]: {
|
||||
"Name": "Initech",
|
||||
"BillingCity": "Austin",
|
||||
"Industry": "Software",
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[3]: {
|
||||
"Name": "TechCorp Solutions",
|
||||
"BillingCity": "San Francisco",
|
||||
"Industry": "Software",
|
||||
"AnnualRevenue": 5000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[4]: {
|
||||
"Name": "BioMed Research",
|
||||
"BillingCity": "Boston",
|
||||
"Industry": "Healthcare",
|
||||
"AnnualRevenue": 12000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[5]: {
|
||||
"Name": "Green Energy Co",
|
||||
"BillingCity": "Portland",
|
||||
"Industry": "Energy",
|
||||
"AnnualRevenue": 8000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[6]: {
|
||||
"Name": "DataFlow Analytics",
|
||||
"BillingCity": "Seattle",
|
||||
"Industry": "Technology",
|
||||
"AnnualRevenue": 3000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[7]: {
|
||||
"Name": "Cloud Nine Services",
|
||||
"BillingCity": "Denver",
|
||||
"Industry": "Cloud Computing",
|
||||
"AnnualRevenue": 7000000,
|
||||
},
|
||||
}
|
||||
|
||||
# Get all Account IDs
|
||||
account_ids = find_ids_by_type("Account")
|
||||
|
||||
# Verify we found all expected accounts
|
||||
assert len(account_ids) == len(
|
||||
expected_accounts
|
||||
), f"Expected {len(expected_accounts)} accounts, found {len(account_ids)}"
|
||||
assert set(account_ids) == set(
|
||||
expected_accounts.keys()
|
||||
), "Found account IDs don't match expected IDs"
|
||||
|
||||
# Verify each account's data
|
||||
for acc_id in account_ids:
|
||||
combined = get_record(acc_id)
|
||||
assert combined is not None, f"Could not find account {acc_id}"
|
||||
|
||||
expected = expected_accounts[acc_id]
|
||||
|
||||
# Verify account data matches
|
||||
for key, value in expected.items():
|
||||
value = str(value)
|
||||
assert (
|
||||
combined.data[key] == value
|
||||
), f"Account {acc_id} field {key} expected {value}, got {combined.data[key]}"
|
||||
|
||||
print("All query tests passed successfully!")
|
||||
|
||||
|
||||
def test_upsert() -> None:
|
||||
"""
|
||||
Tests upsert functionality by:
|
||||
1. Updating an existing account
|
||||
2. Creating a new account
|
||||
3. Verifying both operations were successful
|
||||
"""
|
||||
# Create CSV for updating an existing account and adding a new one
|
||||
update_data: list[dict[str, str | int]] = [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Acme Inc. Updated",
|
||||
"BillingCity": "New York",
|
||||
"Industry": "Technology",
|
||||
"Description": "Updated company info",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[2],
|
||||
"Name": "New Company Inc.",
|
||||
"BillingCity": "Miami",
|
||||
"Industry": "Finance",
|
||||
"AnnualRevenue": 1000000,
|
||||
},
|
||||
]
|
||||
|
||||
create_csv_file("Account", update_data, "update_data.csv")
|
||||
|
||||
# Verify the update worked
|
||||
updated_record = get_record(_VALID_SALESFORCE_IDS[0])
|
||||
assert updated_record is not None, "Updated record not found"
|
||||
assert updated_record.data["Name"] == "Acme Inc. Updated", "Name not updated"
|
||||
assert (
|
||||
updated_record.data["Description"] == "Updated company info"
|
||||
), "Description not added"
|
||||
|
||||
# Verify the new record was created
|
||||
new_record = get_record(_VALID_SALESFORCE_IDS[2])
|
||||
assert new_record is not None, "New record not found"
|
||||
assert new_record.data["Name"] == "New Company Inc.", "New record name incorrect"
|
||||
assert new_record.data["AnnualRevenue"] == "1000000", "New record revenue incorrect"
|
||||
|
||||
print("All upsert tests passed successfully!")
|
||||
|
||||
|
||||
def test_relationships() -> None:
|
||||
"""
|
||||
Tests relationship shelf updates and queries by:
|
||||
1. Creating test data with relationships
|
||||
2. Verifying the relationships are correctly stored
|
||||
3. Testing relationship queries
|
||||
"""
|
||||
# Create test data for each object type
|
||||
test_data: dict[str, list[dict[str, str | int]]] = {
|
||||
"Case": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[13],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"Subject": "Test Case 1",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[14],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"Subject": "Test Case 2",
|
||||
},
|
||||
],
|
||||
"Contact": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[48],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"FirstName": "Test",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
],
|
||||
"Opportunity": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[62],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Test Opportunity",
|
||||
"Amount": 100000,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Create and update CSV files for each object type
|
||||
for object_type, records in test_data.items():
|
||||
create_csv_file(object_type, records, "relationship_test.csv")
|
||||
|
||||
# Test relationship queries
|
||||
# All these objects should be children of Acme Inc.
|
||||
child_ids = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||
assert len(child_ids) == 4, f"Expected 4 child objects, found {len(child_ids)}"
|
||||
assert _VALID_SALESFORCE_IDS[13] in child_ids, "Case 1 not found in relationship"
|
||||
assert _VALID_SALESFORCE_IDS[14] in child_ids, "Case 2 not found in relationship"
|
||||
assert _VALID_SALESFORCE_IDS[48] in child_ids, "Contact not found in relationship"
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[62] in child_ids
|
||||
), "Opportunity not found in relationship"
|
||||
|
||||
# Test querying relationships for a different account (should be empty)
|
||||
other_account_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
|
||||
assert (
|
||||
len(other_account_children) == 0
|
||||
), "Expected no children for different account"
|
||||
|
||||
print("All relationship tests passed successfully!")
|
||||
|
||||
|
||||
def test_account_with_children() -> None:
|
||||
"""
|
||||
Tests querying all accounts and retrieving their child objects.
|
||||
This test verifies that:
|
||||
1. All accounts can be retrieved
|
||||
2. Child objects are correctly linked
|
||||
3. Child object data is complete and accurate
|
||||
"""
|
||||
# First get all account IDs
|
||||
account_ids = find_ids_by_type("Account")
|
||||
assert len(account_ids) > 0, "No accounts found"
|
||||
|
||||
# For each account, get its children and verify the data
|
||||
for account_id in account_ids:
|
||||
account = get_record(account_id)
|
||||
assert account is not None, f"Could not find account {account_id}"
|
||||
|
||||
# Get all child objects
|
||||
child_ids = get_child_ids(account_id)
|
||||
|
||||
# For Acme Inc., verify specific relationships
|
||||
if account_id == _VALID_SALESFORCE_IDS[0]: # Acme Inc.
|
||||
assert (
|
||||
len(child_ids) == 4
|
||||
), f"Expected 4 children for Acme Inc., found {len(child_ids)}"
|
||||
|
||||
# Get all child records
|
||||
child_records = []
|
||||
for child_id in child_ids:
|
||||
child_record = get_record(child_id)
|
||||
if child_record is not None:
|
||||
child_records.append(child_record)
|
||||
# Verify Cases
|
||||
cases = [r for r in child_records if r.type == "Case"]
|
||||
assert (
|
||||
len(cases) == 2
|
||||
), f"Expected 2 cases for Acme Inc., found {len(cases)}"
|
||||
case_subjects = {case.data["Subject"] for case in cases}
|
||||
assert "Test Case 1" in case_subjects, "Test Case 1 not found"
|
||||
assert "Test Case 2" in case_subjects, "Test Case 2 not found"
|
||||
|
||||
# Verify Contacts
|
||||
contacts = [r for r in child_records if r.type == "Contact"]
|
||||
assert (
|
||||
len(contacts) == 1
|
||||
), f"Expected 1 contact for Acme Inc., found {len(contacts)}"
|
||||
contact = contacts[0]
|
||||
assert contact.data["FirstName"] == "Test", "Contact first name mismatch"
|
||||
assert contact.data["LastName"] == "Contact", "Contact last name mismatch"
|
||||
|
||||
# Verify Opportunities
|
||||
opportunities = [r for r in child_records if r.type == "Opportunity"]
|
||||
assert (
|
||||
len(opportunities) == 1
|
||||
), f"Expected 1 opportunity for Acme Inc., found {len(opportunities)}"
|
||||
opportunity = opportunities[0]
|
||||
assert (
|
||||
opportunity.data["Name"] == "Test Opportunity"
|
||||
), "Opportunity name mismatch"
|
||||
assert opportunity.data["Amount"] == "100000", "Opportunity amount mismatch"
|
||||
|
||||
print("All account with children tests passed successfully!")
|
||||
|
||||
|
||||
def test_relationship_updates() -> None:
|
||||
"""
|
||||
Tests that relationships are properly updated when a child object's parent reference changes.
|
||||
This test verifies:
|
||||
1. Initial relationship is created correctly
|
||||
2. When parent reference is updated, old relationship is removed
|
||||
3. New relationship is created correctly
|
||||
"""
|
||||
# Create initial test data - Contact linked to Acme Inc.
|
||||
initial_contact = [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"FirstName": "Test",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
]
|
||||
create_csv_file("Contact", initial_contact, "initial_contact.csv")
|
||||
|
||||
# Verify initial relationship
|
||||
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[40] in acme_children
|
||||
), "Initial relationship not created"
|
||||
|
||||
# Update contact to be linked to Globex Corp instead
|
||||
updated_contact = [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[1],
|
||||
"FirstName": "Test",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
]
|
||||
create_csv_file("Contact", updated_contact, "updated_contact.csv")
|
||||
|
||||
# Verify old relationship is removed
|
||||
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[40] not in acme_children
|
||||
), "Old relationship not removed"
|
||||
|
||||
# Verify new relationship is created
|
||||
globex_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
|
||||
assert _VALID_SALESFORCE_IDS[40] in globex_children, "New relationship not created"
|
||||
|
||||
print("All relationship update tests passed successfully!")
|
||||
|
||||
|
||||
def test_get_affected_parent_ids() -> None:
|
||||
"""
|
||||
Tests get_affected_parent_ids functionality by verifying:
|
||||
1. IDs that are directly in the parent_types list are included
|
||||
2. IDs that have children in the updated_ids list are included
|
||||
3. IDs that are neither of the above are not included
|
||||
"""
|
||||
# Create test data with relationships
|
||||
test_data = {
|
||||
"Account": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Parent Account 1",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[1],
|
||||
"Name": "Parent Account 2",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[2],
|
||||
"Name": "Not Affected Account",
|
||||
},
|
||||
],
|
||||
"Contact": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"FirstName": "Child",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Create and update CSV files for test data
|
||||
for object_type, records in test_data.items():
|
||||
create_csv_file(object_type, records)
|
||||
|
||||
# Test Case 1: Account directly in updated_ids and parent_types
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[1]} # Parent Account 2
|
||||
parent_types = ["Account"]
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
|
||||
|
||||
# Test Case 2: Account with child in updated_ids
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
|
||||
parent_types = ["Account"]
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[0] in affected_ids
|
||||
), "Parent of updated child not included"
|
||||
|
||||
# Test Case 3: Both direct and indirect affects
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]} # Both cases
|
||||
parent_types = ["Account"]
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert len(affected_ids) == 2, "Expected exactly two affected parent IDs"
|
||||
assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included"
|
||||
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[2] not in affected_ids
|
||||
), "Unaffected ID incorrectly included"
|
||||
|
||||
# Test Case 4: No matches
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
|
||||
parent_types = ["Opportunity"] # Wrong type
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert len(affected_ids) == 0, "Should return empty list when no matches"
|
||||
|
||||
print("All get_affected_parent_ids tests passed successfully!")
|
||||
|
||||
|
||||
def main_build() -> None:
|
||||
clear_sf_db()
|
||||
create_csv_with_example_data()
|
||||
test_query()
|
||||
test_upsert()
|
||||
test_relationships()
|
||||
test_account_with_children()
|
||||
test_relationship_updates()
|
||||
test_get_affected_parent_ids()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_build()
|
||||
@@ -1,386 +0,0 @@
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
|
||||
from onyx.connectors.salesforce.utils import get_sqlite_db_path
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.connectors.salesforce.utils import validate_salesforce_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.utils import batch_list
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_connection(
|
||||
isolation_level: str | None = None,
|
||||
) -> Iterator[sqlite3.Connection]:
|
||||
"""Get a database connection with proper isolation level and error handling.
|
||||
|
||||
Args:
|
||||
isolation_level: SQLite isolation level. None = default "DEFERRED",
|
||||
can be "IMMEDIATE" or "EXCLUSIVE" for more strict isolation.
|
||||
"""
|
||||
# 60 second timeout for locks
|
||||
conn = sqlite3.connect(get_sqlite_db_path(), timeout=60.0)
|
||||
|
||||
if isolation_level is not None:
|
||||
conn.isolation_level = isolation_level
|
||||
try:
|
||||
yield conn
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
"""Initialize the SQLite database with required tables if they don't exist."""
|
||||
if os.path.exists(get_sqlite_db_path()):
|
||||
return
|
||||
|
||||
# Create database directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(get_sqlite_db_path()), exist_ok=True)
|
||||
|
||||
with get_db_connection("EXCLUSIVE") as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Enable WAL mode for better concurrent access and write performance
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.execute("PRAGMA temp_store=MEMORY")
|
||||
cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache
|
||||
|
||||
# Main table for storing Salesforce objects
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS salesforce_objects (
|
||||
id TEXT PRIMARY KEY,
|
||||
object_type TEXT NOT NULL,
|
||||
data TEXT NOT NULL, -- JSON serialized data
|
||||
last_modified INTEGER DEFAULT (strftime('%s', 'now')) -- Add timestamp for better cache management
|
||||
) WITHOUT ROWID -- Optimize for primary key lookups
|
||||
"""
|
||||
)
|
||||
|
||||
# Table for parent-child relationships with covering index
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS relationships (
|
||||
child_id TEXT NOT NULL,
|
||||
parent_id TEXT NOT NULL,
|
||||
PRIMARY KEY (child_id, parent_id)
|
||||
) WITHOUT ROWID -- Optimize for primary key lookups
|
||||
"""
|
||||
)
|
||||
|
||||
# New table for caching parent-child relationships with object types
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS relationship_types (
|
||||
child_id TEXT NOT NULL,
|
||||
parent_id TEXT NOT NULL,
|
||||
parent_type TEXT NOT NULL,
|
||||
PRIMARY KEY (child_id, parent_id, parent_type)
|
||||
) WITHOUT ROWID
|
||||
"""
|
||||
)
|
||||
|
||||
# Always recreate indexes to ensure they exist
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_object_type")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_parent_id")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_child_parent")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_object_type_id")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_relationship_types_lookup")
|
||||
|
||||
# Create covering indexes for common queries
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX idx_object_type
|
||||
ON salesforce_objects(object_type, id)
|
||||
WHERE object_type IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX idx_parent_id
|
||||
ON relationships(parent_id, child_id)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX idx_child_parent
|
||||
ON relationships(child_id)
|
||||
WHERE child_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# New composite index for fast parent type lookups
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX idx_relationship_types_lookup
|
||||
ON relationship_types(parent_type, child_id, parent_id)
|
||||
"""
|
||||
)
|
||||
|
||||
# Analyze tables to help query planner
|
||||
cursor.execute("ANALYZE relationships")
|
||||
cursor.execute("ANALYZE salesforce_objects")
|
||||
cursor.execute("ANALYZE relationship_types")
|
||||
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _update_relationship_tables(
|
||||
conn: sqlite3.Connection, child_id: str, parent_ids: set[str]
|
||||
) -> None:
|
||||
"""Update the relationship tables when a record is updated.
|
||||
|
||||
Args:
|
||||
conn: The database connection to use (must be in a transaction)
|
||||
child_id: The ID of the child record
|
||||
parent_ids: Set of parent IDs to link to
|
||||
"""
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get existing parent IDs
|
||||
cursor.execute(
|
||||
"SELECT parent_id FROM relationships WHERE child_id = ?", (child_id,)
|
||||
)
|
||||
old_parent_ids = {row[0] for row in cursor.fetchall()}
|
||||
|
||||
# Calculate differences
|
||||
parent_ids_to_remove = old_parent_ids - parent_ids
|
||||
parent_ids_to_add = parent_ids - old_parent_ids
|
||||
|
||||
# Remove old relationships
|
||||
if parent_ids_to_remove:
|
||||
cursor.executemany(
|
||||
"DELETE FROM relationships WHERE child_id = ? AND parent_id = ?",
|
||||
[(child_id, pid) for pid in parent_ids_to_remove],
|
||||
)
|
||||
# Also remove from relationship_types
|
||||
cursor.executemany(
|
||||
"DELETE FROM relationship_types WHERE child_id = ? AND parent_id = ?",
|
||||
[(child_id, pid) for pid in parent_ids_to_remove],
|
||||
)
|
||||
|
||||
# Add new relationships
|
||||
if parent_ids_to_add:
|
||||
# First add to relationships table
|
||||
cursor.executemany(
|
||||
"INSERT INTO relationships (child_id, parent_id) VALUES (?, ?)",
|
||||
[(child_id, pid) for pid in parent_ids_to_add],
|
||||
)
|
||||
|
||||
# Then get the types of the parent objects and add to relationship_types
|
||||
for parent_id in parent_ids_to_add:
|
||||
cursor.execute(
|
||||
"SELECT object_type FROM salesforce_objects WHERE id = ?",
|
||||
(parent_id,),
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
parent_type = result[0]
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO relationship_types (child_id, parent_id, parent_type)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(child_id, parent_id, parent_type),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating relationship tables: {e}")
|
||||
logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}")
|
||||
raise
|
||||
|
||||
|
||||
def update_sf_db_with_csv(object_type: str, csv_download_path: str) -> list[str]:
|
||||
"""Update the SF DB with a CSV file using SQLite storage."""
|
||||
updated_ids = []
|
||||
|
||||
# Use IMMEDIATE to get a write lock at the start of the transaction
|
||||
with get_db_connection("IMMEDIATE") as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
if "Id" not in row:
|
||||
logger.warning(
|
||||
f"Row {row} does not have an Id field in {csv_download_path}"
|
||||
)
|
||||
continue
|
||||
id = row["Id"]
|
||||
parent_ids = set()
|
||||
field_to_remove: set[str] = set()
|
||||
|
||||
# Process relationships and clean data
|
||||
for field, value in row.items():
|
||||
if validate_salesforce_id(value) and field != "Id":
|
||||
parent_ids.add(value)
|
||||
field_to_remove.add(field)
|
||||
if not value:
|
||||
field_to_remove.add(field)
|
||||
|
||||
# Remove unwanted fields
|
||||
for field in field_to_remove:
|
||||
if field != "LastModifiedById":
|
||||
del row[field]
|
||||
|
||||
# Update main object data
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO salesforce_objects (id, object_type, data)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(id, object_type, json.dumps(row)),
|
||||
)
|
||||
|
||||
# Update relationships using the same connection
|
||||
_update_relationship_tables(conn, id, parent_ids)
|
||||
updated_ids.append(id)
|
||||
|
||||
conn.commit()
|
||||
|
||||
return updated_ids
|
||||
|
||||
|
||||
def get_child_ids(parent_id: str) -> set[str]:
|
||||
"""Get all child IDs for a given parent ID."""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Force index usage with INDEXED BY
|
||||
cursor.execute(
|
||||
"SELECT child_id FROM relationships INDEXED BY idx_parent_id WHERE parent_id = ?",
|
||||
(parent_id,),
|
||||
)
|
||||
child_ids = {row[0] for row in cursor.fetchall()}
|
||||
return child_ids
|
||||
|
||||
|
||||
def get_type_from_id(object_id: str) -> str | None:
|
||||
"""Get the type of an object from its ID."""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT object_type FROM salesforce_objects WHERE id = ?", (object_id,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
logger.warning(f"Object ID {object_id} not found")
|
||||
return None
|
||||
return result[0]
|
||||
|
||||
|
||||
def get_record(
|
||||
object_id: str, object_type: str | None = None
|
||||
) -> SalesforceObject | None:
|
||||
"""Retrieve the record and return it as a SalesforceObject."""
|
||||
if object_type is None:
|
||||
object_type = get_type_from_id(object_id)
|
||||
if not object_type:
|
||||
return None
|
||||
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT data FROM salesforce_objects WHERE id = ?", (object_id,))
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
logger.warning(f"Object ID {object_id} not found")
|
||||
return None
|
||||
|
||||
data = json.loads(result[0])
|
||||
return SalesforceObject(id=object_id, type=object_type, data=data)
|
||||
|
||||
|
||||
def find_ids_by_type(object_type: str) -> list[str]:
|
||||
"""Find all object IDs for rows of the specified type."""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT id FROM salesforce_objects WHERE object_type = ?", (object_type,)
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
|
||||
|
||||
def get_affected_parent_ids_by_type(
|
||||
updated_ids: list[str],
|
||||
parent_types: list[str],
|
||||
batch_size: int = 500,
|
||||
) -> Iterator[tuple[str, set[str]]]:
|
||||
"""Get IDs of objects that are of the specified parent types and are either in the
|
||||
updated_ids or have children in the updated_ids. Yields tuples of (parent_type, affected_ids).
|
||||
"""
|
||||
# SQLite typically has a limit of 999 variables
|
||||
updated_ids_batches = batch_list(updated_ids, batch_size)
|
||||
updated_parent_ids: set[str] = set()
|
||||
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
for batch_ids in updated_ids_batches:
|
||||
id_placeholders = ",".join(["?" for _ in batch_ids])
|
||||
|
||||
for parent_type in parent_types:
|
||||
affected_ids: set[str] = set()
|
||||
|
||||
# Get directly updated objects of parent types - using index on object_type
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT id FROM salesforce_objects
|
||||
WHERE id IN ({id_placeholders})
|
||||
AND object_type = ?
|
||||
""",
|
||||
batch_ids + [parent_type],
|
||||
)
|
||||
affected_ids.update(row[0] for row in cursor.fetchall())
|
||||
|
||||
# Get parent objects of updated objects - using optimized relationship_types table
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT DISTINCT parent_id
|
||||
FROM relationship_types
|
||||
INDEXED BY idx_relationship_types_lookup
|
||||
WHERE parent_type = ?
|
||||
AND child_id IN ({id_placeholders})
|
||||
""",
|
||||
[parent_type] + batch_ids,
|
||||
)
|
||||
affected_ids.update(row[0] for row in cursor.fetchall())
|
||||
|
||||
# Remove any parent IDs that have already been processed
|
||||
new_affected_ids = affected_ids - updated_parent_ids
|
||||
# Add the new affected IDs to the set of updated parent IDs
|
||||
updated_parent_ids.update(new_affected_ids)
|
||||
|
||||
if new_affected_ids:
|
||||
yield parent_type, new_affected_ids
|
||||
|
||||
|
||||
def has_at_least_one_object_of_type(object_type: str) -> bool:
|
||||
"""Check if there is at least one object of the specified type in the database.
|
||||
|
||||
Args:
|
||||
object_type: The Salesforce object type to check
|
||||
|
||||
Returns:
|
||||
bool: True if at least one object exists, False otherwise
|
||||
"""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) FROM salesforce_objects WHERE object_type = ?",
|
||||
(object_type,),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
return count > 0
|
||||
@@ -1,72 +1,66 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SalesforceObject:
|
||||
id: str
|
||||
type: str
|
||||
data: dict[str, Any]
|
||||
def _clean_salesforce_dict(data: Union[dict, list]) -> Union[dict, list]:
|
||||
if isinstance(data, dict):
|
||||
if "records" in data.keys():
|
||||
data = data["records"]
|
||||
if isinstance(data, dict):
|
||||
if "attributes" in data.keys():
|
||||
if isinstance(data["attributes"], dict):
|
||||
data.update(data.pop("attributes"))
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"ID": self.id,
|
||||
"Type": self.type,
|
||||
"Data": self.data,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SalesforceObject":
|
||||
return cls(
|
||||
id=data["Id"],
|
||||
type=data["Type"],
|
||||
data=data,
|
||||
)
|
||||
if isinstance(data, dict):
|
||||
filtered_dict = {}
|
||||
for key, value in data.items():
|
||||
if not re.search(SF_JSON_FILTER, key, re.IGNORECASE):
|
||||
if "__c" in key: # remove the custom object indicator for display
|
||||
key = key[:-3]
|
||||
if isinstance(value, (dict, list)):
|
||||
filtered_value = _clean_salesforce_dict(value)
|
||||
if filtered_value: # Only add non-empty dictionaries or lists
|
||||
filtered_dict[key] = filtered_value
|
||||
elif value is not None:
|
||||
filtered_dict[key] = value
|
||||
return filtered_dict
|
||||
elif isinstance(data, list):
|
||||
filtered_list = []
|
||||
for item in data:
|
||||
if isinstance(item, (dict, list)):
|
||||
filtered_item = _clean_salesforce_dict(item)
|
||||
if filtered_item: # Only add non-empty dictionaries or lists
|
||||
filtered_list.append(filtered_item)
|
||||
elif item is not None:
|
||||
filtered_list.append(filtered_item)
|
||||
return filtered_list
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
# This defines the base path for all data files relative to this file
|
||||
# AKA BE CAREFUL WHEN MOVING THIS FILE
|
||||
BASE_DATA_PATH = os.path.join(os.path.dirname(__file__), "data")
|
||||
def _json_to_natural_language(data: Union[dict, list], indent: int = 0) -> str:
|
||||
result = []
|
||||
indent_str = " " * indent
|
||||
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
result.append(f"{indent_str}{key}:")
|
||||
result.append(_json_to_natural_language(value, indent + 2))
|
||||
else:
|
||||
result.append(f"{indent_str}{key}: {value}")
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
result.append(_json_to_natural_language(item, indent))
|
||||
else:
|
||||
result.append(f"{indent_str}{data}")
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def get_sqlite_db_path() -> str:
|
||||
"""Get the path to the sqlite db file."""
|
||||
return os.path.join(BASE_DATA_PATH, "salesforce_db.sqlite")
|
||||
|
||||
|
||||
def get_object_type_path(object_type: str) -> str:
|
||||
"""Get the directory path for a specific object type."""
|
||||
type_dir = os.path.join(BASE_DATA_PATH, object_type)
|
||||
os.makedirs(type_dir, exist_ok=True)
|
||||
return type_dir
|
||||
|
||||
|
||||
_CHECKSUM_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"
|
||||
_LOOKUP = {format(i, "05b"): _CHECKSUM_CHARS[i] for i in range(32)}
|
||||
|
||||
|
||||
def validate_salesforce_id(salesforce_id: str) -> bool:
|
||||
"""Validate the checksum portion of an 18-character Salesforce ID.
|
||||
|
||||
Args:
|
||||
salesforce_id: An 18-character Salesforce ID
|
||||
|
||||
Returns:
|
||||
bool: True if the checksum is valid, False otherwise
|
||||
"""
|
||||
if len(salesforce_id) != 18:
|
||||
return False
|
||||
|
||||
chunks = [salesforce_id[0:5], salesforce_id[5:10], salesforce_id[10:15]]
|
||||
|
||||
checksum = salesforce_id[15:18]
|
||||
calculated_checksum = ""
|
||||
|
||||
for chunk in chunks:
|
||||
result_string = "".join(
|
||||
"1" if char.isupper() else "0" for char in reversed(chunk)
|
||||
)
|
||||
calculated_checksum += _LOOKUP[result_string]
|
||||
|
||||
return checksum == calculated_checksum
|
||||
def extract_dict_text(raw_dict: dict) -> str:
|
||||
processed_dict = _clean_salesforce_dict(raw_dict)
|
||||
natural_language_dict = _json_to_natural_language(processed_dict)
|
||||
return natural_language_dict
|
||||
|
||||
@@ -264,6 +264,24 @@ class SlackTextCleaner:
|
||||
message = message.replace("<!everyone>", "@everyone")
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def replace_links(message: str) -> str:
|
||||
"""Replaces slack links e.g. `<URL>` -> `URL` and `<URL|DISPLAY>` -> `DISPLAY`"""
|
||||
# Find user IDs in the message
|
||||
possible_link_matches = re.findall(r"<(.*?)>", message)
|
||||
for possible_link in possible_link_matches:
|
||||
if not possible_link:
|
||||
continue
|
||||
# Special slack patterns that aren't for links
|
||||
if possible_link[0] not in ["#", "@", "!"]:
|
||||
link_display = (
|
||||
possible_link
|
||||
if "|" not in possible_link
|
||||
else possible_link.split("|")[1]
|
||||
)
|
||||
message = message.replace(f"<{possible_link}>", link_display)
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def replace_special_catchall(message: str) -> str:
|
||||
"""Replaces pattern of <!something|another-thing> with another-thing
|
||||
|
||||
@@ -33,7 +33,6 @@ from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.sitemap import list_pages_for_site
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -242,12 +241,6 @@ class WebConnector(LoadConnector):
|
||||
self.to_visit_list = extract_urls_from_sitemap(_ensure_valid_url(base_url))
|
||||
|
||||
elif web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.UPLOAD:
|
||||
# Explicitly check if running in multi-tenant mode to prevent potential security risks
|
||||
if MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Upload input for web connector is not supported in cloud environments"
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
"This is not a UI supported Web Connector flow, "
|
||||
"are you sure you want to do this?"
|
||||
|
||||
@@ -10,21 +10,17 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
time_str_to_utc,
|
||||
)
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
MAX_PAGE_SIZE = 30 # Zendesk API maximum
|
||||
_SLIM_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
class ZendeskCredentialsNotSetUpError(PermissionError):
|
||||
@@ -44,13 +40,6 @@ class ZendeskClient:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/{endpoint}", auth=self.auth, params=params
|
||||
)
|
||||
|
||||
if response.status_code == 429:
|
||||
retry_after = response.headers.get("Retry-After")
|
||||
if retry_after is not None:
|
||||
# Sleep for the duration indicated by the Retry-After header
|
||||
time.sleep(int(retry_after))
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@@ -276,7 +265,7 @@ def _ticket_to_document(
|
||||
)
|
||||
|
||||
|
||||
class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class ZendeskConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
@@ -401,43 +390,6 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
if self.content_type == "articles":
|
||||
articles = _get_articles(
|
||||
self.client, start_time=int(start) if start else None
|
||||
)
|
||||
for article in articles:
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=f"article:{article['id']}",
|
||||
)
|
||||
)
|
||||
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
elif self.content_type == "tickets":
|
||||
tickets = _get_tickets(
|
||||
self.client, start_time=int(start) if start else None
|
||||
)
|
||||
for ticket in tickets:
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=f"zendesk_ticket_{ticket['id']}",
|
||||
)
|
||||
)
|
||||
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
else:
|
||||
raise ValueError(f"Unsupported content_type: {self.content_type}")
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
@@ -54,11 +54,9 @@ def get_total_users_count(db_session: Session) -> int:
|
||||
return user_count + invited_users
|
||||
|
||||
|
||||
async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
async def get_user_count() -> int:
|
||||
async with get_async_session_with_tenant() as session:
|
||||
stmt = select(func.count(User.id))
|
||||
if only_admin_users:
|
||||
stmt = stmt.where(User.role == UserRole.ADMIN)
|
||||
result = await session.execute(stmt)
|
||||
user_count = result.scalar()
|
||||
if user_count is None:
|
||||
|
||||
@@ -7,7 +7,6 @@ from sqlalchemy import exists
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -91,22 +90,15 @@ def get_connector_credential_pairs(
|
||||
user: User | None = None,
|
||||
get_editable: bool = True,
|
||||
ids: list[int] | None = None,
|
||||
eager_load_connector: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
|
||||
if eager_load_connector:
|
||||
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
|
||||
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
|
||||
if not include_disabled:
|
||||
stmt = stmt.where(
|
||||
ConnectorCredentialPair.status == ConnectorCredentialPairStatus.ACTIVE
|
||||
)
|
||||
) # noqa
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
@@ -318,9 +310,6 @@ def associate_default_cc_pair(db_session: Session) -> None:
|
||||
if existing_association is not None:
|
||||
return
|
||||
|
||||
# DefaultCCPair has id 1 since it is the first CC pair created
|
||||
# It is DEFAULT_CC_PAIR_ID, but can't set it explicitly because it messed with the
|
||||
# auto-incrementing id
|
||||
association = ConnectorCredentialPair(
|
||||
connector_id=0,
|
||||
credential_id=0,
|
||||
@@ -361,12 +350,7 @@ def add_credential_to_connector(
|
||||
last_successful_index_time: datetime | None = None,
|
||||
) -> StatusResponse:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
get_editable=False,
|
||||
)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
@@ -443,12 +427,7 @@ def remove_credential_from_connector(
|
||||
db_session: Session,
|
||||
) -> StatusResponse[int]:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
get_editable=False,
|
||||
)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
|
||||
@@ -86,7 +86,7 @@ def _add_user_filters(
|
||||
"""
|
||||
Filter Credentials by:
|
||||
- if the user is in the user_group that owns the Credential
|
||||
- if the user is a curator, they must also have a curator relationship
|
||||
- if the user is not a global_curator, they must also have a curator relationship
|
||||
to the user_group
|
||||
- if editing is being done, we also filter out Credentials that are owned by groups
|
||||
that the user isn't a curator for
|
||||
@@ -97,7 +97,6 @@ def _add_user_filters(
|
||||
where_clause = User__UserGroup.user_id == user.id
|
||||
if user.role == UserRole.CURATOR:
|
||||
where_clause &= User__UserGroup.is_curator == True # noqa: E712
|
||||
|
||||
if get_editable:
|
||||
user_groups = select(User__UserGroup.user_group_id).where(
|
||||
User__UserGroup.user_id == user.id
|
||||
@@ -153,16 +152,10 @@ def fetch_credential_by_id(
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
assume_admin: bool = False,
|
||||
get_editable: bool = True,
|
||||
) -> Credential | None:
|
||||
stmt = select(Credential).distinct()
|
||||
stmt = stmt.where(Credential.id == credential_id)
|
||||
stmt = _add_user_filters(
|
||||
stmt=stmt,
|
||||
user=user,
|
||||
assume_admin=assume_admin,
|
||||
get_editable=get_editable,
|
||||
)
|
||||
stmt = _add_user_filters(stmt, user, assume_admin=assume_admin)
|
||||
result = db_session.execute(stmt)
|
||||
credential = result.scalar_one_or_none()
|
||||
return credential
|
||||
|
||||
@@ -416,18 +416,6 @@ def update_docs_last_modified__no_commit(
|
||||
doc.last_modified = now
|
||||
|
||||
|
||||
def update_docs_chunk_count__no_commit(
|
||||
document_ids: list[str],
|
||||
doc_id_to_chunk_count: dict[str, int],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
documents_to_update = (
|
||||
db_session.query(DbDocument).filter(DbDocument.id.in_(document_ids)).all()
|
||||
)
|
||||
for doc in documents_to_update:
|
||||
doc.chunk_count = doc_id_to_chunk_count[doc.id]
|
||||
|
||||
|
||||
def mark_document_as_modified(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
@@ -624,25 +612,3 @@ def get_document(
|
||||
stmt = select(DbDocument).where(DbDocument.id == document_id)
|
||||
doc: DbDocument | None = db_session.execute(stmt).scalar_one_or_none()
|
||||
return doc
|
||||
|
||||
|
||||
def fetch_chunk_counts_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> list[tuple[str, int | None]]:
|
||||
"""
|
||||
Return a list of (document_id, chunk_count) tuples.
|
||||
Note: chunk_count might be None if not set in DB,
|
||||
so we declare it as Optional[int].
|
||||
"""
|
||||
stmt = select(DbDocument.id, DbDocument.chunk_count).where(
|
||||
DbDocument.id.in_(document_ids)
|
||||
)
|
||||
|
||||
# results is a list of 'Row' objects, each containing two columns
|
||||
results = db_session.execute(stmt).all()
|
||||
|
||||
# If DbDocument.id is guaranteed to be a string, you can just do row.id;
|
||||
# otherwise cast to str if you need to be sure it's a string:
|
||||
return [(str(row[0]), row[1]) for row in results]
|
||||
# or row.id, row.chunk_count if they are named attributes in your ORM model
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import ssl
|
||||
@@ -15,6 +14,7 @@ from typing import ContextManager
|
||||
|
||||
import asyncpg # type: ignore
|
||||
import boto3
|
||||
import jwt
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy import event
|
||||
@@ -27,7 +27,7 @@ from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from onyx.configs.app_configs import AWS_REGION_NAME
|
||||
from onyx.configs.app_configs import AWS_REGION
|
||||
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
||||
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
@@ -40,9 +40,9 @@ from onyx.configs.app_configs import POSTGRES_POOL_PRE_PING
|
||||
from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -273,7 +273,7 @@ async def get_async_connection() -> Any:
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
db = POSTGRES_DB
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION)
|
||||
|
||||
# asyncpg requires 'ssl="require"' if SSL needed
|
||||
return await asyncpg.connect(
|
||||
@@ -315,40 +315,38 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION)
|
||||
cparams["password"] = token
|
||||
cparams["ssl"] = ssl_context
|
||||
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
|
||||
async def get_current_tenant_id(request: Request) -> str:
|
||||
def get_current_tenant_id(request: Request) -> str:
|
||||
if not MULTI_TENANT:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
return tenant_id
|
||||
|
||||
token = request.cookies.get("fastapiusersauth")
|
||||
if not token:
|
||||
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
return current_value
|
||||
|
||||
try:
|
||||
# Look up token data in Redis
|
||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||
|
||||
if not token_data:
|
||||
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
logger.debug(
|
||||
f"Token data not found or expired in Redis, defaulting to {current_value}"
|
||||
)
|
||||
return current_value
|
||||
|
||||
tenant_id = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
USER_AUTH_SECRET,
|
||||
audience=["fastapi-users:auth"],
|
||||
algorithms=["HS256"],
|
||||
)
|
||||
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
return tenant_id
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error decoding token data from Redis")
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
except jwt.InvalidTokenError:
|
||||
return CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
@@ -370,23 +368,9 @@ async def get_async_session_with_tenant(
|
||||
bind=engine, expire_on_commit=False, class_=AsyncSession
|
||||
) # type: ignore
|
||||
|
||||
async def _set_search_path(session: AsyncSession, tenant_id: str) -> None:
|
||||
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
|
||||
async with async_session_factory() as session:
|
||||
# Register an event listener that is called whenever a new transaction starts
|
||||
@event.listens_for(session.sync_session, "after_begin")
|
||||
def after_begin(session_: Any, transaction: Any, connection: Any) -> None:
|
||||
# Because the event is sync, we can't directly await here.
|
||||
# Instead we queue up an asyncio task to ensures
|
||||
# the next statement sets the search_path
|
||||
session_.do_orm_execute = lambda state: connection.exec_driver_sql(
|
||||
f'SET search_path = "{tenant_id}"'
|
||||
)
|
||||
|
||||
try:
|
||||
await _set_search_path(session, tenant_id)
|
||||
|
||||
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
await session.execute(
|
||||
text(
|
||||
@@ -541,6 +525,6 @@ def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) ->
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
region = os.getenv("AWS_REGION_NAME", "us-east-2")
|
||||
region = os.getenv("AWS_REGION", "us-east-2")
|
||||
# Configure for psycopg2 with IAM token
|
||||
configure_psycopg2_iam_auth(cparams, host, port, user, region)
|
||||
|
||||
@@ -54,7 +54,6 @@ from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.pydantic_type import PydanticType
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
@@ -66,8 +65,6 @@ from onyx.utils.headers import HeaderItemDict
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import RerankerProvider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
__abstract__ = True
|
||||
@@ -75,8 +72,6 @@ class Base(DeclarativeBase):
|
||||
|
||||
class EncryptedString(TypeDecorator):
|
||||
impl = LargeBinary
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: str | None, dialect: Dialect) -> bytes | None:
|
||||
if value is not None:
|
||||
@@ -91,8 +86,6 @@ class EncryptedString(TypeDecorator):
|
||||
|
||||
class EncryptedJson(TypeDecorator):
|
||||
impl = LargeBinary
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: dict | None, dialect: Dialect) -> bytes | None:
|
||||
if value is not None:
|
||||
@@ -109,21 +102,6 @@ class EncryptedJson(TypeDecorator):
|
||||
return value
|
||||
|
||||
|
||||
class NullFilteredString(TypeDecorator):
|
||||
impl = String
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: str | None, dialect: Dialect) -> str | None:
|
||||
if value is not None and "\x00" in value:
|
||||
logger.warning(f"NUL characters found in value: {value}")
|
||||
return value.replace("\x00", "")
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: str | None, dialect: Dialect) -> str | None:
|
||||
return value
|
||||
|
||||
|
||||
"""
|
||||
Auth/Authz (users, permissions, access) Tables
|
||||
"""
|
||||
@@ -473,16 +451,16 @@ class Document(Base):
|
||||
|
||||
# this should correspond to the ID of the document
|
||||
# (as is passed around in Onyx)
|
||||
id: Mapped[str] = mapped_column(NullFilteredString, primary_key=True)
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
from_ingestion_api: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, nullable=True
|
||||
)
|
||||
# 0 for neutral, positive for mostly endorse, negative for mostly reject
|
||||
boost: Mapped[int] = mapped_column(Integer, default=DEFAULT_BOOST)
|
||||
hidden: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
semantic_id: Mapped[str] = mapped_column(NullFilteredString)
|
||||
semantic_id: Mapped[str] = mapped_column(String)
|
||||
# First Section's link
|
||||
link: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True)
|
||||
link: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# The updated time is also used as a measure of the last successful state of the doc
|
||||
# pulled from the source (to help skip reindexing already updated docs in case of
|
||||
@@ -494,10 +472,6 @@ class Document(Base):
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
# Number of chunks in the document (in Vespa)
|
||||
# Only null for documents indexed prior to this change
|
||||
chunk_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# last time any vespa relevant row metadata or the doc changed.
|
||||
# does not include last_synced
|
||||
last_modified: Mapped[datetime.datetime | None] = mapped_column(
|
||||
@@ -508,7 +482,6 @@ class Document(Base):
|
||||
last_synced: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, index=True
|
||||
)
|
||||
|
||||
# The following are not attached to User because the account/email may not be known
|
||||
# within Onyx
|
||||
# Something like the document creator
|
||||
|
||||
@@ -99,9 +99,6 @@ def _add_user_filters(
|
||||
return stmt.where(where_clause)
|
||||
|
||||
|
||||
# fetch_persona_by_id is used to fetch a persona by its ID. It is used to fetch a persona by its ID.
|
||||
|
||||
|
||||
def fetch_persona_by_id(
|
||||
db_session: Session, persona_id: int, user: User | None, get_editable: bool = True
|
||||
) -> Persona:
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -7,20 +6,9 @@ from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlalchemy.sql.elements import KeyedColumnElement
|
||||
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
def validate_user_role_update(requested_role: UserRole, current_role: UserRole) -> None:
|
||||
@@ -95,10 +83,8 @@ def validate_user_role_update(requested_role: UserRole, current_role: UserRole)
|
||||
)
|
||||
|
||||
|
||||
def get_all_users(
|
||||
db_session: Session,
|
||||
email_filter_string: str | None = None,
|
||||
include_external: bool = False,
|
||||
def list_users(
|
||||
db_session: Session, email_filter_string: str = "", include_external: bool = False
|
||||
) -> Sequence[User]:
|
||||
"""List all users. No pagination as of now, as the # of users
|
||||
is assumed to be relatively small (<< 1 million)"""
|
||||
@@ -109,7 +95,7 @@ def get_all_users(
|
||||
if not include_external:
|
||||
where_clause.append(User.role != UserRole.EXT_PERM_USER)
|
||||
|
||||
if email_filter_string is not None:
|
||||
if email_filter_string:
|
||||
where_clause.append(User.email.ilike(f"%{email_filter_string}%")) # type: ignore
|
||||
|
||||
stmt = stmt.where(*where_clause)
|
||||
@@ -117,101 +103,13 @@ def get_all_users(
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def _get_accepted_user_where_clause(
|
||||
email_filter_string: str | None = None,
|
||||
roles_filter: list[UserRole] = [],
|
||||
include_external: bool = False,
|
||||
is_active_filter: bool | None = None,
|
||||
) -> list[ColumnElement[bool]]:
|
||||
"""
|
||||
Generates a SQLAlchemy where clause for filtering users based on the provided parameters.
|
||||
This is used to build the filters for the function that retrieves the users for the users table in the admin panel.
|
||||
|
||||
Parameters:
|
||||
- email_filter_string: A substring to filter user emails. Only users whose emails contain this substring will be included.
|
||||
- is_active_filter: When True, only active users will be included. When False, only inactive users will be included.
|
||||
- roles_filter: A list of user roles to filter by. Only users with roles in this list will be included.
|
||||
- include_external: If False, external permissioned users will be excluded.
|
||||
|
||||
Returns:
|
||||
- list: A list of conditions to be used in a SQLAlchemy query to filter users.
|
||||
"""
|
||||
|
||||
# Access table columns directly via __table__.c to get proper SQLAlchemy column types
|
||||
# This ensures type checking works correctly for SQL operations like ilike, endswith, and is_
|
||||
email_col: KeyedColumnElement[Any] = User.__table__.c.email
|
||||
is_active_col: KeyedColumnElement[Any] = User.__table__.c.is_active
|
||||
|
||||
where_clause: list[ColumnElement[bool]] = [
|
||||
expression.not_(email_col.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN))
|
||||
]
|
||||
|
||||
if not include_external:
|
||||
where_clause.append(User.role != UserRole.EXT_PERM_USER)
|
||||
|
||||
if email_filter_string is not None:
|
||||
where_clause.append(email_col.ilike(f"%{email_filter_string}%"))
|
||||
|
||||
if roles_filter:
|
||||
where_clause.append(User.role.in_(roles_filter))
|
||||
|
||||
if is_active_filter is not None:
|
||||
where_clause.append(is_active_col.is_(is_active_filter))
|
||||
|
||||
return where_clause
|
||||
|
||||
|
||||
def get_page_of_filtered_users(
|
||||
db_session: Session,
|
||||
page_size: int,
|
||||
page_num: int,
|
||||
email_filter_string: str | None = None,
|
||||
is_active_filter: bool | None = None,
|
||||
roles_filter: list[UserRole] = [],
|
||||
include_external: bool = False,
|
||||
) -> Sequence[User]:
|
||||
users_stmt = select(User)
|
||||
|
||||
where_clause = _get_accepted_user_where_clause(
|
||||
email_filter_string=email_filter_string,
|
||||
roles_filter=roles_filter,
|
||||
include_external=include_external,
|
||||
is_active_filter=is_active_filter,
|
||||
)
|
||||
# Apply pagination
|
||||
users_stmt = users_stmt.offset((page_num) * page_size).limit(page_size)
|
||||
# Apply filtering
|
||||
users_stmt = users_stmt.where(*where_clause)
|
||||
|
||||
return db_session.scalars(users_stmt).unique().all()
|
||||
|
||||
|
||||
def get_total_filtered_users_count(
|
||||
db_session: Session,
|
||||
email_filter_string: str | None = None,
|
||||
is_active_filter: bool | None = None,
|
||||
roles_filter: list[UserRole] = [],
|
||||
include_external: bool = False,
|
||||
) -> int:
|
||||
where_clause = _get_accepted_user_where_clause(
|
||||
email_filter_string=email_filter_string,
|
||||
roles_filter=roles_filter,
|
||||
include_external=include_external,
|
||||
is_active_filter=is_active_filter,
|
||||
)
|
||||
total_count_stmt = select(func.count()).select_from(User)
|
||||
# Apply filtering
|
||||
total_count_stmt = total_count_stmt.where(*where_clause)
|
||||
|
||||
return db_session.scalar(total_count_stmt) or 0
|
||||
|
||||
|
||||
def get_user_by_email(email: str, db_session: Session) -> User | None:
|
||||
user = (
|
||||
db_session.query(User)
|
||||
.filter(func.lower(User.email) == func.lower(email))
|
||||
.first()
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@@ -287,43 +185,3 @@ def batch_add_ext_perm_user_if_not_exists(
|
||||
db_session.commit()
|
||||
|
||||
return found_users + new_users
|
||||
|
||||
|
||||
def delete_user_from_db(
|
||||
user_to_delete: User,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
for oauth_account in user_to_delete.oauth_accounts:
|
||||
db_session.delete(oauth_account)
|
||||
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.external_perm",
|
||||
"delete_user__ext_group_for_user__no_commit",
|
||||
)(
|
||||
db_session=db_session,
|
||||
user_id=user_to_delete.id,
|
||||
)
|
||||
db_session.query(SamlAccount).filter(
|
||||
SamlAccount.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(DocumentSet__User).filter(
|
||||
DocumentSet__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(User__UserGroup).filter(
|
||||
User__UserGroup.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.delete(user_to_delete)
|
||||
db_session.commit()
|
||||
|
||||
# NOTE: edge case may exist with race conditions
|
||||
# with this `invited user` scheme generally.
|
||||
user_emails = get_invited_users()
|
||||
remaining_users = [
|
||||
remaining_user_email
|
||||
for remaining_user_email in user_emails
|
||||
if remaining_user_email != user_to_delete.email
|
||||
]
|
||||
write_invited_users(remaining_users)
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
import math
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
|
||||
|
||||
DEFAULT_BATCH_SIZE = 30
|
||||
@@ -37,118 +36,25 @@ def translate_boost_count_to_multiplier(boost: int) -> float:
|
||||
return 2 / (1 + math.exp(-1 * boost / 3))
|
||||
|
||||
|
||||
def assemble_document_chunk_info(
|
||||
enriched_document_info_list: list[EnrichedDocumentIndexingInfo],
|
||||
tenant_id: str | None,
|
||||
large_chunks_enabled: bool,
|
||||
) -> list[UUID]:
|
||||
doc_chunk_ids = []
|
||||
|
||||
for enriched_document_info in enriched_document_info_list:
|
||||
for chunk_index in range(
|
||||
enriched_document_info.chunk_start_index,
|
||||
enriched_document_info.chunk_end_index,
|
||||
):
|
||||
if not enriched_document_info.old_version:
|
||||
doc_chunk_ids.append(
|
||||
get_uuid_from_chunk_info(
|
||||
document_id=enriched_document_info.doc_id,
|
||||
chunk_id=chunk_index,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
doc_chunk_ids.append(
|
||||
get_uuid_from_chunk_info_old(
|
||||
document_id=enriched_document_info.doc_id,
|
||||
chunk_id=chunk_index,
|
||||
)
|
||||
)
|
||||
|
||||
if large_chunks_enabled and chunk_index % 4 == 0:
|
||||
large_chunk_id = int(chunk_index / 4)
|
||||
large_chunk_reference_ids = [
|
||||
large_chunk_id + i
|
||||
for i in range(4)
|
||||
if large_chunk_id + i < enriched_document_info.chunk_end_index
|
||||
]
|
||||
if enriched_document_info.old_version:
|
||||
doc_chunk_ids.append(
|
||||
get_uuid_from_chunk_info_old(
|
||||
document_id=enriched_document_info.doc_id,
|
||||
chunk_id=large_chunk_id,
|
||||
large_chunk_reference_ids=large_chunk_reference_ids,
|
||||
)
|
||||
)
|
||||
else:
|
||||
doc_chunk_ids.append(
|
||||
get_uuid_from_chunk_info(
|
||||
document_id=enriched_document_info.doc_id,
|
||||
chunk_id=large_chunk_id,
|
||||
tenant_id=tenant_id,
|
||||
large_chunk_id=large_chunk_id,
|
||||
)
|
||||
)
|
||||
|
||||
return doc_chunk_ids
|
||||
|
||||
|
||||
def get_uuid_from_chunk_info(
|
||||
*,
|
||||
document_id: str,
|
||||
chunk_id: int,
|
||||
tenant_id: str | None,
|
||||
large_chunk_id: int | None = None,
|
||||
) -> UUID:
|
||||
doc_str = document_id
|
||||
|
||||
# Web parsing URL duplicate catching
|
||||
if doc_str and doc_str[-1] == "/":
|
||||
doc_str = doc_str[:-1]
|
||||
|
||||
chunk_index = (
|
||||
"large_" + str(large_chunk_id) if large_chunk_id is not None else str(chunk_id)
|
||||
def get_uuid_from_chunk(
|
||||
chunk: IndexChunk | InferenceChunk, mini_chunk_ind: int = 0
|
||||
) -> uuid.UUID:
|
||||
doc_str = (
|
||||
chunk.document_id
|
||||
if isinstance(chunk, InferenceChunk)
|
||||
else chunk.source_document.id
|
||||
)
|
||||
unique_identifier_string = "_".join([doc_str, chunk_index])
|
||||
if tenant_id:
|
||||
unique_identifier_string += "_" + tenant_id
|
||||
|
||||
return uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)
|
||||
|
||||
|
||||
def get_uuid_from_chunk_info_old(
|
||||
*, document_id: str, chunk_id: int, large_chunk_reference_ids: list[int] = []
|
||||
) -> UUID:
|
||||
doc_str = document_id
|
||||
|
||||
# Web parsing URL duplicate catching
|
||||
if doc_str and doc_str[-1] == "/":
|
||||
doc_str = doc_str[:-1]
|
||||
unique_identifier_string = "_".join([doc_str, str(chunk_id), "0"])
|
||||
if large_chunk_reference_ids:
|
||||
unique_identifier_string = "_".join(
|
||||
[doc_str, str(chunk.chunk_id), str(mini_chunk_ind)]
|
||||
)
|
||||
if chunk.large_chunk_reference_ids:
|
||||
unique_identifier_string += "_large" + "_".join(
|
||||
[
|
||||
str(referenced_chunk_id)
|
||||
for referenced_chunk_id in large_chunk_reference_ids
|
||||
for referenced_chunk_id in chunk.large_chunk_reference_ids
|
||||
]
|
||||
)
|
||||
return uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)
|
||||
|
||||
|
||||
def get_uuid_from_chunk(chunk: DocMetadataAwareIndexChunk) -> uuid.UUID:
|
||||
return get_uuid_from_chunk_info(
|
||||
document_id=chunk.source_document.id,
|
||||
chunk_id=chunk.chunk_id,
|
||||
tenant_id=chunk.tenant_id,
|
||||
large_chunk_id=chunk.large_chunk_id,
|
||||
)
|
||||
|
||||
|
||||
def get_uuid_from_chunk_old(
|
||||
chunk: DocMetadataAwareIndexChunk, large_chunk_reference_ids: list[int] = []
|
||||
) -> UUID:
|
||||
return get_uuid_from_chunk_info_old(
|
||||
document_id=chunk.source_document.id,
|
||||
chunk_id=chunk.chunk_id,
|
||||
large_chunk_reference_ids=large_chunk_reference_ids,
|
||||
)
|
||||
|
||||
@@ -35,38 +35,6 @@ class VespaChunkRequest:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexBatchParams:
|
||||
"""
|
||||
Information necessary for efficiently indexing a batch of documents
|
||||
"""
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int | None]
|
||||
doc_id_to_new_chunk_cnt: dict[str, int]
|
||||
tenant_id: str | None
|
||||
large_chunks_enabled: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class MinimalDocumentIndexingInfo:
|
||||
"""
|
||||
Minimal information necessary for indexing a document
|
||||
"""
|
||||
|
||||
doc_id: str
|
||||
chunk_start_index: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnrichedDocumentIndexingInfo(MinimalDocumentIndexingInfo):
|
||||
"""
|
||||
Enriched information necessary for indexing a document, including version and chunk range.
|
||||
"""
|
||||
|
||||
old_version: bool
|
||||
chunk_end_index: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentMetadata:
|
||||
"""
|
||||
@@ -180,7 +148,7 @@ class Indexable(abc.ABC):
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
fresh_index: bool = False,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""
|
||||
Takes a list of document chunks and indexes them in the document index
|
||||
@@ -198,11 +166,14 @@ class Indexable(abc.ABC):
|
||||
only needs to index chunks into the PRIMARY index. Do not update the secondary index here,
|
||||
it is done automatically outside of this code.
|
||||
|
||||
NOTE: The fresh_index parameter, when set to True, assumes no documents have been previously
|
||||
indexed for the given index/tenant. This can be used to optimize the indexing process for
|
||||
new or empty indices.
|
||||
|
||||
Parameters:
|
||||
- chunks: Document chunks with all of the information needed for indexing to the document
|
||||
index.
|
||||
- tenant_id: The tenant id of the user whose chunks are being indexed
|
||||
- large_chunks_enabled: Whether large chunks are enabled
|
||||
- fresh_index: Boolean indicating whether this is a fresh index with no existing documents.
|
||||
|
||||
Returns:
|
||||
List of document ids which map to unique documents and are used for deduping chunks
|
||||
@@ -214,7 +185,7 @@ class Indexable(abc.ABC):
|
||||
|
||||
class Deletable(abc.ABC):
|
||||
"""
|
||||
Class must implement the ability to delete document by a given unique document id.
|
||||
Class must implement the ability to delete document by their unique document ids.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -227,6 +198,16 @@ class Deletable(abc.ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self, doc_ids: list[str]) -> None:
|
||||
"""
|
||||
Given a list of document ids, hard delete them from the document index
|
||||
|
||||
Parameters:
|
||||
- doc_ids: list of document ids as specified by the connector
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Updatable(abc.ABC):
|
||||
"""
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import concurrent.futures
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from retry import retry
|
||||
|
||||
from onyx.document_index.vespa.chunk_retrieval import (
|
||||
get_all_vespa_ids_for_document_id,
|
||||
)
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import NUM_THREADS
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -14,27 +16,29 @@ logger = setup_logger()
|
||||
CONTENT_SUMMARY = "content_summary"
|
||||
|
||||
|
||||
@retry(tries=10, delay=1, backoff=2)
|
||||
def _retryable_http_delete(http_client: httpx.Client, url: str) -> None:
|
||||
res = http_client.delete(url)
|
||||
res.raise_for_status()
|
||||
|
||||
|
||||
def _delete_vespa_chunk(
|
||||
doc_chunk_id: UUID, index_name: str, http_client: httpx.Client
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _delete_vespa_doc_chunks(
|
||||
document_id: str, index_name: str, http_client: httpx.Client
|
||||
) -> None:
|
||||
try:
|
||||
_retryable_http_delete(
|
||||
http_client,
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}",
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Failed to delete chunk, details: {e.response.text}")
|
||||
raise
|
||||
doc_chunk_ids = get_all_vespa_ids_for_document_id(
|
||||
document_id=document_id,
|
||||
index_name=index_name,
|
||||
get_large_chunks=True,
|
||||
)
|
||||
|
||||
for chunk_id in doc_chunk_ids:
|
||||
try:
|
||||
res = http_client.delete(
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{chunk_id}"
|
||||
)
|
||||
res.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Failed to delete chunk, details: {e.response.text}")
|
||||
raise
|
||||
|
||||
|
||||
def delete_vespa_chunks(
|
||||
doc_chunk_ids: list[UUID],
|
||||
def delete_vespa_docs(
|
||||
document_ids: list[str],
|
||||
index_name: str,
|
||||
http_client: httpx.Client,
|
||||
executor: concurrent.futures.ThreadPoolExecutor | None = None,
|
||||
@@ -46,13 +50,13 @@ def delete_vespa_chunks(
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS)
|
||||
|
||||
try:
|
||||
chunk_deletion_future = {
|
||||
doc_deletion_future = {
|
||||
executor.submit(
|
||||
_delete_vespa_chunk, doc_chunk_id, index_name, http_client
|
||||
): doc_chunk_id
|
||||
for doc_chunk_id in doc_chunk_ids
|
||||
_delete_vespa_doc_chunks, doc_id, index_name, http_client
|
||||
): doc_id
|
||||
for doc_id in document_ids
|
||||
}
|
||||
for future in concurrent.futures.as_completed(chunk_deletion_future):
|
||||
for future in concurrent.futures.as_completed(doc_deletion_future):
|
||||
# Will raise exception if the deletion raised an exception
|
||||
future.result()
|
||||
|
||||
|
||||
@@ -25,12 +25,8 @@ from onyx.configs.chat_configs import VESPA_SEARCHER_THREADS
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.document_index.document_index_utils import assemble_document_chunk_info
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentInsertionRecord
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.document_index.interfaces import MinimalDocumentIndexingInfo
|
||||
from onyx.document_index.interfaces import UpdateRequest
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
@@ -42,10 +38,12 @@ from onyx.document_index.vespa.chunk_retrieval import (
|
||||
parallel_visit_api_retrieval,
|
||||
)
|
||||
from onyx.document_index.vespa.chunk_retrieval import query_vespa
|
||||
from onyx.document_index.vespa.deletion import delete_vespa_chunks
|
||||
from onyx.document_index.vespa.deletion import delete_vespa_docs
|
||||
from onyx.document_index.vespa.indexing_utils import batch_index_vespa_chunks
|
||||
from onyx.document_index.vespa.indexing_utils import check_for_final_chunk_existence
|
||||
from onyx.document_index.vespa.indexing_utils import clean_chunk_id_copy
|
||||
from onyx.document_index.vespa.indexing_utils import (
|
||||
get_existing_documents_from_chunks,
|
||||
)
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
@@ -309,18 +307,12 @@ class VespaIndex(DocumentIndex):
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
fresh_index: bool = False,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""Receive a list of chunks from a batch of documents and index the chunks into Vespa along
|
||||
with updating the associated permissions. Assumes that a document will not be split into
|
||||
multiple chunk batches calling this function multiple times, otherwise only the last set of
|
||||
chunks will be kept"""
|
||||
|
||||
doc_id_to_previous_chunk_cnt = index_batch_params.doc_id_to_previous_chunk_cnt
|
||||
doc_id_to_new_chunk_cnt = index_batch_params.doc_id_to_new_chunk_cnt
|
||||
tenant_id = index_batch_params.tenant_id
|
||||
large_chunks_enabled = index_batch_params.large_chunks_enabled
|
||||
|
||||
# IMPORTANT: This must be done one index at a time, do not use secondary index here
|
||||
cleaned_chunks = [clean_chunk_id_copy(chunk) for chunk in chunks]
|
||||
|
||||
@@ -332,59 +324,30 @@ class VespaIndex(DocumentIndex):
|
||||
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
|
||||
get_vespa_http_client() as http_client,
|
||||
):
|
||||
# We require the start and end index for each document in order to
|
||||
# know precisely which chunks to delete. This information exists for
|
||||
# documents that have `chunk_count` in the database, but not for
|
||||
# `old_version` documents.
|
||||
|
||||
enriched_doc_infos: list[EnrichedDocumentIndexingInfo] = []
|
||||
for document_id, doc_count in doc_id_to_previous_chunk_cnt.items():
|
||||
last_indexed_chunk = doc_id_to_previous_chunk_cnt.get(document_id, None)
|
||||
# If the document has no `chunk_count` in the database, we know that it
|
||||
# has the old chunk ID system and we must check for the final chunk index
|
||||
is_old_version = False
|
||||
if last_indexed_chunk is None:
|
||||
is_old_version = True
|
||||
minimal_doc_info = MinimalDocumentIndexingInfo(
|
||||
doc_id=document_id,
|
||||
chunk_start_index=doc_id_to_new_chunk_cnt.get(document_id, 0),
|
||||
if not fresh_index:
|
||||
# Check for existing documents, existing documents need to have all of their chunks deleted
|
||||
# prior to indexing as the document size (num chunks) may have shrunk
|
||||
first_chunks = [
|
||||
chunk for chunk in cleaned_chunks if chunk.chunk_id == 0
|
||||
]
|
||||
for chunk_batch in batch_generator(first_chunks, BATCH_SIZE):
|
||||
existing_docs.update(
|
||||
get_existing_documents_from_chunks(
|
||||
chunks=chunk_batch,
|
||||
index_name=self.index_name,
|
||||
http_client=http_client,
|
||||
executor=executor,
|
||||
)
|
||||
)
|
||||
last_indexed_chunk = check_for_final_chunk_existence(
|
||||
minimal_doc_info=minimal_doc_info,
|
||||
start_index=doc_id_to_new_chunk_cnt[document_id],
|
||||
|
||||
for doc_id_batch in batch_generator(existing_docs, BATCH_SIZE):
|
||||
delete_vespa_docs(
|
||||
document_ids=doc_id_batch,
|
||||
index_name=self.index_name,
|
||||
http_client=http_client,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
# If the document has previously indexed chunks, we know it previously existed
|
||||
if doc_count or last_indexed_chunk:
|
||||
existing_docs.add(document_id)
|
||||
|
||||
enriched_doc_info = EnrichedDocumentIndexingInfo(
|
||||
doc_id=document_id,
|
||||
chunk_start_index=doc_id_to_new_chunk_cnt.get(document_id, 0),
|
||||
chunk_end_index=last_indexed_chunk,
|
||||
old_version=is_old_version,
|
||||
)
|
||||
enriched_doc_infos.append(enriched_doc_info)
|
||||
|
||||
# Now, for each doc, we know exactly where to start and end our deletion
|
||||
# So let's generate the chunk IDs for each chunk to delete
|
||||
chunks_to_delete = assemble_document_chunk_info(
|
||||
enriched_document_info_list=enriched_doc_infos,
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=large_chunks_enabled,
|
||||
)
|
||||
|
||||
# Delete old Vespa documents
|
||||
for doc_chunk_ids_batch in batch_generator(chunks_to_delete, BATCH_SIZE):
|
||||
delete_vespa_chunks(
|
||||
doc_chunk_ids=doc_chunk_ids_batch,
|
||||
index_name=self.index_name,
|
||||
http_client=http_client,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
|
||||
batch_index_vespa_chunks(
|
||||
chunks=chunk_batch,
|
||||
@@ -625,6 +588,24 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
return total_chunks_updated
|
||||
|
||||
def delete(self, doc_ids: list[str]) -> None:
|
||||
logger.info(f"Deleting {len(doc_ids)} documents from Vespa")
|
||||
|
||||
doc_ids = [replace_invalid_doc_id_characters(doc_id) for doc_id in doc_ids]
|
||||
|
||||
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficial for
|
||||
# indexing / updates / deletes since we have to make a large volume of requests.
|
||||
with get_vespa_http_client() as http_client:
|
||||
index_names = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
for index_name in index_names:
|
||||
delete_vespa_docs(
|
||||
document_ids=doc_ids, index_name=index_name, http_client=http_client
|
||||
)
|
||||
return
|
||||
|
||||
def delete_single(self, doc_id: str) -> int:
|
||||
"""Possibly faster overall than the delete method due to using a single
|
||||
delete call with a selection query."""
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import concurrent.futures
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
@@ -12,8 +11,6 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info_old
|
||||
from onyx.document_index.interfaces import MinimalDocumentIndexingInfo
|
||||
from onyx.document_index.vespa.shared_utils.utils import remove_invalid_unicode_chars
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
@@ -51,9 +48,14 @@ logger = setup_logger()
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _does_doc_chunk_exist(
|
||||
doc_chunk_id: uuid.UUID, index_name: str, http_client: httpx.Client
|
||||
def _does_document_exist(
|
||||
doc_chunk_id: str,
|
||||
index_name: str,
|
||||
http_client: httpx.Client,
|
||||
) -> bool:
|
||||
"""Returns whether the document already exists and the users/group whitelists
|
||||
Specifically in this case, document refers to a vespa document which is equivalent to a Onyx
|
||||
chunk. This checks for whether the chunk exists already in the index"""
|
||||
doc_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}"
|
||||
doc_fetch_response = http_client.get(doc_url)
|
||||
if doc_fetch_response.status_code == 404:
|
||||
@@ -62,10 +64,10 @@ def _does_doc_chunk_exist(
|
||||
if doc_fetch_response.status_code != 200:
|
||||
logger.debug(f"Failed to check for document with URL {doc_url}")
|
||||
raise RuntimeError(
|
||||
f"Unexpected fetch document by ID value from Vespa: "
|
||||
f"error={doc_fetch_response.status_code} "
|
||||
f"index={index_name} "
|
||||
f"doc_chunk_id={doc_chunk_id}"
|
||||
f"Unexpected fetch document by ID value from Vespa "
|
||||
f"with error {doc_fetch_response.status_code}"
|
||||
f"Index name: {index_name}"
|
||||
f"Doc chunk id: {doc_chunk_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -96,8 +98,8 @@ def get_existing_documents_from_chunks(
|
||||
try:
|
||||
chunk_existence_future = {
|
||||
executor.submit(
|
||||
_does_doc_chunk_exist,
|
||||
get_uuid_from_chunk(chunk),
|
||||
_does_document_exist,
|
||||
str(get_uuid_from_chunk(chunk)),
|
||||
index_name,
|
||||
http_client,
|
||||
): chunk
|
||||
@@ -246,22 +248,3 @@ def clean_chunk_id_copy(
|
||||
}
|
||||
)
|
||||
return clean_chunk
|
||||
|
||||
|
||||
def check_for_final_chunk_existence(
|
||||
minimal_doc_info: MinimalDocumentIndexingInfo,
|
||||
start_index: int,
|
||||
index_name: str,
|
||||
http_client: httpx.Client,
|
||||
) -> int:
|
||||
index = start_index
|
||||
while True:
|
||||
doc_chunk_id = get_uuid_from_chunk_info_old(
|
||||
document_id=minimal_doc_info.doc_id,
|
||||
chunk_id=index,
|
||||
large_chunk_reference_ids=[],
|
||||
)
|
||||
if not _does_doc_chunk_exist(doc_chunk_id, index_name, http_client):
|
||||
return index
|
||||
|
||||
index += 1
|
||||
|
||||
@@ -55,7 +55,9 @@ def remove_invalid_unicode_chars(text: str) -> str:
|
||||
return _illegal_xml_chars_RE.sub("", text)
|
||||
|
||||
|
||||
def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx.Client:
|
||||
def get_vespa_http_client(
|
||||
no_timeout: bool = False, http2: bool = False
|
||||
) -> httpx.Client:
|
||||
"""
|
||||
Configure and return an HTTP client for communicating with Vespa,
|
||||
including authentication if needed.
|
||||
|
||||
@@ -15,7 +15,6 @@ from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import SOURCE_TYPE
|
||||
from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -60,8 +59,7 @@ def build_vespa_filters(
|
||||
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
|
||||
# If running in multi-tenant mode, we may want to filter by tenant_id
|
||||
if filters.tenant_id and MULTI_TENANT:
|
||||
if filters.tenant_id:
|
||||
filter_str += f'({TENANT_ID} contains "{filters.tenant_id}") and '
|
||||
|
||||
# CAREFUL touching this one, currently there is no second ACL double-check post retrieval
|
||||
|
||||
@@ -35,8 +35,6 @@ DOCUMENT_ID_ENDPOINT = (
|
||||
f"{VESPA_APP_CONTAINER_URL}/document/v1/default/{{index_name}}/docid"
|
||||
)
|
||||
|
||||
# the default document id endpoint is http://localhost:8080/document/v1/default/danswer_chunk/docid
|
||||
|
||||
SEARCH_ENDPOINT = f"{VESPA_APP_CONTAINER_URL}/search/"
|
||||
|
||||
NUM_THREADS = (
|
||||
|
||||
@@ -67,9 +67,7 @@ def is_text_file_extension(file_name: str) -> bool:
|
||||
|
||||
def get_file_ext(file_path_or_name: str | Path) -> str:
|
||||
_, extension = os.path.splitext(file_path_or_name)
|
||||
# standardize all extensions to be lowercase so that checks against
|
||||
# VALID_FILE_EXTENSIONS and similar will work as intended
|
||||
return extension.lower()
|
||||
return extension
|
||||
|
||||
|
||||
def is_valid_file_ext(ext: str) -> bool:
|
||||
|
||||
@@ -73,25 +73,25 @@ def _get_metadata_suffix_for_document_index(
|
||||
return metadata_semantic, metadata_keyword
|
||||
|
||||
|
||||
def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwareChunk:
|
||||
def _combine_chunks(chunks: list[DocAwareChunk], index: int) -> DocAwareChunk:
|
||||
merged_chunk = DocAwareChunk(
|
||||
source_document=chunks[0].source_document,
|
||||
chunk_id=chunks[0].chunk_id,
|
||||
chunk_id=index,
|
||||
blurb=chunks[0].blurb,
|
||||
content=chunks[0].content,
|
||||
source_links=chunks[0].source_links or {},
|
||||
section_continuation=(chunks[0].chunk_id > 0),
|
||||
section_continuation=(index > 0),
|
||||
title_prefix=chunks[0].title_prefix,
|
||||
metadata_suffix_semantic=chunks[0].metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=chunks[0].metadata_suffix_keyword,
|
||||
large_chunk_reference_ids=[chunk.chunk_id for chunk in chunks],
|
||||
large_chunk_reference_ids=[chunks[0].chunk_id],
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=large_chunk_id,
|
||||
)
|
||||
|
||||
offset = 0
|
||||
for i in range(1, len(chunks)):
|
||||
merged_chunk.content += SECTION_SEPARATOR + chunks[i].content
|
||||
merged_chunk.large_chunk_reference_ids.append(chunks[i].chunk_id)
|
||||
|
||||
offset += len(SECTION_SEPARATOR) + len(chunks[i - 1].content)
|
||||
for link_offset, link_text in (chunks[i].source_links or {}).items():
|
||||
@@ -103,12 +103,11 @@ def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwar
|
||||
|
||||
|
||||
def generate_large_chunks(chunks: list[DocAwareChunk]) -> list[DocAwareChunk]:
|
||||
large_chunks = []
|
||||
for idx, i in enumerate(range(0, len(chunks), LARGE_CHUNK_RATIO)):
|
||||
chunk_group = chunks[i : i + LARGE_CHUNK_RATIO]
|
||||
if len(chunk_group) > 1:
|
||||
large_chunk = _combine_chunks(chunk_group, idx)
|
||||
large_chunks.append(large_chunk)
|
||||
large_chunks = [
|
||||
_combine_chunks(chunks[i : i + LARGE_CHUNK_RATIO], idx)
|
||||
for idx, i in enumerate(range(0, len(chunks), LARGE_CHUNK_RATIO))
|
||||
if len(chunks[i : i + LARGE_CHUNK_RATIO]) > 1
|
||||
]
|
||||
return large_chunks
|
||||
|
||||
|
||||
@@ -220,7 +219,6 @@ class Chunker:
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=self._get_mini_chunk_texts(text),
|
||||
large_chunk_id=None,
|
||||
)
|
||||
|
||||
for section_idx, section in enumerate(document.sections):
|
||||
|
||||
@@ -20,10 +20,8 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
)
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.document import fetch_chunk_counts_for_documents
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.document import prepare_to_modify_documents
|
||||
from onyx.db.document import update_docs_chunk_count__no_commit
|
||||
from onyx.db.document import update_docs_last_modified__no_commit
|
||||
from onyx.db.document import update_docs_updated_at__no_commit
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
@@ -36,7 +34,6 @@ from onyx.db.tag import create_or_add_document_tag
|
||||
from onyx.db.tag import create_or_add_document_tag_list
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -263,21 +260,6 @@ def index_doc_batch_prepare(
|
||||
def filter_documents(document_batch: list[Document]) -> list[Document]:
|
||||
documents: list[Document] = []
|
||||
for document in document_batch:
|
||||
# Remove any NUL characters from title/semantic_id
|
||||
# This is a known issue with the Zendesk connector
|
||||
# Postgres cannot handle NUL characters in text fields
|
||||
if document.title:
|
||||
document.title = document.title.replace("\x00", "")
|
||||
if document.semantic_identifier:
|
||||
document.semantic_identifier = document.semantic_identifier.replace(
|
||||
"\x00", ""
|
||||
)
|
||||
|
||||
# Remove NUL characters from all sections
|
||||
for section in document.sections:
|
||||
if section.text is not None:
|
||||
section.text = section.text.replace("\x00", "")
|
||||
|
||||
empty_contents = not any(section.text.strip() for section in document.sections)
|
||||
if (
|
||||
(not document.title or not document.title.strip())
|
||||
@@ -373,35 +355,16 @@ def index_doc_batch(
|
||||
# NOTE: don't need to acquire till here, since this is when the actual race condition
|
||||
# with Vespa can occur.
|
||||
with prepare_to_modify_documents(db_session=db_session, document_ids=updatable_ids):
|
||||
doc_id_to_access_info = get_access_for_documents(
|
||||
document_id_to_access_info = get_access_for_documents(
|
||||
document_ids=updatable_ids, db_session=db_session
|
||||
)
|
||||
doc_id_to_document_set = {
|
||||
document_id_to_document_set = {
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
document_ids=updatable_ids, db_session=db_session
|
||||
)
|
||||
}
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int | None] = {
|
||||
document_id: chunk_count
|
||||
for document_id, chunk_count in fetch_chunk_counts_for_documents(
|
||||
document_ids=updatable_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
}
|
||||
|
||||
doc_id_to_new_chunk_cnt: dict[str, int] = {
|
||||
document_id: len(
|
||||
[
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == document_id
|
||||
]
|
||||
)
|
||||
for document_id in updatable_ids
|
||||
}
|
||||
|
||||
# we're concerned about race conditions where multiple simultaneous indexings might result
|
||||
# in one set of metadata overwriting another one in vespa.
|
||||
# we still write data here for the immediate and most likely correct sync, but
|
||||
@@ -410,9 +373,11 @@ def index_doc_batch(
|
||||
access_aware_chunks = [
|
||||
DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=doc_id_to_access_info.get(chunk.source_document.id, no_access),
|
||||
access=document_id_to_access_info.get(
|
||||
chunk.source_document.id, no_access
|
||||
),
|
||||
document_sets=set(
|
||||
doc_id_to_document_set.get(chunk.source_document.id, [])
|
||||
document_id_to_document_set.get(chunk.source_document.id, [])
|
||||
),
|
||||
boost=(
|
||||
ctx.id_to_db_doc_map[chunk.source_document.id].boost
|
||||
@@ -430,15 +395,7 @@ def index_doc_batch(
|
||||
# A document will not be spread across different batches, so all the
|
||||
# documents with chunks in this set, are fully represented by the chunks
|
||||
# in this set
|
||||
insertion_records = document_index.index(
|
||||
chunks=access_aware_chunks,
|
||||
index_batch_params=IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=doc_id_to_new_chunk_cnt,
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=chunker.enable_large_chunks,
|
||||
),
|
||||
)
|
||||
insertion_records = document_index.index(chunks=access_aware_chunks)
|
||||
|
||||
successful_doc_ids = [record.document_id for record in insertion_records]
|
||||
successful_docs = [
|
||||
@@ -463,12 +420,6 @@ def index_doc_batch(
|
||||
document_ids=last_modified_ids, db_session=db_session
|
||||
)
|
||||
|
||||
update_docs_chunk_count__no_commit(
|
||||
document_ids=updatable_ids,
|
||||
doc_id_to_chunk_count=doc_id_to_new_chunk_cnt,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
result = (
|
||||
@@ -479,28 +430,6 @@ def index_doc_batch(
|
||||
return result
|
||||
|
||||
|
||||
def check_enable_large_chunks_and_multipass(
|
||||
embedder: IndexingEmbedder, db_session: Session
|
||||
) -> tuple[bool, bool]:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
multipass = (
|
||||
search_settings.multipass_indexing
|
||||
if search_settings
|
||||
else ENABLE_MULTIPASS_INDEXING
|
||||
)
|
||||
|
||||
enable_large_chunks = (
|
||||
multipass
|
||||
and
|
||||
# Only local models that supports larger context are from Nomic
|
||||
(embedder.model_name.startswith("nomic-ai"))
|
||||
and
|
||||
# Cohere does not support larger context they recommend not going above 512 tokens
|
||||
embedder.provider_type != EmbeddingProvider.COHERE
|
||||
)
|
||||
return multipass, enable_large_chunks
|
||||
|
||||
|
||||
def build_indexing_pipeline(
|
||||
*,
|
||||
embedder: IndexingEmbedder,
|
||||
@@ -513,8 +442,24 @@ def build_indexing_pipeline(
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
|
||||
multipass, enable_large_chunks = check_enable_large_chunks_and_multipass(
|
||||
embedder, db_session
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
multipass = (
|
||||
search_settings.multipass_indexing
|
||||
if search_settings
|
||||
else ENABLE_MULTIPASS_INDEXING
|
||||
)
|
||||
|
||||
enable_large_chunks = (
|
||||
multipass
|
||||
and
|
||||
# Only local models that supports larger context are from Nomic
|
||||
(
|
||||
embedder.provider_type is not None
|
||||
or embedder.model_name.startswith("nomic-ai")
|
||||
)
|
||||
and
|
||||
# Cohere does not support larger context they recommend not going above 512 tokens
|
||||
embedder.provider_type != EmbeddingProvider.COHERE
|
||||
)
|
||||
|
||||
chunker = chunker or Chunker(
|
||||
|
||||
@@ -47,8 +47,6 @@ class DocAwareChunk(BaseChunk):
|
||||
|
||||
mini_chunk_texts: list[str] | None
|
||||
|
||||
large_chunk_id: int | None
|
||||
|
||||
large_chunk_reference_ids: list[int] = Field(default_factory=list)
|
||||
|
||||
def to_short_descriptor(self) -> str:
|
||||
|
||||
@@ -266,27 +266,18 @@ class DefaultMultiLLM(LLM):
|
||||
# )
|
||||
self._custom_config = custom_config
|
||||
|
||||
# Create a dictionary for model-specific arguments if it's None
|
||||
model_kwargs = model_kwargs or {}
|
||||
|
||||
# NOTE: have to set these as environment variables for Litellm since
|
||||
# not all are able to passed in but they always support them set as env
|
||||
# variables. We'll also try passing them in, since litellm just ignores
|
||||
# addtional kwargs (and some kwargs MUST be passed in rather than set as
|
||||
# env variables)
|
||||
if custom_config:
|
||||
# Specifically pass in "vertex_credentials" as a model_kwarg to the
|
||||
# completion call for vertex AI. More details here:
|
||||
# https://docs.litellm.ai/docs/providers/vertex
|
||||
vertex_credentials_key = "vertex_credentials"
|
||||
vertex_credentials = custom_config.get(vertex_credentials_key)
|
||||
if vertex_credentials and model_provider == "vertex_ai":
|
||||
model_kwargs[vertex_credentials_key] = vertex_credentials
|
||||
else:
|
||||
# standard case
|
||||
for k, v in custom_config.items():
|
||||
os.environ[k] = v
|
||||
for k, v in custom_config.items():
|
||||
os.environ[k] = v
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
if custom_config:
|
||||
model_kwargs.update(custom_config)
|
||||
if extra_headers:
|
||||
model_kwargs.update({"extra_headers": extra_headers})
|
||||
if extra_body:
|
||||
|
||||
@@ -74,9 +74,6 @@ from onyx.server.manage.search_settings import router as search_settings_router
|
||||
from onyx.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from onyx.server.manage.users import router as user_router
|
||||
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
|
||||
from onyx.server.middleware.rate_limiting import close_limiter
|
||||
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
|
||||
from onyx.server.middleware.rate_limiting import setup_limiter
|
||||
from onyx.server.onyx_api.ingestion import router as onyx_api_router
|
||||
from onyx.server.openai_assistants_api.full_openai_assistants_api import (
|
||||
get_full_openai_assistants_api_router,
|
||||
@@ -156,23 +153,6 @@ def include_router_with_global_prefix_prepended(
|
||||
application.include_router(router, **final_kwargs)
|
||||
|
||||
|
||||
def include_auth_router_with_prefix(
|
||||
application: FastAPI,
|
||||
router: APIRouter,
|
||||
prefix: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Wrapper function to include an 'auth' router with prefix + rate-limiting dependencies."""
|
||||
final_tags = tags or ["auth"]
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
router,
|
||||
prefix=prefix,
|
||||
tags=final_tags,
|
||||
dependencies=get_auth_rate_limiters(),
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
# Set recursion limit
|
||||
@@ -214,15 +194,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
setup_multitenant_onyx()
|
||||
|
||||
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
||||
|
||||
# Set up rate limiter
|
||||
await setup_limiter()
|
||||
|
||||
yield
|
||||
|
||||
# Close rate limiter
|
||||
await close_limiter()
|
||||
|
||||
|
||||
def log_http_error(_: Request, exc: Exception) -> JSONResponse:
|
||||
status_code = getattr(exc, "status_code", 500)
|
||||
@@ -310,37 +283,42 @@ def get_application() -> FastAPI:
|
||||
pass
|
||||
|
||||
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_auth_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_register_router(UserRead, UserCreate),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_reset_password_router(),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_verify_router(UserRead),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_users_router(UserRead, UserUpdate),
|
||||
prefix="/users",
|
||||
tags=["users"],
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
oauth_client,
|
||||
@@ -352,13 +330,15 @@ def get_application() -> FastAPI:
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
|
||||
),
|
||||
prefix="/auth/oauth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# Need basic auth router for `logout` endpoint
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_logout_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
application.add_exception_handler(
|
||||
|
||||
@@ -131,15 +131,10 @@ class EmbeddingModel:
|
||||
tries=10, delay=10, exceptions=ModelServerRateLimitError
|
||||
)(final_make_request_func)
|
||||
|
||||
response: Response | None = None
|
||||
|
||||
try:
|
||||
response = final_make_request_func()
|
||||
return EmbedResponse(**response.json())
|
||||
except requests.HTTPError as e:
|
||||
if not response:
|
||||
raise HTTPError("HTTP error occurred - response is None.") from e
|
||||
|
||||
try:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
except Exception:
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from re import Match
|
||||
|
||||
import pytz
|
||||
import timeago # type: ignore
|
||||
@@ -57,6 +59,33 @@ def get_feedback_reminder_blocks(thread_link: str, include_followup: bool) -> Bl
|
||||
return SectionBlock(text=text)
|
||||
|
||||
|
||||
def _process_citations_for_slack(text: str) -> str:
|
||||
"""
|
||||
Converts instances of [[x]](LINK) in the input text to Slack's link format <LINK|[x]>.
|
||||
|
||||
Args:
|
||||
- text (str): The input string containing markdown links.
|
||||
|
||||
Returns:
|
||||
- str: The string with markdown links converted to Slack format.
|
||||
"""
|
||||
# Regular expression to find all instances of [[x]](LINK)
|
||||
pattern = r"\[\[(.*?)\]\]\((.*?)\)"
|
||||
|
||||
# Function to replace each found instance with Slack's format
|
||||
def slack_link_format(match: Match) -> str:
|
||||
link_text = match.group(1)
|
||||
link_url = match.group(2)
|
||||
|
||||
# Account for empty link citations
|
||||
if link_url == "":
|
||||
return f"[{link_text}]"
|
||||
return f"<{link_url}|[{link_text}]>"
|
||||
|
||||
# Substitute all matches in the input text
|
||||
return re.sub(pattern, slack_link_format, text)
|
||||
|
||||
|
||||
def _split_text(text: str, limit: int = 3000) -> list[str]:
|
||||
if len(text) <= limit:
|
||||
return [text]
|
||||
@@ -340,12 +369,15 @@ def _build_citations_blocks(
|
||||
|
||||
def _build_qa_response_blocks(
|
||||
answer: ChatOnyxBotResponse,
|
||||
process_message_for_citations: bool = False,
|
||||
) -> list[Block]:
|
||||
retrieval_info = answer.docs
|
||||
if not retrieval_info:
|
||||
# This should not happen, even with no docs retrieved, there is still info returned
|
||||
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
|
||||
|
||||
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return []
|
||||
|
||||
@@ -376,18 +408,18 @@ def _build_qa_response_blocks(
|
||||
|
||||
filter_block = SectionBlock(text=f"_{filter_text}_")
|
||||
|
||||
if not answer.answer:
|
||||
if not formatted_answer:
|
||||
answer_blocks = [
|
||||
SectionBlock(
|
||||
text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓"
|
||||
)
|
||||
]
|
||||
else:
|
||||
# replaces markdown links with slack format links
|
||||
formatted_answer = format_slack_message(answer.answer)
|
||||
answer_processed = decode_escapes(
|
||||
remove_slack_text_interactions(formatted_answer)
|
||||
)
|
||||
if process_message_for_citations:
|
||||
answer_processed = _process_citations_for_slack(answer_processed)
|
||||
answer_blocks = [
|
||||
SectionBlock(text=text) for text in _split_text(answer_processed)
|
||||
]
|
||||
@@ -493,6 +525,7 @@ def build_slack_response_blocks(
|
||||
|
||||
answer_blocks = _build_qa_response_blocks(
|
||||
answer=answer,
|
||||
process_message_for_citations=use_citations,
|
||||
)
|
||||
|
||||
web_follow_up_block = []
|
||||
|
||||
@@ -4,55 +4,73 @@ from onyx.configs.constants import DocumentSource
|
||||
def source_to_github_img_link(source: DocumentSource) -> str | None:
|
||||
# TODO: store these images somewhere better
|
||||
if source == DocumentSource.WEB.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Web.png"
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/Web.png"
|
||||
if source == DocumentSource.FILE.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/File.png"
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/File.png"
|
||||
if source == DocumentSource.GOOGLE_SITES.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/GoogleSites.png"
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/GoogleSites.png"
|
||||
if source == DocumentSource.SLACK.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Slack.png"
|
||||
return (
|
||||
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Slack.png"
|
||||
)
|
||||
if source == DocumentSource.GMAIL.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Gmail.png"
|
||||
return (
|
||||
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Gmail.png"
|
||||
)
|
||||
if source == DocumentSource.GOOGLE_DRIVE.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/GoogleDrive.png"
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/GoogleDrive.png"
|
||||
if source == DocumentSource.GITHUB.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Github.png"
|
||||
return (
|
||||
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Github.png"
|
||||
)
|
||||
if source == DocumentSource.GITLAB.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Gitlab.png"
|
||||
return (
|
||||
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Gitlab.png"
|
||||
)
|
||||
if source == DocumentSource.CONFLUENCE.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Confluence.png"
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/Confluence.png"
|
||||
if source == DocumentSource.JIRA.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Jira.png"
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/Jira.png"
|
||||
if source == DocumentSource.NOTION.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Notion.png"
|
||||
return (
|
||||
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Notion.png"
|
||||
)
|
||||
if source == DocumentSource.ZENDESK.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Zendesk.png"
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/Zendesk.png"
|
||||
if source == DocumentSource.GONG.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Gong.png"
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Gong.png"
|
||||
if source == DocumentSource.LINEAR.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Linear.png"
|
||||
return (
|
||||
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Linear.png"
|
||||
)
|
||||
if source == DocumentSource.PRODUCTBOARD.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Productboard.webp"
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Productboard.webp"
|
||||
if source == DocumentSource.SLAB.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/SlabLogo.png"
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/SlabLogo.png"
|
||||
if source == DocumentSource.ZULIP.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Zulip.png"
|
||||
return (
|
||||
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Zulip.png"
|
||||
)
|
||||
if source == DocumentSource.GURU.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Guru.png"
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/Guru.png"
|
||||
if source == DocumentSource.HUBSPOT.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/HubSpot.png"
|
||||
return (
|
||||
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/HubSpot.png"
|
||||
)
|
||||
if source == DocumentSource.DOCUMENT360.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Document360.png"
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Document360.png"
|
||||
if source == DocumentSource.BOOKSTACK.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Bookstack.png"
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Bookstack.png"
|
||||
if source == DocumentSource.LOOPIO.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Loopio.png"
|
||||
return (
|
||||
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Loopio.png"
|
||||
)
|
||||
if source == DocumentSource.SHAREPOINT.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Sharepoint.png"
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Sharepoint.png"
|
||||
if source == DocumentSource.REQUESTTRACKER.value:
|
||||
# just use file icon for now
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/File.png"
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/File.png"
|
||||
if source == DocumentSource.INGESTION_API.value:
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/File.png"
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/File.png"
|
||||
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/File.png"
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/File.png"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user