mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-21 17:55:45 +00:00
Compare commits
94 Commits
connector-
...
improve
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c26f3f4507 | ||
|
|
ea01c282b7 | ||
|
|
ec5b8b240e | ||
|
|
2d948812a5 | ||
|
|
35c5bbd1aa | ||
|
|
3536b5e7c7 | ||
|
|
bf48eb435c | ||
|
|
71c2559ea9 | ||
|
|
ceb34a41d9 | ||
|
|
82eab9d704 | ||
|
|
2b8d3a6ef5 | ||
|
|
4fb129e77b | ||
|
|
f16ca1b735 | ||
|
|
e3b2c9d944 | ||
|
|
6c9c25642d | ||
|
|
2862d8bbd3 | ||
|
|
143be6a524 | ||
|
|
c2444a5cff | ||
|
|
7f8194798a | ||
|
|
e3947e4b64 | ||
|
|
98005510ad | ||
|
|
ca54bd0b21 | ||
|
|
d26f8ce852 | ||
|
|
c8090ab75b | ||
|
|
e100a5e965 | ||
|
|
ddec239fef | ||
|
|
e83542f572 | ||
|
|
8750f14647 | ||
|
|
27699c8216 | ||
|
|
6fcd712a00 | ||
|
|
b027a08698 | ||
|
|
1db778baa8 | ||
|
|
f895e5f7d0 | ||
|
|
2fc58252f4 | ||
|
|
371d1ccd8f | ||
|
|
7fb92d42a0 | ||
|
|
af2061c4db | ||
|
|
ffec19645b | ||
|
|
67d2c86250 | ||
|
|
6c018cb53f | ||
|
|
62302e3faf | ||
|
|
0460531c72 | ||
|
|
6af07a888b | ||
|
|
ea75f5cd5d | ||
|
|
b92c183022 | ||
|
|
c191e23256 | ||
|
|
66f9124135 | ||
|
|
8f0fb70bbf | ||
|
|
ef5e5c80bb | ||
|
|
03acb6587a | ||
|
|
d1ec72b5e5 | ||
|
|
3b214133a8 | ||
|
|
2232702e99 | ||
|
|
8108ff0a4b | ||
|
|
f64e78e986 | ||
|
|
08312a4394 | ||
|
|
92add655e0 | ||
|
|
d64464ca7c | ||
|
|
ccd3983802 | ||
|
|
240f3e4fff | ||
|
|
1291b3d930 | ||
|
|
d05f1997b5 | ||
|
|
aa2e2a62b9 | ||
|
|
174e5968f8 | ||
|
|
1f27606e17 | ||
|
|
60355b84c1 | ||
|
|
680ab9ea30 | ||
|
|
c2447dbb1c | ||
|
|
52bad522f8 | ||
|
|
63e5e58313 | ||
|
|
2643782e30 | ||
|
|
3eb72e5c1d | ||
|
|
9b65c23a7e | ||
|
|
b43a8e48c6 | ||
|
|
1955c1d67b | ||
|
|
3f92ed9d29 | ||
|
|
618369f4a1 | ||
|
|
2783216781 | ||
|
|
bec0f9fb23 | ||
|
|
97a03e7fc8 | ||
|
|
8d6e8269b7 | ||
|
|
9ce2c6c517 | ||
|
|
2ad8bdbc65 | ||
|
|
a83c9b40d5 | ||
|
|
340fab1375 | ||
|
|
3ec338307f | ||
|
|
27acd3387a | ||
|
|
d14ef431a7 | ||
|
|
9bffeb65af | ||
|
|
f4806da653 | ||
|
|
e2700b2bbd | ||
|
|
fc81a3fb12 | ||
|
|
2203cfabea | ||
|
|
f4050306d6 |
14
.github/workflows/pr-python-connector-tests.yml
vendored
14
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -26,7 +26,19 @@ env:
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
# Slab
|
||||
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
|
||||
|
||||
# Zendesk
|
||||
ZENDESK_SUBDOMAIN: ${{ secrets.ZENDESK_SUBDOMAIN }}
|
||||
ZENDESK_EMAIL: ${{ secrets.ZENDESK_EMAIL }}
|
||||
ZENDESK_TOKEN: ${{ secrets.ZENDESK_TOKEN }}
|
||||
# Salesforce
|
||||
SF_USERNAME: ${{ secrets.SF_USERNAME }}
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
# Airtable
|
||||
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
<a href="https://docs.onyx.app/" target="_blank">
|
||||
<img src="https://img.shields.io/badge/docs-view-blue" alt="Documentation">
|
||||
</a>
|
||||
<a href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ" target="_blank">
|
||||
<a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA" target="_blank">
|
||||
<img src="https://img.shields.io/badge/slack-join-blue.svg?logo=slack" alt="Slack">
|
||||
</a>
|
||||
<a href="https://discord.gg/TDJ59cGV2X" target="_blank">
|
||||
|
||||
1
backend/.gitignore
vendored
1
backend/.gitignore
vendored
@@ -9,3 +9,4 @@ api_keys.py
|
||||
vespa-app.zip
|
||||
dynamic_config_storage/
|
||||
celerybeat-schedule*
|
||||
onyx/connectors/salesforce/data/
|
||||
@@ -4,7 +4,7 @@ from onyx.configs.app_configs import USE_IAM_AUTH
|
||||
from onyx.configs.app_configs import POSTGRES_HOST
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.app_configs import AWS_REGION
|
||||
from onyx.configs.app_configs import AWS_REGION_NAME
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from sqlalchemy import event
|
||||
@@ -120,7 +120,7 @@ def provide_iam_token_for_alembic(
|
||||
) -> None:
|
||||
if USE_IAM_AUTH:
|
||||
# Database connection settings
|
||||
region = AWS_REGION
|
||||
region = AWS_REGION_NAME
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
|
||||
@@ -0,0 +1,24 @@
|
||||
"""add chunk count to document
|
||||
|
||||
Revision ID: 2955778aa44c
|
||||
Revises: c0aab6edb6dd
|
||||
Create Date: 2025-01-04 11:39:43.268612
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2955778aa44c"
|
||||
down_revision = "c0aab6edb6dd"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("document", sa.Column("chunk_count", sa.Integer(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("document", "chunk_count")
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.user_group import delete_user_group
|
||||
from ee.onyx.db.user_group import fetch_user_group
|
||||
from ee.onyx.db.user_group import mark_user_group_as_synced
|
||||
from ee.onyx.db.user_group import prepare_user_group_for_deletion
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -46,11 +47,20 @@ def monitor_usergroup_taskset(
|
||||
|
||||
user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id)
|
||||
if user_group:
|
||||
usergroup_name = user_group.name
|
||||
if user_group.is_up_for_deletion:
|
||||
# this prepare should have been run when the deletion was scheduled,
|
||||
# but run it again to be sure we're ready to go
|
||||
mark_user_group_as_synced(db_session, user_group)
|
||||
prepare_user_group_for_deletion(db_session, usergroup_id)
|
||||
delete_user_group(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(f"Deleted usergroup. id='{usergroup_id}'")
|
||||
task_logger.info(
|
||||
f"Deleted usergroup: name={usergroup_name} id={usergroup_id}"
|
||||
)
|
||||
else:
|
||||
mark_user_group_as_synced(db_session=db_session, user_group=user_group)
|
||||
task_logger.info(f"Synced usergroup. id='{usergroup_id}'")
|
||||
task_logger.info(
|
||||
f"Synced usergroup. name={usergroup_name} id={usergroup_id}"
|
||||
)
|
||||
|
||||
rug.reset()
|
||||
|
||||
@@ -15,6 +15,12 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
|
||||
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
# This is a boolean that determines if anonymous access is public
|
||||
# Default behavior is to not make the page public and instead add a group
|
||||
# that contains all the users that we found in Confluence
|
||||
CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
|
||||
os.environ.get("CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC", "").lower() == "true"
|
||||
)
|
||||
# In seconds, default is 5 minutes
|
||||
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
|
||||
@@ -2,6 +2,7 @@ import datetime
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import case
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy import Date
|
||||
@@ -14,6 +15,9 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatMessageFeedback
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
|
||||
|
||||
def fetch_query_analytics(
|
||||
@@ -234,3 +238,121 @@ def fetch_persona_unique_users(
|
||||
)
|
||||
|
||||
return [tuple(row) for row in db_session.execute(query).all()]
|
||||
|
||||
|
||||
def fetch_assistant_message_analytics(
|
||||
db_session: Session,
|
||||
assistant_id: int,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
) -> list[tuple[int, datetime.date]]:
|
||||
"""
|
||||
Gets the daily message counts for a specific assistant in the given time range.
|
||||
"""
|
||||
query = (
|
||||
select(
|
||||
func.count(ChatMessage.id),
|
||||
cast(ChatMessage.time_sent, Date),
|
||||
)
|
||||
.join(
|
||||
ChatSession,
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
)
|
||||
.group_by(cast(ChatMessage.time_sent, Date))
|
||||
.order_by(cast(ChatMessage.time_sent, Date))
|
||||
)
|
||||
|
||||
return [tuple(row) for row in db_session.execute(query).all()]
|
||||
|
||||
|
||||
def fetch_assistant_unique_users(
|
||||
db_session: Session,
|
||||
assistant_id: int,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
) -> list[tuple[int, datetime.date]]:
|
||||
"""
|
||||
Gets the daily unique user counts for a specific assistant in the given time range.
|
||||
"""
|
||||
query = (
|
||||
select(
|
||||
func.count(func.distinct(ChatSession.user_id)),
|
||||
cast(ChatMessage.time_sent, Date),
|
||||
)
|
||||
.join(
|
||||
ChatSession,
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
)
|
||||
.group_by(cast(ChatMessage.time_sent, Date))
|
||||
.order_by(cast(ChatMessage.time_sent, Date))
|
||||
)
|
||||
|
||||
return [tuple(row) for row in db_session.execute(query).all()]
|
||||
|
||||
|
||||
def fetch_assistant_unique_users_total(
|
||||
db_session: Session,
|
||||
assistant_id: int,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
) -> int:
|
||||
"""
|
||||
Gets the total number of distinct users who have sent or received messages from
|
||||
the specified assistant in the given time range.
|
||||
"""
|
||||
query = (
|
||||
select(func.count(func.distinct(ChatSession.user_id)))
|
||||
.select_from(ChatMessage)
|
||||
.join(
|
||||
ChatSession,
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == assistant_id,
|
||||
ChatSession.persona_id == assistant_id,
|
||||
),
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
)
|
||||
)
|
||||
|
||||
result = db_session.execute(query).scalar()
|
||||
return result if result else 0
|
||||
|
||||
|
||||
# Users can view assistant stats if they created the persona,
|
||||
# or if they are an admin
|
||||
def user_can_view_assistant_stats(
|
||||
db_session: Session, user: User | None, assistant_id: int
|
||||
) -> bool:
|
||||
# If user is None, assume the user is an admin or auth is disabled
|
||||
if user is None or user.role == UserRole.ADMIN:
|
||||
return True
|
||||
|
||||
# Check if the user created the persona
|
||||
stmt = select(Persona).where(
|
||||
and_(Persona.id == assistant_id, Persona.user_id == user.id)
|
||||
)
|
||||
|
||||
persona = db_session.execute(stmt).scalar_one_or_none()
|
||||
return persona is not None
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
# This is a group that we use to store all the users that we found in Confluence
|
||||
# Instead of setting a page to public, we just add this group so that the page
|
||||
# is only accessible to users who have confluence accounts.
|
||||
ALL_CONF_EMAILS_GROUP_NAME = "All_Confluence_Users_Found_By_Onyx"
|
||||
@@ -4,6 +4,8 @@ https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.htm
|
||||
"""
|
||||
from typing import Any
|
||||
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
@@ -31,14 +33,32 @@ def _get_server_space_permissions(
|
||||
permission_category.get("spacePermissions", [])
|
||||
)
|
||||
|
||||
is_public = False
|
||||
user_names = set()
|
||||
group_names = set()
|
||||
for permission in viewspace_permissions:
|
||||
if user_name := permission.get("userName"):
|
||||
user_name = permission.get("userName")
|
||||
if user_name:
|
||||
user_names.add(user_name)
|
||||
if group_name := permission.get("groupName"):
|
||||
group_name = permission.get("groupName")
|
||||
if group_name:
|
||||
group_names.add(group_name)
|
||||
|
||||
# It seems that if anonymous access is turned on for the site and space,
|
||||
# then the space is publicly accessible.
|
||||
# For confluence server, we make a group that contains all users
|
||||
# that exist in confluence and then just add that group to the space permissions
|
||||
# if anonymous access is turned on for the site and space or we set is_public = True
|
||||
# if they set the env variable CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC to True so
|
||||
# that we can support confluence server deployments that want anonymous access
|
||||
# to be public (we cant test this because its paywalled)
|
||||
if user_name is None and group_name is None:
|
||||
# Defaults to False
|
||||
if CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC:
|
||||
is_public = True
|
||||
else:
|
||||
group_names.add(ALL_CONF_EMAILS_GROUP_NAME)
|
||||
|
||||
user_emails = set()
|
||||
for user_name in user_names:
|
||||
user_email = get_user_email_from_username__server(confluence_client, user_name)
|
||||
@@ -50,11 +70,7 @@ def _get_server_space_permissions(
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_names,
|
||||
# TODO: Check if the space is publicly accessible
|
||||
# Currently, we assume the space is not public
|
||||
# We need to check if anonymous access is turned on for the site and space
|
||||
# This information is paywalled so it remains unimplemented
|
||||
is_public=False,
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
|
||||
@@ -134,7 +150,7 @@ def _get_space_permissions(
|
||||
|
||||
def _extract_read_access_restrictions(
|
||||
confluence_client: OnyxConfluence, restrictions: dict[str, Any]
|
||||
) -> ExternalAccess | None:
|
||||
) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
Converts a page's restrictions dict into an ExternalAccess object.
|
||||
If there are no restrictions, then return None
|
||||
@@ -177,21 +193,57 @@ def _extract_read_access_restrictions(
|
||||
group["name"] for group in read_access_group_jsons if group.get("name")
|
||||
]
|
||||
|
||||
return set(read_access_user_emails), set(read_access_group_names)
|
||||
|
||||
|
||||
def _get_all_page_restrictions(
|
||||
confluence_client: OnyxConfluence,
|
||||
perm_sync_data: dict[str, Any],
|
||||
) -> ExternalAccess | None:
|
||||
"""
|
||||
This function gets the restrictions for a page by taking the intersection
|
||||
of the page's restrictions and the restrictions of all the ancestors
|
||||
of the page.
|
||||
If the page/ancestor has no restrictions, then it is ignored (no intersection).
|
||||
If no restrictions are found anywhere, then return None, indicating that the page
|
||||
should inherit the space's restrictions.
|
||||
"""
|
||||
found_user_emails: set[str] = set()
|
||||
found_group_names: set[str] = set()
|
||||
|
||||
found_user_emails, found_group_names = _extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=perm_sync_data.get("restrictions", {}),
|
||||
)
|
||||
|
||||
ancestors: list[dict[str, Any]] = perm_sync_data.get("ancestors", [])
|
||||
for ancestor in ancestors:
|
||||
ancestor_user_emails, ancestor_group_names = _extract_read_access_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=ancestor.get("restrictions", {}),
|
||||
)
|
||||
if not ancestor_user_emails and not ancestor_group_names:
|
||||
# This ancestor has no restrictions, so it has no effect on
|
||||
# the page's restrictions, so we ignore it
|
||||
continue
|
||||
|
||||
found_user_emails.intersection_update(ancestor_user_emails)
|
||||
found_group_names.intersection_update(ancestor_group_names)
|
||||
|
||||
# If there are no restrictions found, then the page
|
||||
# inherits the space's restrictions so return None
|
||||
is_space_public = read_access_user_emails == [] and read_access_group_names == []
|
||||
if is_space_public:
|
||||
if not found_user_emails and not found_group_names:
|
||||
return None
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(read_access_user_emails),
|
||||
external_user_group_ids=set(read_access_group_names),
|
||||
external_user_emails=found_user_emails,
|
||||
external_user_group_ids=found_group_names,
|
||||
# there is no way for a page to be individually public if the space isn't public
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
def _fetch_all_page_restrictions_for_space(
|
||||
def _fetch_all_page_restrictions(
|
||||
confluence_client: OnyxConfluence,
|
||||
slim_docs: list[SlimDocument],
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
@@ -208,11 +260,11 @@ def _fetch_all_page_restrictions_for_space(
|
||||
raise ValueError(
|
||||
f"No permission sync data found for document {slim_doc.id}"
|
||||
)
|
||||
restrictions = _extract_read_access_restrictions(
|
||||
|
||||
if restrictions := _get_all_page_restrictions(
|
||||
confluence_client=confluence_client,
|
||||
restrictions=slim_doc.perm_sync_data.get("restrictions", {}),
|
||||
)
|
||||
if restrictions:
|
||||
perm_sync_data=slim_doc.perm_sync_data,
|
||||
):
|
||||
document_restrictions.append(
|
||||
DocExternalAccess(
|
||||
doc_id=slim_doc.id,
|
||||
@@ -301,7 +353,7 @@ def confluence_doc_sync(
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
return _fetch_all_page_restrictions_for_space(
|
||||
return _fetch_all_page_restrictions(
|
||||
confluence_client=confluence_connector.confluence_client,
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -53,6 +53,7 @@ def confluence_group_sync(
|
||||
confluence_client=confluence_client,
|
||||
)
|
||||
onyx_groups: list[ExternalUserGroup] = []
|
||||
all_found_emails = set()
|
||||
for group_id, group_member_emails in group_member_email_map.items():
|
||||
onyx_groups.append(
|
||||
ExternalUserGroup(
|
||||
@@ -60,5 +61,15 @@ def confluence_group_sync(
|
||||
user_emails=list(group_member_emails),
|
||||
)
|
||||
)
|
||||
all_found_emails.update(group_member_emails)
|
||||
|
||||
# This is so that when we find a public confleunce server page, we can
|
||||
# give access to all users only in if they have an email in Confluence
|
||||
if cc_pair.connector.connector_specific_config.get("is_cloud", False):
|
||||
all_found_group = ExternalUserGroup(
|
||||
id=ALL_CONF_EMAILS_GROUP_NAME,
|
||||
user_emails=list(all_found_emails),
|
||||
)
|
||||
onyx_groups.append(all_found_group)
|
||||
|
||||
return onyx_groups
|
||||
|
||||
@@ -40,6 +40,7 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.main import get_application as get_application_base
|
||||
from onyx.main import include_auth_router_with_prefix
|
||||
from onyx.main import include_router_with_global_prefix_prepended
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
@@ -62,7 +63,7 @@ def get_application() -> FastAPI:
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
oauth_client,
|
||||
@@ -74,19 +75,17 @@ def get_application() -> FastAPI:
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
|
||||
),
|
||||
prefix="/auth/oauth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# Need basic auth router for `logout` endpoint
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_logout_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.OIDC:
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
|
||||
@@ -97,19 +96,20 @@ def get_application() -> FastAPI:
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
|
||||
),
|
||||
prefix="/auth/oidc",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# need basic auth router for `logout` endpoint
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_auth_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
elif AUTH_TYPE == AuthType.SAML:
|
||||
include_router_with_global_prefix_prepended(application, saml_router)
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
saml_router,
|
||||
)
|
||||
|
||||
# RBAC / group access control
|
||||
include_router_with_global_prefix_prepended(application, user_group_router)
|
||||
|
||||
@@ -1,17 +1,24 @@
|
||||
import datetime
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.analytics import fetch_assistant_message_analytics
|
||||
from ee.onyx.db.analytics import fetch_assistant_unique_users
|
||||
from ee.onyx.db.analytics import fetch_assistant_unique_users_total
|
||||
from ee.onyx.db.analytics import fetch_onyxbot_analytics
|
||||
from ee.onyx.db.analytics import fetch_per_user_query_analytics
|
||||
from ee.onyx.db.analytics import fetch_persona_message_analytics
|
||||
from ee.onyx.db.analytics import fetch_persona_unique_users
|
||||
from ee.onyx.db.analytics import fetch_query_analytics
|
||||
from ee.onyx.db.analytics import user_can_view_assistant_stats
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
|
||||
@@ -191,3 +198,74 @@ def get_persona_unique_users(
|
||||
)
|
||||
)
|
||||
return unique_user_counts
|
||||
|
||||
|
||||
class AssistantDailyUsageResponse(BaseModel):
|
||||
date: datetime.date
|
||||
total_messages: int
|
||||
total_unique_users: int
|
||||
|
||||
|
||||
class AssistantStatsResponse(BaseModel):
|
||||
daily_stats: List[AssistantDailyUsageResponse]
|
||||
total_messages: int
|
||||
total_unique_users: int
|
||||
|
||||
|
||||
@router.get("/assistant/{assistant_id}/stats")
|
||||
def get_assistant_stats(
|
||||
assistant_id: int,
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AssistantStatsResponse:
|
||||
"""
|
||||
Returns daily message and unique user counts for a user's assistant,
|
||||
along with the overall total messages and total distinct users.
|
||||
"""
|
||||
start = start or (
|
||||
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
|
||||
)
|
||||
end = end or datetime.datetime.utcnow()
|
||||
|
||||
if not user_can_view_assistant_stats(db_session, user, assistant_id):
|
||||
raise HTTPException(
|
||||
status_code=403, detail="Not allowed to access this assistant's stats."
|
||||
)
|
||||
|
||||
# Pull daily usage from the DB calls
|
||||
messages_data = fetch_assistant_message_analytics(
|
||||
db_session, assistant_id, start, end
|
||||
)
|
||||
unique_users_data = fetch_assistant_unique_users(
|
||||
db_session, assistant_id, start, end
|
||||
)
|
||||
|
||||
# Map each day => (messages, unique_users).
|
||||
daily_messages_map = {date: count for count, date in messages_data}
|
||||
daily_unique_users_map = {date: count for count, date in unique_users_data}
|
||||
all_dates = set(daily_messages_map.keys()) | set(daily_unique_users_map.keys())
|
||||
|
||||
# Merge both sets of metrics by date
|
||||
daily_results: list[AssistantDailyUsageResponse] = []
|
||||
for date in sorted(all_dates):
|
||||
daily_results.append(
|
||||
AssistantDailyUsageResponse(
|
||||
date=date,
|
||||
total_messages=daily_messages_map.get(date, 0),
|
||||
total_unique_users=daily_unique_users_map.get(date, 0),
|
||||
)
|
||||
)
|
||||
|
||||
# Now pull a single total distinct user count across the entire time range
|
||||
total_msgs = sum(d.total_messages for d in daily_results)
|
||||
total_users = fetch_assistant_unique_users_total(
|
||||
db_session, assistant_id, start, end
|
||||
)
|
||||
|
||||
return AssistantStatsResponse(
|
||||
daily_stats=daily_results,
|
||||
total_messages=total_msgs,
|
||||
total_unique_users=total_users,
|
||||
)
|
||||
|
||||
@@ -2,15 +2,14 @@ import logging
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Callable
|
||||
|
||||
import jwt
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
|
||||
from onyx.auth.api_key import extract_tenant_from_api_key_header
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.db.engine import is_valid_schema_name
|
||||
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
@@ -22,11 +21,11 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
try:
|
||||
tenant_id = (
|
||||
_get_tenant_id_from_request(request, logger)
|
||||
if MULTI_TENANT
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
if MULTI_TENANT:
|
||||
tenant_id = await _get_tenant_id_from_request(request, logger)
|
||||
else:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
return await call_next(request)
|
||||
|
||||
@@ -35,27 +34,36 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
raise
|
||||
|
||||
|
||||
def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter) -> str:
|
||||
# First check for API key
|
||||
async def _get_tenant_id_from_request(
|
||||
request: Request, logger: logging.LoggerAdapter
|
||||
) -> str:
|
||||
"""
|
||||
Attempt to extract tenant_id from:
|
||||
1) The API key header
|
||||
2) The Redis-based token (stored in Cookie: fastapiusersauth)
|
||||
Fallback: POSTGRES_DEFAULT_SCHEMA
|
||||
"""
|
||||
# Check for API key
|
||||
tenant_id = extract_tenant_from_api_key_header(request)
|
||||
if tenant_id is not None:
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
|
||||
# Check for cookie-based auth
|
||||
token = request.cookies.get("fastapiusersauth")
|
||||
if not token:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
USER_AUTH_SECRET,
|
||||
audience=["fastapi-users:auth"],
|
||||
algorithms=["HS256"],
|
||||
)
|
||||
tenant_id_from_payload = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
# Look up token data in Redis
|
||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||
|
||||
# Since payload.get() can return None, ensure we have a string
|
||||
if not token_data:
|
||||
logger.debug(
|
||||
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
|
||||
)
|
||||
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
|
||||
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
|
||||
# so we maintain consistency by returning it here when no valid tenant is found.
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
# Since token_data.get() can return None, ensure we have a string
|
||||
tenant_id = (
|
||||
str(tenant_id_from_payload)
|
||||
if tenant_id_from_payload is not None
|
||||
@@ -67,9 +75,6 @@ def _get_tenant_id_from_request(request: Request, logger: logging.LoggerAdapter)
|
||||
|
||||
return tenant_id
|
||||
|
||||
except jwt.InvalidTokenError:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in set_tenant_id_middleware: {str(e)}")
|
||||
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
@@ -10,11 +12,29 @@ from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
|
||||
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
@@ -62,14 +82,7 @@ class SlackOAuth:
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={cls.REDIRECT_URI}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
return url
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
@@ -77,10 +90,14 @@ class SlackOAuth:
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={cls.DEV_REDIRECT_URI}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
@@ -102,82 +119,151 @@ class SlackOAuth:
|
||||
return session
|
||||
|
||||
|
||||
# Work in progress
|
||||
# class ConfluenceCloudOAuth:
|
||||
# """work in progress"""
|
||||
class ConfluenceCloudOAuth:
|
||||
"""work in progress"""
|
||||
|
||||
# # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
|
||||
# class OAuthSession(BaseModel):
|
||||
# """Stored in redis to be looked up on callback"""
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
# email: str
|
||||
# redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
# CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
|
||||
# CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
# TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
|
||||
# # All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
# CONFLUENCE_OAUTH_SCOPE = (
|
||||
# "read:confluence-props%20"
|
||||
# "read:confluence-content.all%20"
|
||||
# "read:confluence-content.summary%20"
|
||||
# "read:confluence-content.permission%20"
|
||||
# "read:confluence-user%20"
|
||||
# "read:confluence-groups%20"
|
||||
# "readonly:content.attachment:confluence"
|
||||
# )
|
||||
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
CONFLUENCE_OAUTH_SCOPE = (
|
||||
"read:confluence-props%20"
|
||||
"read:confluence-content.all%20"
|
||||
"read:confluence-content.summary%20"
|
||||
"read:confluence-content.permission%20"
|
||||
"read:confluence-user%20"
|
||||
"read:confluence-groups%20"
|
||||
"readonly:content.attachment:confluence"
|
||||
)
|
||||
|
||||
# REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
# DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
# # eventually for Confluence Data Center
|
||||
# # oauth_url = (
|
||||
# # f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# # f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# # f"&redirect_uri={redirectme_uri}"
|
||||
# # )
|
||||
# eventually for Confluence Data Center
|
||||
# oauth_url = (
|
||||
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# f"&redirect_uri={redirectme_uri}"
|
||||
# )
|
||||
|
||||
# @classmethod
|
||||
# def generate_oauth_url(cls, state: str) -> str:
|
||||
# return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
# @classmethod
|
||||
# def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
# """dev mode workaround for localhost testing
|
||||
# - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
# """
|
||||
# return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
# @classmethod
|
||||
# def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# url = (
|
||||
# "https://auth.atlassian.com/authorize"
|
||||
# f"?audience=api.atlassian.com"
|
||||
# f"&client_id={cls.CLIENT_ID}"
|
||||
# f"&redirect_uri={redirect_uri}"
|
||||
# f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
# f"&state={state}"
|
||||
# "&response_type=code"
|
||||
# "&prompt=consent"
|
||||
# )
|
||||
# return url
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
"https://auth.atlassian.com/authorize"
|
||||
f"?audience=api.atlassian.com"
|
||||
f"&client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
f"&state={state}"
|
||||
"&response_type=code"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
# @classmethod
|
||||
# def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
# """Temporary state to store in redis. to be looked up on auth response.
|
||||
# Returns a json string.
|
||||
# """
|
||||
# session = ConfluenceCloudOAuth.OAuthSession(
|
||||
# email=email, redirect_on_success=redirect_on_success
|
||||
# )
|
||||
# return session.model_dump_json()
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = ConfluenceCloudOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
# @classmethod
|
||||
# def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
|
||||
# session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
# return session
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
class GoogleDriveOAuth:
|
||||
# https://developers.google.com/identity/protocols/oauth2
|
||||
# https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# SCOPE is per https://docs.onyx.app/connectors/google-drive
|
||||
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
|
||||
SCOPE = (
|
||||
"https://www.googleapis.com/auth/drive.readonly%20"
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# without prompt=consent, a refresh token is only issued the first time the user approves
|
||||
url = (
|
||||
f"https://accounts.google.com/o/oauth2/v2/auth"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
"&response_type=code"
|
||||
f"&scope={cls.SCOPE}"
|
||||
"&access_type=offline"
|
||||
f"&state={state}"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = GoogleDriveOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/prepare-authorization-request")
|
||||
@@ -192,8 +278,11 @@ def prepare_authorization_request(
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# urlsafe b64 encode the uuid for the oauth url
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
@@ -203,6 +292,11 @@ def prepare_authorization_request(
|
||||
session = SlackOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
|
||||
session = GoogleDriveOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
# elif connector == DocumentSource.CONFLUENCE:
|
||||
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
|
||||
# session = ConfluenceCloudOAuth.session_dump_json(
|
||||
@@ -210,8 +304,6 @@ def prepare_authorization_request(
|
||||
# )
|
||||
# elif connector == DocumentSource.JIRA:
|
||||
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
|
||||
# elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
# oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = None
|
||||
|
||||
@@ -223,6 +315,7 @@ def prepare_authorization_request(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# store important session state to retrieve when the user is redirected back
|
||||
# 10 min is the max we want an oauth flow to be valid
|
||||
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
|
||||
|
||||
@@ -421,3 +514,116 @@ def handle_slack_oauth_callback(
|
||||
# "redirect_on_success": session.redirect_on_success,
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
@router.post("/connector/google-drive/callback")
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = GoogleDriveOAuth.parse_session(session_json)
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
GoogleDriveOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": GoogleDriveOAuth.CLIENT_ID,
|
||||
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
authorization_response: dict[str, Any] = response.json()
|
||||
|
||||
# the connector wants us to store the json in its authorized_user_info format
|
||||
# returned from OAuthCredentials.get_authorized_user_info().
|
||||
# So refresh immediately via get_google_oauth_creds with the params filled in
|
||||
# from fields in authorization_response to get the json we need
|
||||
authorized_user_info = {}
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
|
||||
)
|
||||
if not oauth_creds:
|
||||
raise RuntimeError("get_google_oauth_creds returned None.")
|
||||
|
||||
# save off the credentials
|
||||
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
|
||||
credential_dict: dict[str, str] = {}
|
||||
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
|
||||
credential_dict[
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD
|
||||
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
name="OAuth (interactive)",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Google Drive OAuth completed successfully.",
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -13,9 +13,8 @@ from ee.onyx.db.usage_export import get_all_empty_chat_message_entries
|
||||
from ee.onyx.db.usage_export import write_usage_report
|
||||
from ee.onyx.server.reporting.usage_export_models import UsageReportMetadata
|
||||
from ee.onyx.server.reporting.usage_export_models import UserSkeleton
|
||||
from onyx.auth.schemas import UserStatus
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.db.users import list_users
|
||||
from onyx.db.users import get_all_users
|
||||
from onyx.file_store.constants import MAX_IN_MEMORY_SIZE
|
||||
from onyx.file_store.file_store import FileStore
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -84,15 +83,15 @@ def generate_user_report(
|
||||
max_size=MAX_IN_MEMORY_SIZE, mode="w+"
|
||||
) as temp_file:
|
||||
csvwriter = csv.writer(temp_file, delimiter=",")
|
||||
csvwriter.writerow(["user_id", "status"])
|
||||
csvwriter.writerow(["user_id", "is_active"])
|
||||
|
||||
users = list_users(db_session)
|
||||
users = get_all_users(db_session)
|
||||
for user in users:
|
||||
user_skeleton = UserSkeleton(
|
||||
user_id=str(user.id),
|
||||
status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED,
|
||||
is_active=user.is_active,
|
||||
)
|
||||
csvwriter.writerow([user_skeleton.user_id, user_skeleton.status])
|
||||
csvwriter.writerow([user_skeleton.user_id, user_skeleton.is_active])
|
||||
|
||||
temp_file.seek(0)
|
||||
file_store.save_file(
|
||||
|
||||
@@ -4,8 +4,6 @@ from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.schemas import UserStatus
|
||||
|
||||
|
||||
class FlowType(str, Enum):
|
||||
CHAT = "chat"
|
||||
@@ -22,7 +20,7 @@ class ChatMessageSkeleton(BaseModel):
|
||||
|
||||
class UserSkeleton(BaseModel):
|
||||
user_id: str
|
||||
status: UserStatus
|
||||
is_active: bool
|
||||
|
||||
|
||||
class UsageReportMetadata(BaseModel):
|
||||
|
||||
@@ -19,7 +19,7 @@ from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_jwt_strategy
|
||||
from onyx.auth.users import get_redis_strategy
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.auth import get_user_count
|
||||
@@ -112,7 +112,7 @@ async def impersonate_user(
|
||||
)
|
||||
if user_to_impersonate is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
token = await get_jwt_strategy().write_token(user_to_impersonate)
|
||||
token = await get_redis_strategy().write_token(user_to_impersonate)
|
||||
|
||||
response = await auth_backend.transport.get_login_response(token)
|
||||
response.set_cookie(
|
||||
|
||||
@@ -46,6 +46,7 @@ def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscr
|
||||
"""
|
||||
Send a request to the control service to register the number of users for a tenant.
|
||||
"""
|
||||
|
||||
if not STRIPE_PRICE_ID:
|
||||
raise Exception("STRIPE_PRICE_ID is not set")
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
@@ -320,8 +321,6 @@ async def embed_text(
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
) -> list[Embedding]:
|
||||
logger.info(f"Embedding {len(texts)} texts with provider: {provider_type}")
|
||||
|
||||
if not all(texts):
|
||||
logger.error("Empty strings provided for embedding")
|
||||
raise ValueError("Empty strings are not allowed for embedding.")
|
||||
@@ -330,8 +329,17 @@ async def embed_text(
|
||||
logger.error("No texts provided for embedding")
|
||||
raise ValueError("No texts provided for embedding.")
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
total_chars = 0
|
||||
for text in texts:
|
||||
total_chars += len(text)
|
||||
|
||||
if provider_type is not None:
|
||||
logger.debug(f"Using cloud provider {provider_type} for embedding")
|
||||
logger.info(
|
||||
f"Embedding {len(texts)} texts with {total_chars} total characters with provider: {provider_type}"
|
||||
)
|
||||
|
||||
if api_key is None:
|
||||
logger.error("API key not provided for cloud model")
|
||||
raise RuntimeError("API key not provided for cloud model")
|
||||
@@ -363,8 +371,16 @@ async def embed_text(
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info(
|
||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||
f"with provider {provider_type} in {elapsed:.2f}"
|
||||
)
|
||||
elif model_name is not None:
|
||||
logger.debug(f"Using local model {model_name} for embedding")
|
||||
logger.info(
|
||||
f"Embedding {len(texts)} texts with {total_chars} total characters with local model: {model_name}"
|
||||
)
|
||||
|
||||
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
|
||||
|
||||
local_model = get_embedding_model(
|
||||
@@ -382,13 +398,17 @@ async def embed_text(
|
||||
for embedding in embeddings_vectors
|
||||
]
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info(
|
||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||
f"with local model {model_name} in {elapsed:.2f}"
|
||||
)
|
||||
else:
|
||||
logger.error("Neither model name nor provider specified for embedding")
|
||||
raise ValueError(
|
||||
"Either model name or provider must be provided to run embeddings."
|
||||
)
|
||||
|
||||
logger.info(f"Successfully embedded {len(texts)} texts")
|
||||
return embeddings
|
||||
|
||||
|
||||
@@ -440,7 +460,8 @@ async def process_embed_request(
|
||||
) -> EmbedResponse:
|
||||
if not embed_request.texts:
|
||||
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
||||
elif not all(embed_request.texts):
|
||||
|
||||
if not all(embed_request.texts):
|
||||
raise ValueError("Empty strings are not allowed for embedding.")
|
||||
|
||||
try:
|
||||
@@ -471,9 +492,12 @@ async def process_embed_request(
|
||||
detail=str(e),
|
||||
)
|
||||
except Exception as e:
|
||||
exception_detail = f"Error during embedding process:\n{str(e)}"
|
||||
logger.exception(exception_detail)
|
||||
raise HTTPException(status_code=500, detail=exception_detail)
|
||||
logger.exception(
|
||||
f"Error during embedding process: provider={embed_request.provider_type} model={embed_request.model_name}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Error during embedding process: {e}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/cross-encoder-scores")
|
||||
|
||||
@@ -44,6 +44,7 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
|
||||
the files in the existing huggingface cache that don't exist in the temp
|
||||
huggingface cache.
|
||||
"""
|
||||
|
||||
for item in source.iterdir():
|
||||
target_path = dest / item.relative_to(source)
|
||||
if item.is_dir():
|
||||
|
||||
@@ -40,21 +40,24 @@ def send_email(
|
||||
|
||||
|
||||
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
||||
subject = "Invitation to Join Onyx Workspace"
|
||||
subject = "Invitation to Join Onyx Organization"
|
||||
body = dedent(
|
||||
f"""\
|
||||
Hello,
|
||||
|
||||
You have been invited to join a workspace on Onyx.
|
||||
You have been invited to join an organization on Onyx.
|
||||
|
||||
To join the workspace, please visit the following link:
|
||||
To join the organization, please visit the following link:
|
||||
|
||||
{WEB_DOMAIN}/auth/login
|
||||
{WEB_DOMAIN}/auth/signup?email={user_email}
|
||||
|
||||
You'll be asked to set a password or login with Google to complete your registration.
|
||||
|
||||
Best regards,
|
||||
The Onyx Team
|
||||
"""
|
||||
)
|
||||
|
||||
send_email(user_email, subject, body, current_user.email)
|
||||
|
||||
|
||||
|
||||
@@ -30,13 +30,16 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
)
|
||||
|
||||
|
||||
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
|
||||
def fetch_no_auth_user(
|
||||
store: KeyValueStore, *, anonymous_user_enabled: bool | None = None
|
||||
) -> UserInfo:
|
||||
return UserInfo(
|
||||
id=NO_AUTH_USER_ID,
|
||||
email=NO_AUTH_USER_EMAIL,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.ADMIN,
|
||||
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
|
||||
preferences=load_no_auth_user_preferences(store),
|
||||
is_anonymous_user=anonymous_user_enabled,
|
||||
)
|
||||
|
||||
@@ -33,12 +33,6 @@ class UserRole(str, Enum):
|
||||
]
|
||||
|
||||
|
||||
class UserStatus(str, Enum):
|
||||
LIVE = "live"
|
||||
INVITED = "invited"
|
||||
DEACTIVATED = "deactivated"
|
||||
|
||||
|
||||
class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
role: UserRole
|
||||
|
||||
@@ -49,4 +43,7 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
"""
|
||||
Role updates are not allowed through the user update endpoint for security reasons
|
||||
Role changes should be handled through a separate, admin-only process
|
||||
"""
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import json
|
||||
import secrets
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
@@ -29,10 +31,8 @@ from fastapi_users import schemas
|
||||
from fastapi_users import UUIDIDMixin
|
||||
from fastapi_users.authentication import AuthenticationBackend
|
||||
from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication import JWTStrategy
|
||||
from fastapi_users.authentication import RedisStrategy
|
||||
from fastapi_users.authentication import Strategy
|
||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.exceptions import UserAlreadyExists
|
||||
from fastapi_users.jwt import decode_jwt
|
||||
from fastapi_users.jwt import generate_jwt
|
||||
@@ -46,7 +46,6 @@ from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback
|
||||
from httpx_oauth.oauth2 import BaseOAuth2
|
||||
from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
@@ -59,6 +58,8 @@ from onyx.auth.schemas import UserUpdate
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import REDIS_AUTH_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
@@ -69,10 +70,10 @@ from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.api_key import fetch_user_for_api_key
|
||||
from onyx.db.auth import get_access_token_db
|
||||
from onyx.db.auth import get_default_admin_user_emails
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.auth import get_user_db
|
||||
@@ -80,11 +81,11 @@ from onyx.db.auth import SQLAlchemyUserAdminDB
|
||||
from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
@@ -98,6 +99,11 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
@@ -138,6 +144,20 @@ def user_needs_to_be_verified() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def anonymous_user_enabled() -> bool:
|
||||
if MULTI_TENANT:
|
||||
return False
|
||||
|
||||
redis_client = get_redis_client(tenant_id=None)
|
||||
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
assert isinstance(value, bytes)
|
||||
return int(value.decode("utf-8")) == 1
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
whitelist = get_invited_users()
|
||||
if not whitelist:
|
||||
@@ -252,7 +272,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
if not user.role.is_web_login() and user_create.role.is_web_login():
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
@@ -376,11 +395,9 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# Explicitly set the Postgres schema for this session to ensure
|
||||
# OAuth account creation happens in the correct tenant schema
|
||||
await db_session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
|
||||
await self.on_after_register(user, request)
|
||||
|
||||
else:
|
||||
@@ -399,7 +416,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
|
||||
if expires_at and TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc)
|
||||
await self.user_db.update(
|
||||
@@ -562,49 +578,70 @@ cookie_transport = CookieTransport(
|
||||
)
|
||||
|
||||
|
||||
# This strategy is used to add tenant_id to the JWT token
|
||||
class TenantAwareJWTStrategy(JWTStrategy):
|
||||
async def _create_token_data(self, user: User, impersonate: bool = False) -> dict:
|
||||
def get_redis_strategy() -> RedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
|
||||
|
||||
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
"""
|
||||
A custom strategy that fetches the actual async Redis connection inside each method.
|
||||
We do NOT pass a synchronous or "coroutine" redis object to the constructor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lifetime_seconds: Optional[int] = REDIS_AUTH_EXPIRE_TIME_SECONDS,
|
||||
key_prefix: str = REDIS_AUTH_KEY_PREFIX,
|
||||
):
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
self.key_prefix = key_prefix
|
||||
|
||||
async def write_token(self, user: User) -> str:
|
||||
redis = await get_async_redis_connection()
|
||||
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=user.email,
|
||||
)
|
||||
)(email=user.email)
|
||||
|
||||
data = {
|
||||
token_data = {
|
||||
"sub": str(user.id),
|
||||
"aud": self.token_audience,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
return data
|
||||
|
||||
async def write_token(self, user: User) -> str:
|
||||
data = await self._create_token_data(user)
|
||||
return generate_jwt(
|
||||
data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm
|
||||
token = secrets.token_urlsafe()
|
||||
await redis.set(
|
||||
f"{self.key_prefix}{token}",
|
||||
json.dumps(token_data),
|
||||
ex=self.lifetime_seconds,
|
||||
)
|
||||
return token
|
||||
|
||||
async def read_token(
|
||||
self, token: Optional[str], user_manager: BaseUserManager[User, uuid.UUID]
|
||||
) -> Optional[User]:
|
||||
redis = await get_async_redis_connection()
|
||||
token_data_str = await redis.get(f"{self.key_prefix}{token}")
|
||||
if not token_data_str:
|
||||
return None
|
||||
|
||||
def get_jwt_strategy() -> TenantAwareJWTStrategy:
|
||||
return TenantAwareJWTStrategy(
|
||||
secret=USER_AUTH_SECRET,
|
||||
lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS,
|
||||
)
|
||||
try:
|
||||
token_data = json.loads(token_data_str)
|
||||
user_id = token_data["sub"]
|
||||
parsed_id = user_manager.parse_id(user_id)
|
||||
return await user_manager.get(parsed_id)
|
||||
except (exceptions.UserNotExists, exceptions.InvalidID, KeyError):
|
||||
return None
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS # type: ignore
|
||||
)
|
||||
async def destroy_token(self, token: str, user: User) -> None:
|
||||
"""Properly delete the token from async redis."""
|
||||
redis = await get_async_redis_connection()
|
||||
await redis.delete(f"{self.key_prefix}{token}")
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="jwt", transport=cookie_transport, get_strategy=get_jwt_strategy
|
||||
) # type: ignore
|
||||
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
|
||||
)
|
||||
|
||||
|
||||
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
@@ -690,30 +727,36 @@ async def double_check_user(
|
||||
user: User | None,
|
||||
optional: bool = DISABLE_AUTH,
|
||||
include_expired: bool = False,
|
||||
allow_anonymous_access: bool = False,
|
||||
) -> User | None:
|
||||
if optional:
|
||||
return user
|
||||
|
||||
if user is not None:
|
||||
# If user attempted to authenticate, verify them, do not default
|
||||
# to anonymous access if it fails.
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if (
|
||||
user.oidc_expiry
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
if allow_anonymous_access:
|
||||
return None
|
||||
|
||||
if user is None:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if (
|
||||
user.oidc_expiry
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
return user
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
|
||||
async def current_user_with_expired_token(
|
||||
@@ -728,6 +771,14 @@ async def current_limited_user(
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_chat_accesssible_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
return await double_check_user(
|
||||
user, allow_anonymous_access=anonymous_user_enabled()
|
||||
)
|
||||
|
||||
|
||||
async def current_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
|
||||
@@ -335,6 +335,10 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
if not hasattr(sender, "primary_worker_lock"):
|
||||
# primary_worker_lock will not exist when MULTI_TENANT is True
|
||||
return
|
||||
|
||||
if not sender.primary_worker_lock:
|
||||
return
|
||||
|
||||
@@ -414,11 +418,21 @@ def on_setup_logging(
|
||||
task_logger.setLevel(loglevel)
|
||||
task_logger.propagate = False
|
||||
|
||||
# Hide celery task received and succeeded/failed messages
|
||||
# hide celery task received spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] received"
|
||||
strategy.logger.setLevel(logging.WARNING)
|
||||
|
||||
# uncomment this to hide celery task succeeded/failed spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] succeeded in 0.03137450001668185s: None"
|
||||
trace.logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def set_task_finished_log_level(logLevel: int) -> None:
|
||||
"""call this to override the setLevel in on_setup_logging. We are interested
|
||||
in the task timings in the cloud but it can be spammy for self hosted."""
|
||||
trace.logger.setLevel(logLevel)
|
||||
|
||||
|
||||
class TenantContextFilter(logging.Filter):
|
||||
|
||||
"""Logging filter to inject tenant ID into the logger's name."""
|
||||
|
||||
@@ -60,7 +60,12 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
|
||||
|
||||
# rkuo: been seeing transient connection exceptions here, so upping the connection count
|
||||
# from just concurrency/concurrency to concurrency/concurrency*2
|
||||
SqlEngine.init_engine(
|
||||
pool_size=sender.concurrency, max_overflow=sender.concurrency * 2
|
||||
)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -88,12 +89,12 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker. This is unnecessary in the multi tenant scenario
|
||||
r = get_redis_client(tenant_id=None)
|
||||
@@ -194,6 +195,10 @@ def on_setup_logging(
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
# this can be spammy, so just enable it in the cloud for now
|
||||
if MULTI_TENANT:
|
||||
app_base.set_task_finished_log_level(logging.INFO)
|
||||
|
||||
|
||||
class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
"""Regularly reacquires the primary worker lock outside of the task queue.
|
||||
@@ -281,5 +286,6 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.llm_model_update",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -19,7 +19,10 @@ task_acks_late = shared_config.task_acks_late
|
||||
# Indexing worker specific ... this lets us track the transition to STARTED in redis
|
||||
# We don't currently rely on this but it has the potential to be useful and
|
||||
# indexing tasks are not high volume
|
||||
task_track_started = True
|
||||
|
||||
# we don't turn this on yet because celery occasionally runs tasks more than once
|
||||
# which means a duplicate run might change the task state unexpectedly
|
||||
# task_track_started = True
|
||||
|
||||
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
# we might be able to reduce this greatly if we can run a unified
|
||||
# loop across all tenants rather than tasks per tenant
|
||||
|
||||
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# we set expires because it isn't necessary to queue up these tasks
|
||||
# it's only important that they run relatively regularly
|
||||
tasks_to_schedule = [
|
||||
@@ -13,7 +20,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -22,7 +29,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -31,7 +38,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -40,7 +47,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -49,7 +56,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOWEST,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -58,7 +65,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -67,7 +74,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -76,11 +83,25 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": 60,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return tasks_to_schedule
|
||||
|
||||
@@ -34,7 +34,9 @@ class TaskDependencyError(RuntimeError):
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_connector_deletion_task(
|
||||
self: Task, *, tenant_id: str | None
|
||||
) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -45,7 +47,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
return None
|
||||
|
||||
# collect cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
@@ -81,6 +83,8 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> N
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_generate_document_cc_pair_cleanup_tasks(
|
||||
app: Celery,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from time import sleep
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
@@ -18,6 +20,7 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -88,10 +91,10 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -99,7 +102,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
return None
|
||||
|
||||
# get all cc pairs that need to be synced
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
@@ -128,6 +131,8 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
@@ -219,6 +224,43 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_permission_sync_generator_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector.permissions.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_permission_sync_generator_task - fence not found: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector.permissions.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError(
|
||||
"connector_permission_sync_generator_task: payload invalid or not found"
|
||||
)
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_permission_sync_generator_task - Waiting for fence: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"connector_permission_sync_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
@@ -254,8 +296,11 @@ def connector_permission_sync_generator_task(
|
||||
if not payload:
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
new_payload = RedisConnectorPermissionSyncPayload(
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=payload.celery_task_id,
|
||||
)
|
||||
redis_connector.permissions.set_fence(new_payload)
|
||||
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
||||
|
||||
|
||||
@@ -94,10 +94,10 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -105,7 +105,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
return None
|
||||
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -149,6 +149,8 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_external_group_sync_task(
|
||||
app: Celery,
|
||||
@@ -162,7 +164,7 @@ def try_creating_external_group_sync_task(
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock = r.lock(
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
from typing import cast
|
||||
|
||||
import redis
|
||||
import sentry_sdk
|
||||
@@ -23,6 +24,7 @@ from onyx.background.indexing.job_client import SimpleJobClient
|
||||
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -60,6 +62,7 @@ from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndexPayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
@@ -71,14 +74,18 @@ logger = setup_logger()
|
||||
|
||||
|
||||
class IndexingCallback(IndexingHeartbeatInterface):
|
||||
PARENT_CHECK_INTERVAL = 60
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_pid: int,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_pid = parent_pid
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
@@ -88,17 +95,43 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
|
||||
self.last_tag: str = "IndexingCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_parent_check = time.monotonic()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
|
||||
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
|
||||
# so leave this code in until we're ready to test it.
|
||||
|
||||
# if self.parent_pid:
|
||||
# # check if the parent pid is alive so we aren't running as a zombie
|
||||
# now = time.monotonic()
|
||||
# if now - self.last_parent_check > IndexingCallback.PARENT_CHECK_INTERVAL:
|
||||
# try:
|
||||
# # this is unintuitive, but it checks if the parent pid is still running
|
||||
# os.kill(self.parent_pid, 0)
|
||||
# except Exception:
|
||||
# logger.exception("IndexingCallback - parent pid check exceptioned")
|
||||
# raise
|
||||
# self.last_parent_check = now
|
||||
|
||||
try:
|
||||
self.redis_lock.reacquire()
|
||||
current_time = time.monotonic()
|
||||
if current_time - self.last_lock_monotonic >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
self.redis_lock.reacquire()
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_tag = tag
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"IndexingCallback - lock.reacquire exceptioned: "
|
||||
@@ -109,29 +142,7 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
f"now={datetime.now(timezone.utc)}"
|
||||
)
|
||||
|
||||
# diagnostic logging for lock errors
|
||||
name = self.redis_lock.name
|
||||
ttl = self.redis_client.ttl(name)
|
||||
locked = self.redis_lock.locked()
|
||||
owned = self.redis_lock.owned()
|
||||
local_token: str | None = self.redis_lock.local.token # type: ignore
|
||||
|
||||
remote_token_raw = self.redis_client.get(self.redis_lock.name)
|
||||
if remote_token_raw:
|
||||
remote_token_bytes = cast(bytes, remote_token_raw)
|
||||
remote_token = remote_token_bytes.decode("utf-8")
|
||||
else:
|
||||
remote_token = None
|
||||
|
||||
logger.warning(
|
||||
f"IndexingCallback - lock diagnostics: "
|
||||
f"name={name} "
|
||||
f"locked={locked} "
|
||||
f"owned={owned} "
|
||||
f"local_token={local_token} "
|
||||
f"remote_token={remote_token} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
redis_lock_dump(self.redis_lock, self.redis_client)
|
||||
raise
|
||||
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
@@ -323,6 +334,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
# This shouldn't ever happen!
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
lock_beat.reacquire()
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(
|
||||
db_session, redis_client
|
||||
)
|
||||
@@ -346,6 +358,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
|
||||
lock_beat.reacquire()
|
||||
# clear any indexing fences that don't have associated celery tasks in progress
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
@@ -373,6 +386,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
"check_for_indexing - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.debug(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
|
||||
@@ -772,7 +786,6 @@ def connector_indexing_proxy_task(
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
@@ -789,23 +802,26 @@ def connector_indexing_proxy_task(
|
||||
|
||||
# if the job is done, clean up and break
|
||||
if job.done():
|
||||
if job.status == "error":
|
||||
ignore_exitcode = False
|
||||
try:
|
||||
if job.status == "error":
|
||||
ignore_exitcode = False
|
||||
|
||||
exit_code: int | None = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
exit_code: int | None = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
|
||||
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
|
||||
# even though logging clearly indicates that they completed successfully
|
||||
# to work around this, we ignore the job error state if the completion signal is OK
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
|
||||
# even though logging clearly indicates successful completion
|
||||
# to work around this, we ignore the job error state if the completion signal is OK
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if not ignore_exitcode:
|
||||
raise RuntimeError("Spawned task exceptioned.")
|
||||
|
||||
if ignore_exitcode:
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...: "
|
||||
@@ -815,18 +831,21 @@ def connector_indexing_proxy_task(
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code}"
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
|
||||
raise
|
||||
finally:
|
||||
job.release()
|
||||
|
||||
job.release()
|
||||
break
|
||||
|
||||
# if a termination signal is detected, clean up and break
|
||||
@@ -911,7 +930,7 @@ def connector_indexing_task_wrapper(
|
||||
tenant_id,
|
||||
is_ee,
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"connector_indexing_task exceptioned: "
|
||||
f"tenant={tenant_id} "
|
||||
@@ -919,7 +938,14 @@ def connector_indexing_task_wrapper(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
raise
|
||||
|
||||
# There is a cloud related bug outside of our code
|
||||
# where spawned tasks return with an exit code of 1.
|
||||
# Unfortunately, exceptions also return with an exit code of 1,
|
||||
# so just raising an exception isn't informative
|
||||
# Exiting with 255 makes it possible to distinguish between normal exits
|
||||
# and exceptions.
|
||||
sys.exit(255)
|
||||
|
||||
return result
|
||||
|
||||
@@ -991,7 +1017,17 @@ def connector_indexing_task(
|
||||
f"fence={redis_connector.stop.fence_key}"
|
||||
)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector_index.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
|
||||
@@ -1032,7 +1068,9 @@ def connector_indexing_task(
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
f"index_attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1068,6 +1106,7 @@ def connector_indexing_task(
|
||||
|
||||
# define a callback class
|
||||
callback = IndexingCallback(
|
||||
os.getppid(),
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector_index.generator_progress_key,
|
||||
lock,
|
||||
@@ -1101,8 +1140,19 @@ def connector_indexing_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
if attempt_found:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(index_attempt_id, db_session, failure_reason=str(e))
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id, db_session, failure_reason=str(e)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
raise e
|
||||
finally:
|
||||
|
||||
105
backend/onyx/background/celery/tasks/llm_model_update/tasks.py
Normal file
105
backend/onyx/background/celery/tasks/llm_model_update/tasks.py
Normal file
@@ -0,0 +1,105 @@
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import LLMProvider
|
||||
|
||||
|
||||
def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
# Handle case where response is wrapped in a "data" field
|
||||
if isinstance(model_list_json, dict) and "data" in model_list_json:
|
||||
model_list_json = model_list_json["data"]
|
||||
|
||||
if not isinstance(model_list_json, list):
|
||||
raise ValueError(
|
||||
f"Invalid response from API - expected list, got {type(model_list_json)}"
|
||||
)
|
||||
|
||||
# Handle both string list and object list cases
|
||||
model_names: list[str] = []
|
||||
for item in model_list_json:
|
||||
if isinstance(item, str):
|
||||
model_names.append(item)
|
||||
elif isinstance(item, dict) and "model_name" in item:
|
||||
model_names.append(item["model_name"])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid item in model list - expected string or dict with model_name, got {type(item)}"
|
||||
)
|
||||
|
||||
return model_names
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
if not LLM_MODEL_UPDATE_API_URL:
|
||||
raise ValueError("LLM model update API URL not configured")
|
||||
|
||||
# First fetch the models from the API
|
||||
try:
|
||||
response = requests.get(LLM_MODEL_UPDATE_API_URL)
|
||||
response.raise_for_status()
|
||||
available_models = _process_model_list_response(response.json())
|
||||
task_logger.info(f"Found available models: {available_models}")
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Failed to fetch models from API.")
|
||||
return None
|
||||
|
||||
# Then update the database with the fetched models
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the default LLM provider
|
||||
default_provider = (
|
||||
db_session.query(LLMProvider)
|
||||
.filter(LLMProvider.is_default_provider.is_(True))
|
||||
.first()
|
||||
)
|
||||
|
||||
if not default_provider:
|
||||
task_logger.warning("No default LLM provider found")
|
||||
return None
|
||||
|
||||
# log change if any
|
||||
old_models = set(default_provider.model_names or [])
|
||||
new_models = set(available_models)
|
||||
added_models = new_models - old_models
|
||||
removed_models = old_models - new_models
|
||||
|
||||
if added_models:
|
||||
task_logger.info(f"Adding models: {sorted(added_models)}")
|
||||
if removed_models:
|
||||
task_logger.info(f"Removing models: {sorted(removed_models)}")
|
||||
|
||||
# Update the provider's model list
|
||||
default_provider.model_names = available_models
|
||||
# if the default model is no longer available, set it to the first model in the list
|
||||
if default_provider.default_model_name not in available_models:
|
||||
task_logger.info(
|
||||
f"Default model {default_provider.default_model_name} not "
|
||||
f"available, setting to first model in list: {available_models[0]}"
|
||||
)
|
||||
default_provider.default_model_name = available_models[0]
|
||||
if default_provider.fast_default_model_name not in available_models:
|
||||
task_logger.info(
|
||||
f"Fast default model {default_provider.fast_default_model_name} "
|
||||
f"not available, setting to first model in list: {available_models[0]}"
|
||||
)
|
||||
default_provider.fast_default_model_name = available_models[0]
|
||||
db_session.commit()
|
||||
|
||||
if added_models or removed_models:
|
||||
task_logger.info("Updated model list for default provider.")
|
||||
|
||||
return True
|
||||
@@ -81,10 +81,10 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat = r.lock(
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -92,7 +92,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
return None
|
||||
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -127,6 +127,8 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_prune_generator_task(
|
||||
celery_app: Celery,
|
||||
@@ -283,6 +285,7 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
|
||||
callback = IndexingCallback(
|
||||
0,
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector.prune.generator_progress_key,
|
||||
lock,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
@@ -68,6 +69,7 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
@@ -76,6 +78,7 @@ from onyx.utils.variable_functionality import (
|
||||
)
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -88,7 +91,7 @@ logger = setup_logger()
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
time_start = time.monotonic()
|
||||
@@ -103,7 +106,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return
|
||||
return None
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_stale_document_sync_tasks(
|
||||
@@ -111,6 +114,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
)
|
||||
|
||||
# region document set scan
|
||||
lock_beat.reacquire()
|
||||
document_set_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# check if any document sets are not synced
|
||||
@@ -122,6 +126,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
document_set_ids.append(document_set.id)
|
||||
|
||||
for document_set_id in document_set_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_document_set_sync_tasks(
|
||||
self.app, document_set_id, db_session, r, lock_beat, tenant_id
|
||||
@@ -130,6 +135,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
|
||||
# check if any user groups are not synced
|
||||
if global_version.is_ee_version():
|
||||
lock_beat.reacquire()
|
||||
|
||||
try:
|
||||
fetch_user_groups = fetch_versioned_implementation(
|
||||
"onyx.db.user_group", "fetch_user_groups"
|
||||
@@ -149,6 +156,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
usergroup_ids.append(usergroup.id)
|
||||
|
||||
for usergroup_id in usergroup_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_user_group_sync_tasks(
|
||||
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
|
||||
@@ -163,10 +171,16 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"check_for_vespa_sync_task - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, r)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.debug(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
|
||||
return
|
||||
return True
|
||||
|
||||
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
@@ -748,7 +762,13 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
Returns True if the task actually did work, False if it exited early to prevent overlap
|
||||
"""
|
||||
task_logger.info(f"monitor_vespa_sync starting: tenant={tenant_id}")
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
||||
timings: dict[str, float] = {}
|
||||
timings["start"] = time_start
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -759,53 +779,76 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
try:
|
||||
# prevent overlapping tasks
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
task_logger.info("monitor_vespa_sync exiting due to overlap")
|
||||
return False
|
||||
|
||||
# print current queue lengths
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
n_deletion = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
n_pruning = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
phase_start = time.monotonic()
|
||||
# we don't need every tenant polling redis for this info.
|
||||
if not MULTI_TENANT or random.randint(1, 100) == 100:
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
n_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery
|
||||
)
|
||||
n_deletion = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
n_pruning = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
n_external_group_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
n_permissions_upsert = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
|
||||
prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"indexing_prefetched={len(prefetched)} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
)
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"indexing_prefetched={len(prefetched)} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
f"external_group_sync={n_external_group_sync} "
|
||||
f"permissions_upsert={n_permissions_upsert} "
|
||||
)
|
||||
timings["queues"] = time.monotonic() - phase_start
|
||||
|
||||
# scan and monitor activity to completion
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
timings["connector"] = time.monotonic() - phase_start
|
||||
|
||||
phase_start = time.monotonic()
|
||||
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
|
||||
timings["connector_deletion"] = time.monotonic() - phase_start
|
||||
|
||||
phase_start = time.monotonic()
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
timings["document_set"] = time.monotonic() - phase_start
|
||||
|
||||
phase_start = time.monotonic()
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
|
||||
@@ -815,22 +858,29 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
timings["usergroup"] = time.monotonic() - phase_start
|
||||
|
||||
phase_start = time.monotonic()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
timings["pruning"] = time.monotonic() - phase_start
|
||||
|
||||
phase_start = time.monotonic()
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
timings["indexing"] = time.monotonic() - phase_start
|
||||
|
||||
phase_start = time.monotonic()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
timings["permissions"] = time.monotonic() - phase_start
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -838,9 +888,24 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
t = timings
|
||||
task_logger.error(
|
||||
"monitor_vespa_sync - Lock not owned on completion: "
|
||||
f"tenant={tenant_id} "
|
||||
f"queues={t.get('queues')} "
|
||||
f"connector={t.get('connector')} "
|
||||
f"connector_deletion={t.get('connector_deletion')} "
|
||||
f"document_set={t.get('document_set')} "
|
||||
f"usergroup={t.get('usergroup')} "
|
||||
f"pruning={t.get('pruning')} "
|
||||
f"indexing={t.get('indexing')} "
|
||||
f"permissions={t.get('permissions')}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, r)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.debug(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
|
||||
task_logger.info(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
|
||||
return True
|
||||
|
||||
|
||||
@@ -890,6 +955,13 @@ def vespa_metadata_sync_task(
|
||||
# the sync might repeat again later
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
|
||||
redis_syncing_key = RedisConnectorCredentialPair.make_redis_syncing_key(
|
||||
document_id
|
||||
)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r.delete(redis_syncing_key)
|
||||
# r.hdel(RedisConnectorCredentialPair.SYNCING_HASH, document_id)
|
||||
|
||||
task_logger.info(f"doc={document_id} action=sync chunks={chunks_affected}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
|
||||
@@ -14,6 +14,7 @@ from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
@@ -90,6 +91,35 @@ def _get_connector_runner(
|
||||
)
|
||||
|
||||
|
||||
def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
cleaned_batch = []
|
||||
for doc in doc_batch:
|
||||
cleaned_doc = doc.model_copy()
|
||||
|
||||
if "\x00" in cleaned_doc.id:
|
||||
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
|
||||
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
|
||||
|
||||
if "\x00" in cleaned_doc.semantic_identifier:
|
||||
logger.warning(
|
||||
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
|
||||
)
|
||||
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
|
||||
"\x00", ""
|
||||
)
|
||||
|
||||
for section in cleaned_doc.sections:
|
||||
if section.link and "\x00" in section.link:
|
||||
logger.warning(
|
||||
f"NUL characters found in document link for document: {cleaned_doc.id}"
|
||||
)
|
||||
section.link = section.link.replace("\x00", "")
|
||||
|
||||
cleaned_batch.append(cleaned_doc)
|
||||
|
||||
return cleaned_batch
|
||||
|
||||
|
||||
class ConnectorStopSignal(Exception):
|
||||
"""A custom exception used to signal a stop in processing."""
|
||||
|
||||
@@ -238,7 +268,9 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
batch_description = []
|
||||
for doc in doc_batch:
|
||||
|
||||
doc_batch_cleaned = strip_null_characters(doc_batch)
|
||||
for doc in doc_batch_cleaned:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
|
||||
doc_size = 0
|
||||
@@ -258,15 +290,15 @@ def _run_indexing(
|
||||
|
||||
# real work happens here!
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
document_batch=doc_batch,
|
||||
document_batch=doc_batch_cleaned,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
net_doc_change += new_docs
|
||||
chunk_count += total_batch_chunks
|
||||
document_count += len(doc_batch)
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch)
|
||||
document_count += len(doc_batch_cleaned)
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch_cleaned)
|
||||
|
||||
# commit transaction so that the `update` below begins
|
||||
# with a brand new transaction. Postgres uses the start
|
||||
@@ -276,7 +308,7 @@ def _run_indexing(
|
||||
db_session.commit()
|
||||
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch))
|
||||
callback.progress("_run_indexing", len(doc_batch_cleaned))
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
update_docs_indexed(
|
||||
|
||||
@@ -22,7 +22,9 @@ from onyx.chat.stream_processing.answer_response_handler import (
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from onyx.chat.stream_processing.utils import map_document_id_order
|
||||
from onyx.chat.stream_processing.utils import (
|
||||
map_document_id_order,
|
||||
)
|
||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -206,9 +208,9 @@ class Answer:
|
||||
# + figure out what the next LLM call should be
|
||||
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
|
||||
search_result, displayed_search_results_map = SearchTool.get_search_result(
|
||||
final_search_results, displayed_search_results = SearchTool.get_search_result(
|
||||
current_llm_call
|
||||
) or ([], {})
|
||||
) or ([], [])
|
||||
|
||||
# Quotes are no longer supported
|
||||
# answer_handler: AnswerResponseHandler
|
||||
@@ -224,9 +226,9 @@ class Answer:
|
||||
# else:
|
||||
# raise ValueError("No answer style config provided")
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
display_doc_order_dict=displayed_search_results_map,
|
||||
context_docs=final_search_results,
|
||||
final_doc_id_to_rank_map=map_document_id_order(final_search_results),
|
||||
display_doc_id_to_rank_map=map_document_id_order(displayed_search_results),
|
||||
)
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
|
||||
@@ -37,22 +37,22 @@ class CitationResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_order_dict: dict[str, int],
|
||||
final_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.display_doc_order_dict = display_doc_order_dict
|
||||
self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
|
||||
self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
|
||||
self.citation_processor = CitationProcessor(
|
||||
context_docs=self.context_docs,
|
||||
doc_id_to_rank_map=self.doc_id_to_rank_map,
|
||||
display_doc_order_dict=self.display_doc_order_dict,
|
||||
final_doc_id_to_rank_map=self.final_doc_id_to_rank_map,
|
||||
display_doc_id_to_rank_map=self.display_doc_id_to_rank_map,
|
||||
)
|
||||
self.processed_text = ""
|
||||
self.citations: list[CitationInfo] = []
|
||||
|
||||
# TODO remove this after citation issue is resolved
|
||||
logger.debug(f"Document to ranking map {self.doc_id_to_rank_map}")
|
||||
logger.debug(f"Document to ranking map {self.final_doc_id_to_rank_map}")
|
||||
|
||||
def handle_response_part(
|
||||
self,
|
||||
|
||||
@@ -21,20 +21,19 @@ class CitationProcessor:
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_order_dict: dict[str, int],
|
||||
final_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.final_doc_id_to_rank_map = final_doc_id_to_rank_map
|
||||
self.display_doc_id_to_rank_map = display_doc_id_to_rank_map
|
||||
self.stop_stream = stop_stream
|
||||
self.order_mapping = doc_id_to_rank_map.order_mapping
|
||||
self.display_doc_order_dict = (
|
||||
display_doc_order_dict # original order of docs to displayed to user
|
||||
)
|
||||
self.final_order_mapping = final_doc_id_to_rank_map.order_mapping
|
||||
self.display_order_mapping = display_doc_id_to_rank_map.order_mapping
|
||||
self.llm_out = ""
|
||||
self.max_citation_num = len(context_docs)
|
||||
self.citation_order: list[int] = []
|
||||
self.citation_order: list[int] = [] # order of citations in the LLM output
|
||||
self.curr_segment = ""
|
||||
self.cited_inds: set[int] = set()
|
||||
self.hold = ""
|
||||
@@ -93,29 +92,31 @@ class CitationProcessor:
|
||||
|
||||
if 1 <= numerical_value <= self.max_citation_num:
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
real_citation_num = self.order_mapping[context_llm_doc.document_id]
|
||||
final_citation_num = self.final_order_mapping[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
|
||||
if real_citation_num not in self.citation_order:
|
||||
self.citation_order.append(real_citation_num)
|
||||
if final_citation_num not in self.citation_order:
|
||||
self.citation_order.append(final_citation_num)
|
||||
|
||||
target_citation_num = (
|
||||
self.citation_order.index(real_citation_num) + 1
|
||||
citation_order_idx = (
|
||||
self.citation_order.index(final_citation_num) + 1
|
||||
)
|
||||
|
||||
# get the value that was displayed to user, should always
|
||||
# be in the display_doc_order_dict. But check anyways
|
||||
if context_llm_doc.document_id in self.display_doc_order_dict:
|
||||
displayed_citation_num = self.display_doc_order_dict[
|
||||
if context_llm_doc.document_id in self.display_order_mapping:
|
||||
displayed_citation_num = self.display_order_mapping[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
else:
|
||||
displayed_citation_num = real_citation_num
|
||||
displayed_citation_num = final_citation_num
|
||||
logger.warning(
|
||||
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
||||
)
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if target_citation_num in self.current_citations:
|
||||
if final_citation_num in self.current_citations:
|
||||
start, end = citation.span()
|
||||
real_start = length_to_add + start
|
||||
diff = end - start
|
||||
@@ -134,8 +135,8 @@ class CitationProcessor:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
# stay with the original for now (order of LLM cites)
|
||||
citation_num=target_citation_num,
|
||||
# citation_num is now the number post initial ranking, i.e. as displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -151,13 +152,13 @@ class CitationProcessor:
|
||||
link = context_llm_doc.link
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(target_citation_num)
|
||||
self.current_citations.append(final_citation_num)
|
||||
|
||||
if target_citation_num not in self.cited_inds:
|
||||
self.cited_inds.add(target_citation_num)
|
||||
if citation_order_idx not in self.cited_inds:
|
||||
self.cited_inds.add(citation_order_idx)
|
||||
yield CitationInfo(
|
||||
# stay with the original for now (order of LLM cites)
|
||||
citation_num=target_citation_num,
|
||||
# citation number is now the one that was displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
@@ -167,7 +168,6 @@ class CitationProcessor:
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
||||
# + f"[[{target_citation_num}]]({link})"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
@@ -176,7 +176,6 @@ class CitationProcessor:
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
||||
# + f"[[{target_citation_num}]]()"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
|
||||
@@ -54,10 +54,17 @@ MASK_CREDENTIAL_PREFIX = (
|
||||
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
||||
)
|
||||
|
||||
REDIS_AUTH_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
SESSION_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
# Default request timeout, mostly used by connectors
|
||||
REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS") or 60)
|
||||
|
||||
# set `VALID_EMAIL_DOMAINS` to a comma seperated list of domains in order to
|
||||
# restrict access to Onyx to only users with emails from those domains.
|
||||
# E.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will restrict Onyx
|
||||
@@ -146,7 +153,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
AWS_REGION = os.environ.get("AWS_REGION") or "us-east-2"
|
||||
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
|
||||
|
||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
|
||||
@@ -185,6 +192,27 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
|
||||
|
||||
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
|
||||
|
||||
|
||||
# Rate limiting for auth endpoints
|
||||
RATE_LIMIT_WINDOW_SECONDS: int | None = None
|
||||
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
|
||||
if _rate_limit_window_seconds_str is not None:
|
||||
try:
|
||||
RATE_LIMIT_WINDOW_SECONDS = int(_rate_limit_window_seconds_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
RATE_LIMIT_MAX_REQUESTS: int | None = None
|
||||
_rate_limit_max_requests_str = os.environ.get("RATE_LIMIT_MAX_REQUESTS")
|
||||
if _rate_limit_max_requests_str is not None:
|
||||
try:
|
||||
RATE_LIMIT_MAX_REQUESTS = int(_rate_limit_max_requests_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Used for general redis things
|
||||
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
|
||||
|
||||
@@ -348,12 +376,17 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
|
||||
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Typically set to http://localhost:3000 for OAuth connector development
|
||||
CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE")
|
||||
|
||||
# Egnyte specific configs
|
||||
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
|
||||
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
|
||||
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
|
||||
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
|
||||
|
||||
# Linear specific configs
|
||||
LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
|
||||
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
@@ -504,6 +537,9 @@ try:
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# LLM Model Update API endpoint
|
||||
LLM_MODEL_UPDATE_API_URL = os.environ.get("LLM_MODEL_UPDATE_API_URL")
|
||||
|
||||
#####
|
||||
# Enterprise Edition Configs
|
||||
#####
|
||||
@@ -543,7 +579,6 @@ CONTROL_PLANE_API_BASE_URL = os.environ.get(
|
||||
# JWT configuration
|
||||
JWT_ALGORITHM = "HS256"
|
||||
|
||||
|
||||
#####
|
||||
# API Key Configs
|
||||
#####
|
||||
|
||||
@@ -76,13 +76,19 @@ KV_ENTERPRISE_SETTINGS_KEY = "onyx_enterprise_settings"
|
||||
KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__"
|
||||
KV_DOCUMENTS_SEEDED_KEY = "documents_seeded"
|
||||
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 60
|
||||
# NOTE: we use this timeout / 4 in various places to refresh a lock
|
||||
# might be worth separating this timeout into separate timeouts for each situation
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
|
||||
|
||||
# how long a task should wait for associated fence to be ready
|
||||
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
|
||||
@@ -136,9 +142,11 @@ class DocumentSource(str, Enum):
|
||||
OCI_STORAGE = "oci_storage"
|
||||
XENFORO = "xenforo"
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
DISCORD = "discord"
|
||||
FRESHDESK = "freshdesk"
|
||||
FIREFLIES = "fireflies"
|
||||
EGNYTE = "egnyte"
|
||||
AIRTABLE = "airtable"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
@@ -242,6 +250,7 @@ class OnyxCeleryQueues:
|
||||
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
||||
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
LLM_MODEL_UPDATE = "llm_model_update"
|
||||
|
||||
# Heavy queue
|
||||
CONNECTOR_PRUNING = "connector_pruning"
|
||||
@@ -275,6 +284,7 @@ class OnyxRedisLocks:
|
||||
|
||||
SLACK_BOT_LOCK = "da_lock:slack_bot"
|
||||
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
|
||||
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
@@ -296,6 +306,7 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_PRUNING = "check_for_pruning"
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
|
||||
289
backend/onyx/connectors/airtable/airtable_connector.py
Normal file
289
backend/onyx/connectors/airtable/airtable_connector.py
Normal file
@@ -0,0 +1,289 @@
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from pyairtable import Api as AirtableApi
|
||||
from pyairtable.api.types import RecordDict
|
||||
from pyairtable.models.schema import TableSchema
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# NOTE: all are made lowercase to avoid case sensitivity issues
|
||||
# these are the field types that are considered metadata rather
|
||||
# than sections
|
||||
_METADATA_FIELD_TYPES = {
|
||||
"singlecollaborator",
|
||||
"collaborator",
|
||||
"createdby",
|
||||
"singleselect",
|
||||
"multipleselects",
|
||||
"checkbox",
|
||||
"date",
|
||||
"datetime",
|
||||
"email",
|
||||
"phone",
|
||||
"url",
|
||||
"number",
|
||||
"currency",
|
||||
"duration",
|
||||
"percent",
|
||||
"rating",
|
||||
"createdtime",
|
||||
"lastmodifiedtime",
|
||||
"autonumber",
|
||||
"rollup",
|
||||
"lookup",
|
||||
"count",
|
||||
"formula",
|
||||
"date",
|
||||
}
|
||||
|
||||
|
||||
class AirtableClientNotSetUpError(PermissionError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Airtable Client is not set up, was load_credentials called?")
|
||||
|
||||
|
||||
class AirtableConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.base_id = base_id
|
||||
self.table_name_or_id = table_name_or_id
|
||||
self.batch_size = batch_size
|
||||
self.airtable_client: AirtableApi | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
|
||||
return None
|
||||
|
||||
def _get_field_value(self, field_info: Any, field_type: str) -> list[str]:
|
||||
"""
|
||||
Extract value(s) from a field regardless of its type.
|
||||
Returns either a single string or list of strings for attachments.
|
||||
"""
|
||||
if field_info is None:
|
||||
return []
|
||||
|
||||
# skip references to other records for now (would need to do another
|
||||
# request to get the actual record name/type)
|
||||
# TODO: support this
|
||||
if field_type == "multipleRecordLinks":
|
||||
return []
|
||||
|
||||
if field_type == "multipleAttachments":
|
||||
attachment_texts: list[str] = []
|
||||
for attachment in field_info:
|
||||
url = attachment.get("url")
|
||||
filename = attachment.get("filename", "")
|
||||
if not url:
|
||||
continue
|
||||
|
||||
@retry(
|
||||
tries=5,
|
||||
delay=1,
|
||||
backoff=2,
|
||||
max_delay=10,
|
||||
)
|
||||
def get_attachment_with_retry(url: str) -> bytes | None:
|
||||
attachment_response = requests.get(url)
|
||||
if attachment_response.status_code == 200:
|
||||
return attachment_response.content
|
||||
return None
|
||||
|
||||
attachment_content = get_attachment_with_retry(url)
|
||||
if attachment_content:
|
||||
try:
|
||||
file_ext = get_file_ext(filename)
|
||||
attachment_text = extract_file_text(
|
||||
BytesIO(attachment_content),
|
||||
filename,
|
||||
break_on_unprocessable=False,
|
||||
extension=file_ext,
|
||||
)
|
||||
if attachment_text:
|
||||
attachment_texts.append(f"{filename}:\n{attachment_text}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to process attachment {filename}: {str(e)}"
|
||||
)
|
||||
return attachment_texts
|
||||
|
||||
if field_type in ["singleCollaborator", "collaborator", "createdBy"]:
|
||||
combined = []
|
||||
collab_name = field_info.get("name")
|
||||
collab_email = field_info.get("email")
|
||||
if collab_name:
|
||||
combined.append(collab_name)
|
||||
if collab_email:
|
||||
combined.append(f"({collab_email})")
|
||||
return [" ".join(combined) if combined else str(field_info)]
|
||||
|
||||
if isinstance(field_info, list):
|
||||
return [str(item) for item in field_info]
|
||||
|
||||
return [str(field_info)]
|
||||
|
||||
def _should_be_metadata(self, field_type: str) -> bool:
|
||||
"""Determine if a field type should be treated as metadata."""
|
||||
return field_type.lower() in _METADATA_FIELD_TYPES
|
||||
|
||||
def _process_field(
|
||||
self,
|
||||
field_name: str,
|
||||
field_info: Any,
|
||||
field_type: str,
|
||||
table_id: str,
|
||||
record_id: str,
|
||||
) -> tuple[list[Section], dict[str, Any]]:
|
||||
"""
|
||||
Process a single Airtable field and return sections or metadata.
|
||||
|
||||
Args:
|
||||
field_name: Name of the field
|
||||
field_info: Raw field information from Airtable
|
||||
field_type: Airtable field type
|
||||
|
||||
Returns:
|
||||
(list of Sections, dict of metadata)
|
||||
"""
|
||||
if field_info is None:
|
||||
return [], {}
|
||||
|
||||
# Get the value(s) for the field
|
||||
field_values = self._get_field_value(field_info, field_type)
|
||||
if len(field_values) == 0:
|
||||
return [], {}
|
||||
|
||||
# Determine if it should be metadata or a section
|
||||
if self._should_be_metadata(field_type):
|
||||
if len(field_values) > 1:
|
||||
return [], {field_name: field_values}
|
||||
return [], {field_name: field_values[0]}
|
||||
|
||||
# Otherwise, create relevant sections
|
||||
sections = [
|
||||
Section(
|
||||
link=f"https://airtable.com/{self.base_id}/{table_id}/{record_id}",
|
||||
text=(
|
||||
f"{field_name}:\n"
|
||||
"------------------------\n"
|
||||
f"{text}\n"
|
||||
"------------------------"
|
||||
),
|
||||
)
|
||||
for text in field_values
|
||||
]
|
||||
return sections, {}
|
||||
|
||||
def _process_record(
|
||||
self,
|
||||
record: RecordDict,
|
||||
table_schema: TableSchema,
|
||||
primary_field_name: str | None,
|
||||
) -> Document:
|
||||
"""Process a single Airtable record into a Document.
|
||||
|
||||
Args:
|
||||
record: The Airtable record to process
|
||||
table_schema: Schema information for the table
|
||||
table_name: Name of the table
|
||||
table_id: ID of the table
|
||||
primary_field_name: Name of the primary field, if any
|
||||
|
||||
Returns:
|
||||
Document object representing the record
|
||||
"""
|
||||
table_id = table_schema.id
|
||||
table_name = table_schema.name
|
||||
record_id = record["id"]
|
||||
fields = record["fields"]
|
||||
sections: list[Section] = []
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
# Get primary field value if it exists
|
||||
primary_field_value = (
|
||||
fields.get(primary_field_name) if primary_field_name else None
|
||||
)
|
||||
|
||||
for field_schema in table_schema.fields:
|
||||
field_name = field_schema.name
|
||||
field_val = fields.get(field_name)
|
||||
field_type = field_schema.type
|
||||
|
||||
field_sections, field_metadata = self._process_field(
|
||||
field_name=field_name,
|
||||
field_info=field_val,
|
||||
field_type=field_type,
|
||||
table_id=table_id,
|
||||
record_id=record_id,
|
||||
)
|
||||
|
||||
sections.extend(field_sections)
|
||||
metadata.update(field_metadata)
|
||||
|
||||
semantic_id = (
|
||||
f"{table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
else table_name
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=f"airtable__{record_id}",
|
||||
sections=sections,
|
||||
source=DocumentSource.AIRTABLE,
|
||||
semantic_identifier=semantic_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Fetch all records from the table.
|
||||
|
||||
NOTE: Airtable does not support filtering by time updated, so
|
||||
we have to fetch all records every time.
|
||||
"""
|
||||
if not self.airtable_client:
|
||||
raise AirtableClientNotSetUpError()
|
||||
|
||||
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
|
||||
records = table.all()
|
||||
|
||||
table_schema = table.schema()
|
||||
primary_field_name = None
|
||||
|
||||
# Find a primary field from the schema
|
||||
for field in table_schema.fields:
|
||||
if field.id == table_schema.primary_field_id:
|
||||
primary_field_name = field.name
|
||||
break
|
||||
|
||||
record_documents: list[Document] = []
|
||||
for record in records:
|
||||
document = self._process_record(
|
||||
record=record,
|
||||
table_schema=table_schema,
|
||||
primary_field_name=primary_field_name,
|
||||
)
|
||||
record_documents.append(document)
|
||||
|
||||
if len(record_documents) >= self.batch_size:
|
||||
yield record_documents
|
||||
record_documents = []
|
||||
|
||||
if record_documents:
|
||||
yield record_documents
|
||||
@@ -52,6 +52,8 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
"space",
|
||||
"restrictions.read.restrictions.user",
|
||||
"restrictions.read.restrictions.group",
|
||||
"ancestors.restrictions.read.restrictions.user",
|
||||
"ancestors.restrictions.read.restrictions.group",
|
||||
]
|
||||
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
@@ -323,9 +325,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# These will be used by doc_sync.py to sync permissions
|
||||
page_restrictions = page.get("restrictions")
|
||||
page_space_key = page.get("space", {}).get("key")
|
||||
page_ancestors = page.get("ancestors", [])
|
||||
page_perm_sync_data = {
|
||||
"restrictions": page_restrictions or {},
|
||||
"space_key": page_space_key,
|
||||
"ancestors": page_ancestors or [],
|
||||
}
|
||||
|
||||
doc_metadata_list.append(
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import TypeVar
|
||||
|
||||
from dateutil.parser import parse
|
||||
|
||||
from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE
|
||||
from onyx.configs.constants import IGNORE_FOR_QA
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.utils.text_processing import is_valid_email
|
||||
@@ -71,3 +72,10 @@ def process_in_batches(
|
||||
|
||||
def get_metadata_keys_to_ignore() -> list[str]:
|
||||
return [IGNORE_FOR_QA]
|
||||
|
||||
|
||||
def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
|
||||
if CONNECTOR_LOCALHOST_OVERRIDE:
|
||||
# Used for development
|
||||
base_domain = CONNECTOR_LOCALHOST_OVERRIDE
|
||||
return f"{base_domain.strip('/')}/connector/oauth/callback/{connector_id}"
|
||||
|
||||
321
backend/onyx/connectors/discord/connector.py
Normal file
321
backend/onyx/connectors/discord/connector.py
Normal file
@@ -0,0 +1,321 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from discord import Client
|
||||
from discord.channel import TextChannel
|
||||
from discord.channel import Thread
|
||||
from discord.enums import MessageType
|
||||
from discord.flags import Intents
|
||||
from discord.message import Message as DiscordMessage
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_DISCORD_DOC_ID_PREFIX = "DISCORD_"
|
||||
_SNIPPET_LENGTH = 30
|
||||
|
||||
|
||||
def _convert_message_to_document(
|
||||
message: DiscordMessage,
|
||||
sections: list[Section],
|
||||
) -> Document:
|
||||
"""
|
||||
Convert a discord message to a document
|
||||
Sections are collected before calling this function because it relies on async
|
||||
calls to fetch the thread history if there is one
|
||||
"""
|
||||
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
semantic_substring = ""
|
||||
|
||||
# Only messages from TextChannels will make it here but we have to check for it anyways
|
||||
if isinstance(message.channel, TextChannel) and (
|
||||
channel_name := message.channel.name
|
||||
):
|
||||
metadata["Channel"] = channel_name
|
||||
semantic_substring += f" in Channel: #{channel_name}"
|
||||
|
||||
# Single messages dont have a title
|
||||
title = ""
|
||||
|
||||
# If there is a thread, add more detail to the metadata, title, and semantic identifier
|
||||
if isinstance(message.channel, Thread):
|
||||
# Threads do have a title
|
||||
title = message.channel.name
|
||||
|
||||
# If its a thread, update the metadata, title, and semantic_substring
|
||||
metadata["Thread"] = title
|
||||
|
||||
# Add more detail to the semantic identifier if available
|
||||
semantic_substring += f" in Thread: {title}"
|
||||
|
||||
snippet: str = (
|
||||
message.content[:_SNIPPET_LENGTH].rstrip() + "..."
|
||||
if len(message.content) > _SNIPPET_LENGTH
|
||||
else message.content
|
||||
)
|
||||
|
||||
semantic_identifier = f"{message.author.name} said{semantic_substring}: {snippet}"
|
||||
|
||||
return Document(
|
||||
id=f"{_DISCORD_DOC_ID_PREFIX}{message.id}",
|
||||
source=DocumentSource.DISCORD,
|
||||
semantic_identifier=semantic_identifier,
|
||||
doc_updated_at=message.edited_at,
|
||||
title=title,
|
||||
sections=sections,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_filtered_channels(
|
||||
discord_client: Client,
|
||||
server_ids: list[int] | None,
|
||||
channel_names: list[str] | None,
|
||||
) -> list[TextChannel]:
|
||||
filtered_channels: list[TextChannel] = []
|
||||
|
||||
for channel in discord_client.get_all_channels():
|
||||
if not channel.permissions_for(channel.guild.me).read_message_history:
|
||||
continue
|
||||
if not isinstance(channel, TextChannel):
|
||||
continue
|
||||
if server_ids and len(server_ids) > 0 and channel.guild.id not in server_ids:
|
||||
continue
|
||||
if channel_names and channel.name not in channel_names:
|
||||
continue
|
||||
filtered_channels.append(channel)
|
||||
|
||||
logger.info(f"Found {len(filtered_channels)} channels for the authenticated user")
|
||||
return filtered_channels
|
||||
|
||||
|
||||
async def _fetch_documents_from_channel(
|
||||
channel: TextChannel,
|
||||
start_time: datetime | None,
|
||||
end_time: datetime | None,
|
||||
) -> AsyncIterable[Document]:
|
||||
# Discord's epoch starts at 2015-01-01
|
||||
discord_epoch = datetime(2015, 1, 1, tzinfo=timezone.utc)
|
||||
if start_time and start_time < discord_epoch:
|
||||
start_time = discord_epoch
|
||||
|
||||
async for channel_message in channel.history(
|
||||
after=start_time,
|
||||
before=end_time,
|
||||
):
|
||||
# Skip messages that are not the default type
|
||||
if channel_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections: list[Section] = [
|
||||
Section(
|
||||
text=channel_message.content,
|
||||
link=channel_message.jump_url,
|
||||
)
|
||||
]
|
||||
|
||||
yield _convert_message_to_document(channel_message, sections)
|
||||
|
||||
for active_thread in channel.threads:
|
||||
async for thread_message in active_thread.history(
|
||||
after=start_time,
|
||||
before=end_time,
|
||||
):
|
||||
# Skip messages that are not the default type
|
||||
if thread_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections = [
|
||||
Section(
|
||||
text=thread_message.content,
|
||||
link=thread_message.jump_url,
|
||||
)
|
||||
]
|
||||
|
||||
yield _convert_message_to_document(thread_message, sections)
|
||||
|
||||
async for archived_thread in channel.archived_threads():
|
||||
async for thread_message in archived_thread.history(
|
||||
after=start_time,
|
||||
before=end_time,
|
||||
):
|
||||
# Skip messages that are not the default type
|
||||
if thread_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections = [
|
||||
Section(
|
||||
text=thread_message.content,
|
||||
link=thread_message.jump_url,
|
||||
)
|
||||
]
|
||||
|
||||
yield _convert_message_to_document(thread_message, sections)
|
||||
|
||||
|
||||
def _manage_async_retrieval(
|
||||
token: str,
|
||||
requested_start_date_string: str,
|
||||
channel_names: list[str],
|
||||
server_ids: list[int],
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> Iterable[Document]:
|
||||
# parse requested_start_date_string to datetime
|
||||
pull_date: datetime | None = (
|
||||
datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace(
|
||||
tzinfo=timezone.utc
|
||||
)
|
||||
if requested_start_date_string
|
||||
else None
|
||||
)
|
||||
|
||||
# Set start_time to the later of start and pull_date, or whichever is provided
|
||||
start_time = max(filter(None, [start, pull_date])) if start or pull_date else None
|
||||
|
||||
end_time: datetime | None = end
|
||||
|
||||
async def _async_fetch() -> AsyncIterable[Document]:
|
||||
intents = Intents.default()
|
||||
intents.message_content = True
|
||||
async with Client(intents=intents) as discord_client:
|
||||
asyncio.create_task(discord_client.start(token))
|
||||
await discord_client.wait_until_ready()
|
||||
|
||||
filtered_channels: list[TextChannel] = await _fetch_filtered_channels(
|
||||
discord_client=discord_client,
|
||||
server_ids=server_ids,
|
||||
channel_names=channel_names,
|
||||
)
|
||||
|
||||
for channel in filtered_channels:
|
||||
async for doc in _fetch_documents_from_channel(
|
||||
channel=channel,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
):
|
||||
yield doc
|
||||
|
||||
def run_and_yield() -> Iterable[Document]:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _async_fetch()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
# Run the coroutine to get the next document
|
||||
doc = loop.run_until_complete(next_coro)
|
||||
yield doc
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return run_and_yield()
|
||||
|
||||
|
||||
class DiscordConnector(PollConnector, LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
server_ids: list[str] = [],
|
||||
channel_names: list[str] = [],
|
||||
# YYYY-MM-DD
|
||||
start_date: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
):
|
||||
self.batch_size = batch_size
|
||||
self.channel_names: list[str] = channel_names if channel_names else []
|
||||
self.server_ids: list[int] = (
|
||||
[int(server_id) for server_id in server_ids] if server_ids else []
|
||||
)
|
||||
self._discord_bot_token: str | None = None
|
||||
self.requested_start_date_string: str = start_date or ""
|
||||
|
||||
@property
|
||||
def discord_bot_token(self) -> str:
|
||||
if self._discord_bot_token is None:
|
||||
raise ConnectorMissingCredentialError("Discord")
|
||||
return self._discord_bot_token
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._discord_bot_token = credentials["discord_bot_token"]
|
||||
return None
|
||||
|
||||
def _manage_doc_batching(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch = []
|
||||
for doc in _manage_async_retrieval(
|
||||
token=self.discord_bot_token,
|
||||
requested_start_date_string=self.requested_start_date_string,
|
||||
channel_names=self.channel_names,
|
||||
server_ids=self.server_ids,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
return self._manage_doc_batching(
|
||||
datetime.fromtimestamp(start, tz=timezone.utc),
|
||||
datetime.fromtimestamp(end, tz=timezone.utc),
|
||||
)
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._manage_doc_batching(None, None)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
import time
|
||||
|
||||
end = time.time()
|
||||
# 1 day
|
||||
start = end - 24 * 60 * 60 * 1
|
||||
# "1,2,3"
|
||||
server_ids: str | None = os.environ.get("server_ids", None)
|
||||
# "channel1,channel2"
|
||||
channel_names: str | None = os.environ.get("channel_names", None)
|
||||
|
||||
connector = DiscordConnector(
|
||||
server_ids=server_ids.split(",") if server_ids else [],
|
||||
channel_names=channel_names.split(",") if channel_names else [],
|
||||
start_date=os.environ.get("start_date", None),
|
||||
)
|
||||
connector.load_credentials(
|
||||
{"discord_bot_token": os.environ.get("discord_bot_token")}
|
||||
)
|
||||
|
||||
for doc_batch in connector.poll_source(start, end):
|
||||
for doc in doc_batch:
|
||||
print(doc)
|
||||
@@ -3,20 +3,19 @@ import os
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from logging import Logger
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.app_configs import EGNYTE_BASE_DOMAIN
|
||||
from onyx.configs.app_configs import EGNYTE_CLIENT_ID
|
||||
from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_oauth_callback_uri,
|
||||
)
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import OAuthConnector
|
||||
@@ -33,53 +32,13 @@ from onyx.file_processing.extract_file_text import is_text_file_extension
|
||||
from onyx.file_processing.extract_file_text import is_valid_file_ext
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import request_with_retries
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
|
||||
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
|
||||
_TIMEOUT = 60
|
||||
|
||||
|
||||
def _request_with_retries(
|
||||
method: str,
|
||||
url: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
headers: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
timeout: int = _TIMEOUT,
|
||||
stream: bool = False,
|
||||
tries: int = 8,
|
||||
delay: float = 1,
|
||||
backoff: float = 2,
|
||||
) -> requests.Response:
|
||||
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
|
||||
def _make_request() -> requests.Response:
|
||||
response = requests.request(
|
||||
method,
|
||||
url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code != 403:
|
||||
logger.exception(
|
||||
f"Failed to call Egnyte API.\n"
|
||||
f"URL: {url}\n"
|
||||
f"Headers: {headers}\n"
|
||||
f"Data: {data}\n"
|
||||
f"Params: {params}"
|
||||
)
|
||||
raise e
|
||||
return response
|
||||
|
||||
return _make_request()
|
||||
|
||||
|
||||
def _parse_last_modified(last_modified: str) -> datetime:
|
||||
@@ -166,6 +125,15 @@ def _process_egnyte_file(
|
||||
|
||||
|
||||
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
class AdditionalOauthKwargs(OAuthConnector.AdditionalOauthKwargs):
|
||||
egnyte_domain: str = Field(
|
||||
title="Egnyte Domain",
|
||||
description=(
|
||||
"The domain for the Egnyte instance "
|
||||
"(e.g. 'company' for company.egnyte.com)"
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
folder_path: str | None = None,
|
||||
@@ -181,18 +149,20 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
return DocumentSource.EGNYTE
|
||||
|
||||
@classmethod
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
def oauth_authorization_url(
|
||||
cls,
|
||||
base_domain: str,
|
||||
state: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> str:
|
||||
if not EGNYTE_CLIENT_ID:
|
||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
if EGNYTE_LOCALHOST_OVERRIDE:
|
||||
base_domain = EGNYTE_LOCALHOST_OVERRIDE
|
||||
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
|
||||
|
||||
callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte"
|
||||
callback_uri = get_oauth_callback_uri(base_domain, "egnyte")
|
||||
return (
|
||||
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
|
||||
f"?client_id={EGNYTE_CLIENT_ID}"
|
||||
f"&redirect_uri={callback_uri}"
|
||||
f"&scope=Egnyte.filesystem"
|
||||
@@ -201,17 +171,23 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
|
||||
def oauth_code_to_token(
|
||||
cls,
|
||||
base_domain: str,
|
||||
code: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
if not EGNYTE_CLIENT_ID:
|
||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||
if not EGNYTE_CLIENT_SECRET:
|
||||
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs)
|
||||
|
||||
# Exchange code for token
|
||||
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
redirect_uri = f"{EGNYTE_LOCALHOST_OVERRIDE or base_domain}/connector/oauth/callback/egnyte"
|
||||
url = f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token"
|
||||
redirect_uri = get_oauth_callback_uri(base_domain, "egnyte")
|
||||
|
||||
data = {
|
||||
"client_id": EGNYTE_CLIENT_ID,
|
||||
"client_secret": EGNYTE_CLIENT_SECRET,
|
||||
@@ -222,7 +198,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
}
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
response = _request_with_retries(
|
||||
response = request_with_retries(
|
||||
method="POST",
|
||||
url=url,
|
||||
data=data,
|
||||
@@ -236,7 +212,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
|
||||
token_data = response.json()
|
||||
return {
|
||||
"domain": EGNYTE_BASE_DOMAIN,
|
||||
"domain": oauth_kwargs.egnyte_domain,
|
||||
"access_token": token_data["access_token"],
|
||||
}
|
||||
|
||||
@@ -260,9 +236,10 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
"list_content": True,
|
||||
}
|
||||
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{path or ''}"
|
||||
response = _request_with_retries(
|
||||
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
|
||||
url_encoded_path = quote(path or "")
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
|
||||
response = request_with_retries(
|
||||
method="GET", url=url, headers=headers, params=params
|
||||
)
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
|
||||
@@ -315,12 +292,12 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
}
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}"
|
||||
response = _request_with_retries(
|
||||
url_encoded_path = quote(file["path"])
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
|
||||
response = request_with_retries(
|
||||
method="GET",
|
||||
url=url,
|
||||
headers=headers,
|
||||
timeout=_TIMEOUT,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -5,12 +5,14 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSourceRequiringTenantContext
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
from onyx.connectors.bookstack.connector import BookstackConnector
|
||||
from onyx.connectors.clickup.connector import ClickupConnector
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.discord.connector import DiscordConnector
|
||||
from onyx.connectors.discourse.connector import DiscourseConnector
|
||||
from onyx.connectors.document360.connector import Document360Connector
|
||||
from onyx.connectors.dropbox.connector import DropboxConnector
|
||||
@@ -100,9 +102,11 @@ def identify_connector_class(
|
||||
DocumentSource.GOOGLE_CLOUD_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.OCI_STORAGE: BlobStorageConnector,
|
||||
DocumentSource.XENFORO: XenforoConnector,
|
||||
DocumentSource.DISCORD: DiscordConnector,
|
||||
DocumentSource.FRESHDESK: FreshdeskConnector,
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
DocumentSource.AIRTABLE: AirtableConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import cast
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
|
||||
@@ -20,6 +21,7 @@ from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
|
||||
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
|
||||
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_utils.google_auth import get_google_creds
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
@@ -41,6 +43,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
|
||||
@@ -286,13 +289,30 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
logger.info(f"Impersonating user {user_email}")
|
||||
|
||||
drive_service = get_drive_service(self.creds, user_email)
|
||||
|
||||
# validate that the user has access to the drive APIs by performing a simple
|
||||
# request and checking for a 401
|
||||
try:
|
||||
retry_builder()(get_root_folder_id)(drive_service)
|
||||
except HttpError as e:
|
||||
if e.status_code == 401:
|
||||
# fail gracefully, let the other impersonations continue
|
||||
# one user without access shouldn't block the entire connector
|
||||
logger.exception(
|
||||
f"User '{user_email}' does not have access to the drive APIs."
|
||||
)
|
||||
return
|
||||
raise
|
||||
|
||||
# if we are including my drives, try to get the current user's my
|
||||
# drive if any of the following are true:
|
||||
# - include_my_drives is true
|
||||
# - the current user's email is in the requested emails
|
||||
if self.include_my_drives or user_email in self._requested_my_drive_emails:
|
||||
logger.info(f"Getting all files in my drive as '{user_email}'")
|
||||
yield from get_all_files_in_my_drive(
|
||||
service=drive_service,
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
@@ -303,6 +323,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
remaining_drive_ids = filtered_drive_ids - self._retrieved_ids
|
||||
for drive_id in remaining_drive_ids:
|
||||
logger.info(f"Getting files in shared drive '{drive_id}' as '{user_email}'")
|
||||
yield from get_files_in_shared_drive(
|
||||
service=drive_service,
|
||||
drive_id=drive_id,
|
||||
@@ -314,6 +335,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
remaining_folders = filtered_folder_ids - self._retrieved_ids
|
||||
for folder_id in remaining_folders:
|
||||
logger.info(f"Getting files in folder '{folder_id}' as '{user_email}'")
|
||||
yield from crawl_folders_for_files(
|
||||
service=drive_service,
|
||||
parent_id=folder_id,
|
||||
@@ -344,6 +366,15 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
elif self.include_shared_drives:
|
||||
drive_ids_to_retrieve = all_drive_ids
|
||||
|
||||
# checkpoint - we've found all users and drives, now time to actually start
|
||||
# fetching stuff
|
||||
logger.info(f"Found {len(all_org_emails)} users to impersonate")
|
||||
logger.debug(f"Users: {all_org_emails}")
|
||||
logger.info(f"Found {len(drive_ids_to_retrieve)} drives to retrieve")
|
||||
logger.debug(f"Drives: {drive_ids_to_retrieve}")
|
||||
logger.info(f"Found {len(folder_ids_to_retrieve)} folders to retrieve")
|
||||
logger.debug(f"Folders: {folder_ids_to_retrieve}")
|
||||
|
||||
# Process users in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
future_to_email = {
|
||||
@@ -380,6 +411,13 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
drive_service = get_drive_service(self.creds, self.primary_admin_email)
|
||||
|
||||
if self.include_files_shared_with_me or self.include_my_drives:
|
||||
logger.info(
|
||||
f"Getting shared files/my drive files for OAuth "
|
||||
f"with include_files_shared_with_me={self.include_files_shared_with_me}, "
|
||||
f"include_my_drives={self.include_my_drives}, "
|
||||
f"include_shared_drives={self.include_shared_drives}."
|
||||
f"Using '{self.primary_admin_email}' as the account."
|
||||
)
|
||||
yield from get_all_files_for_oauth(
|
||||
service=drive_service,
|
||||
include_files_shared_with_me=self.include_files_shared_with_me,
|
||||
@@ -412,6 +450,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
drive_ids_to_retrieve = all_drive_ids
|
||||
|
||||
for drive_id in drive_ids_to_retrieve:
|
||||
logger.info(
|
||||
f"Getting files in shared drive '{drive_id}' as '{self.primary_admin_email}'"
|
||||
)
|
||||
yield from get_files_in_shared_drive(
|
||||
service=drive_service,
|
||||
drive_id=drive_id,
|
||||
@@ -425,6 +466,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# that could be folders.
|
||||
remaining_folders = folder_ids_to_retrieve - self._retrieved_ids
|
||||
for folder_id in remaining_folders:
|
||||
logger.info(
|
||||
f"Getting files in folder '{folder_id}' as '{self.primary_admin_email}'"
|
||||
)
|
||||
yield from crawl_folders_for_files(
|
||||
service=drive_service,
|
||||
parent_id=folder_id,
|
||||
|
||||
@@ -2,6 +2,8 @@ import abc
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
@@ -66,6 +68,10 @@ class SlimConnector(BaseConnector):
|
||||
|
||||
|
||||
class OAuthConnector(BaseConnector):
|
||||
class AdditionalOauthKwargs(BaseModel):
|
||||
# if overridden, all fields should be str type
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_id(cls) -> DocumentSource:
|
||||
@@ -73,12 +79,22 @@ class OAuthConnector(BaseConnector):
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
def oauth_authorization_url(
|
||||
cls,
|
||||
base_domain: str,
|
||||
state: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_code_to_token(cls, base_domain: str, code: str) -> dict[str, Any]:
|
||||
def oauth_code_to_token(
|
||||
cls,
|
||||
base_domain: str,
|
||||
code: str,
|
||||
additional_kwargs: dict[str, str],
|
||||
) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
@@ -7,16 +7,23 @@ from typing import cast
|
||||
import requests
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import LINEAR_CLIENT_ID
|
||||
from onyx.configs.app_configs import LINEAR_CLIENT_SECRET
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_oauth_callback_uri,
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import OAuthConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import request_with_retries
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -57,7 +64,7 @@ def _make_query(request_body: dict[str, Any], api_key: str) -> requests.Response
|
||||
)
|
||||
|
||||
|
||||
class LinearConnector(LoadConnector, PollConnector):
|
||||
class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
@@ -65,8 +72,68 @@ class LinearConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
self.linear_api_key: str | None = None
|
||||
|
||||
@classmethod
|
||||
def oauth_id(cls) -> DocumentSource:
|
||||
return DocumentSource.LINEAR
|
||||
|
||||
@classmethod
|
||||
def oauth_authorization_url(
|
||||
cls, base_domain: str, state: str, additional_kwargs: dict[str, str]
|
||||
) -> str:
|
||||
if not LINEAR_CLIENT_ID:
|
||||
raise ValueError("LINEAR_CLIENT_ID environment variable must be set")
|
||||
|
||||
callback_uri = get_oauth_callback_uri(base_domain, DocumentSource.LINEAR.value)
|
||||
return (
|
||||
f"https://linear.app/oauth/authorize"
|
||||
f"?client_id={LINEAR_CLIENT_ID}"
|
||||
f"&redirect_uri={callback_uri}"
|
||||
f"&response_type=code"
|
||||
f"&scope=read"
|
||||
f"&state={state}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def oauth_code_to_token(
|
||||
cls, base_domain: str, code: str, additional_kwargs: dict[str, str]
|
||||
) -> dict[str, Any]:
|
||||
data = {
|
||||
"code": code,
|
||||
"redirect_uri": get_oauth_callback_uri(
|
||||
base_domain, DocumentSource.LINEAR.value
|
||||
),
|
||||
"client_id": LINEAR_CLIENT_ID,
|
||||
"client_secret": LINEAR_CLIENT_SECRET,
|
||||
"grant_type": "authorization_code",
|
||||
}
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
response = request_with_retries(
|
||||
method="POST",
|
||||
url="https://api.linear.app/oauth/token",
|
||||
data=data,
|
||||
headers=headers,
|
||||
backoff=0,
|
||||
delay=0.1,
|
||||
)
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failed to exchange code for token: {response.text}")
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
return {
|
||||
"access_token": token_data["access_token"],
|
||||
}
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.linear_api_key = cast(str, credentials["linear_api_key"])
|
||||
if "linear_api_key" in credentials:
|
||||
self.linear_api_key = cast(str, credentials["linear_api_key"])
|
||||
elif "access_token" in credentials:
|
||||
self.linear_api_key = "Bearer " + cast(str, credentials["access_token"])
|
||||
else:
|
||||
# May need to handle case in the future if the OAuth flow expires
|
||||
raise ConnectorMissingCredentialError("Linear")
|
||||
|
||||
return None
|
||||
|
||||
def _process_issues(
|
||||
|
||||
@@ -101,8 +101,11 @@ class DocumentBase(BaseModel):
|
||||
source: DocumentSource | None = None
|
||||
semantic_identifier: str # displayed in the UI as the main identifier for the doc
|
||||
metadata: dict[str, str | list[str]]
|
||||
|
||||
# UTC time
|
||||
doc_updated_at: datetime | None = None
|
||||
chunk_count: int | None = None
|
||||
|
||||
# Owner, creator, etc.
|
||||
primary_owners: list[BasicExpertInfo] | None = None
|
||||
# Assignee, space owner, etc.
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from simple_salesforce import Salesforce
|
||||
from simple_salesforce import SFType
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -19,24 +15,25 @@ from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.salesforce.utils import extract_dict_text
|
||||
from onyx.connectors.salesforce.doc_conversion import extract_section
|
||||
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
|
||||
from onyx.connectors.salesforce.salesforce_calls import get_all_children_of_sf_type
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_by_type
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_record
|
||||
from onyx.connectors.salesforce.sqlite_functions import init_db
|
||||
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
# TODO: this connector does not work well at large scales
|
||||
# the large query against a large Salesforce instance has been reported to take 1.5 hours.
|
||||
# Additionally it seems to eat up more memory over time if the connection is long running (again a scale issue).
|
||||
|
||||
|
||||
DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
MAX_QUERY_LENGTH = 10000 # max query length is 20,000 characters
|
||||
ID_PREFIX = "SALESFORCE_"
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
_ID_PREFIX = "SALESFORCE_"
|
||||
|
||||
|
||||
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -44,200 +41,170 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
requested_objects: list[str] = [],
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.sf_client: Salesforce | None = None
|
||||
self._sf_client: Salesforce | None = None
|
||||
self.parent_object_list = (
|
||||
[obj.capitalize() for obj in requested_objects]
|
||||
if requested_objects
|
||||
else DEFAULT_PARENT_OBJECT_TYPES
|
||||
else _DEFAULT_PARENT_OBJECT_TYPES
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.sf_client = Salesforce(
|
||||
def load_credentials(
|
||||
self,
|
||||
credentials: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
self._sf_client = Salesforce(
|
||||
username=credentials["sf_username"],
|
||||
password=credentials["sf_password"],
|
||||
security_token=credentials["sf_security_token"],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _get_sf_type_object_json(self, type_name: str) -> Any:
|
||||
if self.sf_client is None:
|
||||
@property
|
||||
def sf_client(self) -> Salesforce:
|
||||
if self._sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
sf_object = SFType(
|
||||
type_name, self.sf_client.session_id, self.sf_client.sf_instance
|
||||
)
|
||||
return sf_object.describe()
|
||||
return self._sf_client
|
||||
|
||||
def _get_name_from_id(self, id: str) -> str:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
try:
|
||||
user_object_info = self.sf_client.query(
|
||||
f"SELECT Name FROM User WHERE Id = '{id}'"
|
||||
)
|
||||
name = user_object_info.get("Records", [{}])[0].get("Name", "Null User")
|
||||
return name
|
||||
except Exception:
|
||||
logger.warning(f"Couldnt find name for object id: {id}")
|
||||
return "Null User"
|
||||
def _extract_primary_owners(
|
||||
self, sf_object: SalesforceObject
|
||||
) -> list[BasicExpertInfo] | None:
|
||||
object_dict = sf_object.data
|
||||
if not (last_modified_by_id := object_dict.get("LastModifiedById")):
|
||||
return None
|
||||
if not (last_modified_by := get_record(last_modified_by_id)):
|
||||
return None
|
||||
if not (last_modified_by_name := last_modified_by.data.get("Name")):
|
||||
return None
|
||||
primary_owners = [BasicExpertInfo(display_name=last_modified_by_name)]
|
||||
return primary_owners
|
||||
|
||||
def _convert_object_instance_to_document(
|
||||
self, object_dict: dict[str, Any]
|
||||
self, sf_object: SalesforceObject
|
||||
) -> Document:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
object_dict = sf_object.data
|
||||
salesforce_id = object_dict["Id"]
|
||||
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
|
||||
extracted_link = f"https://{self.sf_client.sf_instance}/{salesforce_id}"
|
||||
onyx_salesforce_id = f"{_ID_PREFIX}{salesforce_id}"
|
||||
base_url = f"https://{self.sf_client.sf_instance}"
|
||||
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
|
||||
extracted_object_text = extract_dict_text(object_dict)
|
||||
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
|
||||
extracted_primary_owners = [
|
||||
BasicExpertInfo(
|
||||
display_name=self._get_name_from_id(object_dict["LastModifiedById"])
|
||||
)
|
||||
]
|
||||
|
||||
sections = [extract_section(sf_object, base_url)]
|
||||
for id in get_child_ids(sf_object.id):
|
||||
if not (child_object := get_record(id)):
|
||||
continue
|
||||
sections.append(extract_section(child_object, base_url))
|
||||
|
||||
doc = Document(
|
||||
id=onyx_salesforce_id,
|
||||
sections=[Section(link=extracted_link, text=extracted_object_text)],
|
||||
sections=sections,
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier=extracted_semantic_identifier,
|
||||
doc_updated_at=extracted_doc_updated_at,
|
||||
primary_owners=extracted_primary_owners,
|
||||
primary_owners=self._extract_primary_owners(sf_object),
|
||||
metadata={},
|
||||
)
|
||||
return doc
|
||||
|
||||
def _is_valid_child_object(self, child_relationship: dict) -> bool:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
if not child_relationship["childSObject"]:
|
||||
return False
|
||||
if not child_relationship["relationshipName"]:
|
||||
return False
|
||||
|
||||
sf_type = child_relationship["childSObject"]
|
||||
object_description = self._get_sf_type_object_json(sf_type)
|
||||
if not object_description["queryable"]:
|
||||
return False
|
||||
|
||||
try:
|
||||
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
|
||||
result = self.sf_client.query(query)
|
||||
if result["totalSize"] == 0:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
|
||||
return False
|
||||
|
||||
if child_relationship["field"]:
|
||||
if child_relationship["field"] == "RelatedToId":
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _get_all_children_of_sf_type(self, sf_type: str) -> list[dict]:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
object_description = self._get_sf_type_object_json(sf_type)
|
||||
|
||||
children_objects: list[dict] = []
|
||||
for child_relationship in object_description["childRelationships"]:
|
||||
if self._is_valid_child_object(child_relationship):
|
||||
children_objects.append(
|
||||
{
|
||||
"relationship_name": child_relationship["relationshipName"],
|
||||
"object_type": child_relationship["childSObject"],
|
||||
}
|
||||
)
|
||||
return children_objects
|
||||
|
||||
def _get_all_fields_for_sf_type(self, sf_type: str) -> list[str]:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
object_description = self._get_sf_type_object_json(sf_type)
|
||||
|
||||
fields = [
|
||||
field.get("name")
|
||||
for field in object_description["fields"]
|
||||
if field.get("type", "base64") != "base64"
|
||||
]
|
||||
|
||||
return fields
|
||||
|
||||
def _generate_query_per_parent_type(self, parent_sf_type: str) -> Iterator[str]:
|
||||
"""
|
||||
This function takes in an object_type and generates query(s) designed to grab
|
||||
information associated to objects of that type.
|
||||
It does that by getting all the fields of the parent object type.
|
||||
Then it gets all the child objects of that object type and all the fields of
|
||||
those children as well.
|
||||
"""
|
||||
parent_fields = self._get_all_fields_for_sf_type(parent_sf_type)
|
||||
child_sf_types = self._get_all_children_of_sf_type(parent_sf_type)
|
||||
|
||||
query = f"SELECT {', '.join(parent_fields)}"
|
||||
for child_object_dict in child_sf_types:
|
||||
fields = self._get_all_fields_for_sf_type(child_object_dict["object_type"])
|
||||
query_addition = f", \n(SELECT {', '.join(fields)} FROM {child_object_dict['relationship_name']})"
|
||||
|
||||
if len(query_addition) + len(query) > MAX_QUERY_LENGTH:
|
||||
query += f"\n FROM {parent_sf_type}"
|
||||
yield query
|
||||
query = "SELECT Id" + query_addition
|
||||
else:
|
||||
query += query_addition
|
||||
|
||||
query += f"\n FROM {parent_sf_type}"
|
||||
|
||||
yield query
|
||||
|
||||
def _fetch_from_salesforce(
|
||||
self,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
init_db()
|
||||
all_object_types: set[str] = set(self.parent_object_list)
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
logger.info(f"Starting with {len(self.parent_object_list)} parent object types")
|
||||
logger.debug(f"Parent object types: {self.parent_object_list}")
|
||||
|
||||
# This takes like 20 seconds
|
||||
for parent_object_type in self.parent_object_list:
|
||||
logger.debug(f"Processing: {parent_object_type}")
|
||||
|
||||
query_results: dict = {}
|
||||
for query in self._generate_query_per_parent_type(parent_object_type):
|
||||
if start is not None and end is not None:
|
||||
if start and start.tzinfo is None:
|
||||
start = start.replace(tzinfo=timezone.utc)
|
||||
if end and end.tzinfo is None:
|
||||
end = end.replace(tzinfo=timezone.utc)
|
||||
query += f" WHERE LastModifiedDate > {start.isoformat()} AND LastModifiedDate < {end.isoformat()}"
|
||||
|
||||
query_result = self.sf_client.query_all(query)
|
||||
|
||||
for record_dict in query_result["records"]:
|
||||
query_results.setdefault(record_dict["Id"], {}).update(record_dict)
|
||||
|
||||
logger.info(
|
||||
f"Number of {parent_object_type} Objects processed: {len(query_results)}"
|
||||
child_types = get_all_children_of_sf_type(
|
||||
self.sf_client, parent_object_type
|
||||
)
|
||||
all_object_types.update(child_types)
|
||||
logger.debug(
|
||||
f"Found {len(child_types)} child types for {parent_object_type}"
|
||||
)
|
||||
|
||||
for combined_object_dict in query_results.values():
|
||||
doc_batch.append(
|
||||
self._convert_object_instance_to_document(combined_object_dict)
|
||||
)
|
||||
logger.info(f"Found total of {len(all_object_types)} object types to fetch")
|
||||
logger.debug(f"All object types: {all_object_types}")
|
||||
|
||||
if len(doc_batch) > self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
yield doc_batch
|
||||
# checkpoint - we've found all object types, now time to fetch the data
|
||||
logger.info("Starting to fetch CSVs for all object types")
|
||||
# This takes like 30 minutes first time and <2 minutes for updates
|
||||
object_type_to_csv_path = fetch_all_csvs_in_parallel(
|
||||
sf_client=self.sf_client,
|
||||
object_types=all_object_types,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
updated_ids: set[str] = set()
|
||||
# This takes like 10 seconds
|
||||
# This is for testing the rest of the functionality if data has
|
||||
# already been fetched and put in sqlite
|
||||
# from import onyx.connectors.salesforce.sf_db.sqlite_functions find_ids_by_type
|
||||
# for object_type in self.parent_object_list:
|
||||
# updated_ids.update(list(find_ids_by_type(object_type)))
|
||||
|
||||
# This takes 10-70 minutes first time (idk why the range is so big)
|
||||
total_types = len(object_type_to_csv_path)
|
||||
logger.info(f"Starting to process {total_types} object types")
|
||||
|
||||
for i, (object_type, csv_paths) in enumerate(
|
||||
object_type_to_csv_path.items(), 1
|
||||
):
|
||||
logger.info(f"Processing object type {object_type} ({i}/{total_types})")
|
||||
# If path is None, it means it failed to fetch the csv
|
||||
if csv_paths is None:
|
||||
continue
|
||||
# Go through each csv path and use it to update the db
|
||||
for csv_path in csv_paths:
|
||||
logger.debug(f"Updating {object_type} with {csv_path}")
|
||||
new_ids = update_sf_db_with_csv(
|
||||
object_type=object_type,
|
||||
csv_download_path=csv_path,
|
||||
)
|
||||
updated_ids.update(new_ids)
|
||||
logger.debug(
|
||||
f"Added {len(new_ids)} new/updated records for {object_type}"
|
||||
)
|
||||
# Remove the csv file after it has been used
|
||||
# to successfully update the db
|
||||
os.remove(csv_path)
|
||||
|
||||
logger.info(f"Found {len(updated_ids)} total updated records")
|
||||
logger.info(
|
||||
f"Starting to process parent objects of types: {self.parent_object_list}"
|
||||
)
|
||||
|
||||
docs_to_yield: list[Document] = []
|
||||
docs_processed = 0
|
||||
# Takes 15-20 seconds per batch
|
||||
for parent_type, parent_id_batch in get_affected_parent_ids_by_type(
|
||||
updated_ids=list(updated_ids),
|
||||
parent_types=self.parent_object_list,
|
||||
):
|
||||
logger.info(
|
||||
f"Processing batch of {len(parent_id_batch)} {parent_type} objects"
|
||||
)
|
||||
for parent_id in parent_id_batch:
|
||||
if not (parent_object := get_record(parent_id, parent_type)):
|
||||
logger.warning(
|
||||
f"Failed to get parent object {parent_id} for {parent_type}"
|
||||
)
|
||||
continue
|
||||
|
||||
docs_to_yield.append(
|
||||
self._convert_object_instance_to_document(parent_object)
|
||||
)
|
||||
docs_processed += 1
|
||||
|
||||
if len(docs_to_yield) >= self.batch_size:
|
||||
yield docs_to_yield
|
||||
docs_to_yield = []
|
||||
|
||||
yield docs_to_yield
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_from_salesforce()
|
||||
@@ -245,26 +212,20 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
start_datetime = datetime.utcfromtimestamp(start)
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
|
||||
return self._fetch_from_salesforce(start=start, end=end)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
for parent_object_type in self.parent_object_list:
|
||||
query = f"SELECT Id FROM {parent_object_type}"
|
||||
query_result = self.sf_client.query_all(query)
|
||||
doc_metadata_list.extend(
|
||||
SlimDocument(
|
||||
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",
|
||||
id=f"{_ID_PREFIX}{instance_dict.get('Id', '')}",
|
||||
perm_sync_data={},
|
||||
)
|
||||
for instance_dict in query_result["records"]
|
||||
@@ -274,9 +235,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = SalesforceConnector(
|
||||
requested_objects=os.environ["REQUESTED_OBJECTS"].split(",")
|
||||
)
|
||||
import time
|
||||
|
||||
connector = SalesforceConnector(requested_objects=["Account"])
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
@@ -285,5 +246,20 @@ if __name__ == "__main__":
|
||||
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
start_time = time.time()
|
||||
doc_count = 0
|
||||
section_count = 0
|
||||
text_count = 0
|
||||
for doc_batch in connector.load_from_state():
|
||||
doc_count += len(doc_batch)
|
||||
print(f"doc_count: {doc_count}")
|
||||
for doc in doc_batch:
|
||||
section_count += len(doc.sections)
|
||||
for section in doc.sections:
|
||||
text_count += len(section.text)
|
||||
end_time = time.time()
|
||||
|
||||
print(f"Doc count: {doc_count}")
|
||||
print(f"Section count: {section_count}")
|
||||
print(f"Text count: {text_count}")
|
||||
print(f"Time taken: {end_time - start_time}")
|
||||
|
||||
156
backend/onyx/connectors/salesforce/doc_conversion.py
Normal file
156
backend/onyx/connectors/salesforce/doc_conversion.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
|
||||
# All of these types of keys are handled by specific fields in the doc
|
||||
# conversion process (E.g. URLs) or are not useful for the user (E.g. UUIDs)
|
||||
_SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
|
||||
|
||||
|
||||
def _clean_salesforce_dict(data: dict | list) -> dict | list:
|
||||
"""Clean and transform Salesforce API response data by recursively:
|
||||
1. Extracting records from the response if present
|
||||
2. Merging attributes into the main dictionary
|
||||
3. Filtering out keys matching certain patterns (Id, Date, stamp, url)
|
||||
4. Removing '__c' suffix from custom field names
|
||||
5. Removing None values and empty containers
|
||||
|
||||
Args:
|
||||
data: A dictionary or list from Salesforce API response
|
||||
|
||||
Returns:
|
||||
Cleaned dictionary or list with transformed keys and filtered values
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
if "records" in data.keys():
|
||||
data = data["records"]
|
||||
if isinstance(data, dict):
|
||||
if "attributes" in data.keys():
|
||||
if isinstance(data["attributes"], dict):
|
||||
data.update(data.pop("attributes"))
|
||||
|
||||
if isinstance(data, dict):
|
||||
filtered_dict = {}
|
||||
for key, value in data.items():
|
||||
if not re.search(_SF_JSON_FILTER, key, re.IGNORECASE):
|
||||
# remove the custom object indicator for display
|
||||
if "__c" in key:
|
||||
key = key[:-3]
|
||||
if isinstance(value, (dict, list)):
|
||||
filtered_value = _clean_salesforce_dict(value)
|
||||
# Only add non-empty dictionaries or lists
|
||||
if filtered_value:
|
||||
filtered_dict[key] = filtered_value
|
||||
elif value is not None:
|
||||
filtered_dict[key] = value
|
||||
return filtered_dict
|
||||
elif isinstance(data, list):
|
||||
filtered_list = []
|
||||
for item in data:
|
||||
if isinstance(item, (dict, list)):
|
||||
filtered_item = _clean_salesforce_dict(item)
|
||||
# Only add non-empty dictionaries or lists
|
||||
if filtered_item:
|
||||
filtered_list.append(filtered_item)
|
||||
elif item is not None:
|
||||
filtered_list.append(filtered_item)
|
||||
return filtered_list
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def _json_to_natural_language(data: dict | list, indent: int = 0) -> str:
|
||||
"""Convert a nested dictionary or list into a human-readable string format.
|
||||
|
||||
Recursively traverses the data structure and formats it with:
|
||||
- Key-value pairs on separate lines
|
||||
- Nested structures indented for readability
|
||||
- Lists and dictionaries handled with appropriate formatting
|
||||
|
||||
Args:
|
||||
data: The dictionary or list to convert
|
||||
indent: Number of spaces to indent (default: 0)
|
||||
|
||||
Returns:
|
||||
A formatted string representation of the data structure
|
||||
"""
|
||||
result = []
|
||||
indent_str = " " * indent
|
||||
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
result.append(f"{indent_str}{key}:")
|
||||
result.append(_json_to_natural_language(value, indent + 2))
|
||||
else:
|
||||
result.append(f"{indent_str}{key}: {value}")
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
result.append(_json_to_natural_language(item, indent + 2))
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def _extract_dict_text(raw_dict: dict) -> str:
|
||||
"""Extract text from a Salesforce API response dictionary by:
|
||||
1. Cleaning the dictionary
|
||||
2. Converting the cleaned dictionary to natural language
|
||||
"""
|
||||
processed_dict = _clean_salesforce_dict(raw_dict)
|
||||
natural_language_for_dict = _json_to_natural_language(processed_dict)
|
||||
return natural_language_for_dict
|
||||
|
||||
|
||||
def extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
|
||||
return Section(
|
||||
text=_extract_dict_text(salesforce_object.data),
|
||||
link=f"{base_url}/{salesforce_object.id}",
|
||||
)
|
||||
|
||||
|
||||
def _field_value_is_child_object(field_value: dict) -> bool:
|
||||
"""
|
||||
Checks if the field value is a child object.
|
||||
"""
|
||||
return (
|
||||
isinstance(field_value, OrderedDict)
|
||||
and "records" in field_value.keys()
|
||||
and isinstance(field_value["records"], list)
|
||||
and len(field_value["records"]) > 0
|
||||
and "Id" in field_value["records"][0].keys()
|
||||
)
|
||||
|
||||
|
||||
def _extract_sections(salesforce_object: dict, base_url: str) -> list[Section]:
|
||||
"""
|
||||
This goes through the salesforce_object and extracts the top level fields as a Section.
|
||||
It also goes through the child objects and extracts them as Sections.
|
||||
"""
|
||||
top_level_dict = {}
|
||||
|
||||
child_object_sections = []
|
||||
for field_name, field_value in salesforce_object.items():
|
||||
# If the field value is not a child object, add it to the top level dict
|
||||
# to turn into text for the top level section
|
||||
if not _field_value_is_child_object(field_value):
|
||||
top_level_dict[field_name] = field_value
|
||||
continue
|
||||
|
||||
# If the field value is a child object, extract the child objects and add them as sections
|
||||
for record in field_value["records"]:
|
||||
child_object_id = record["Id"]
|
||||
child_object_sections.append(
|
||||
Section(
|
||||
text=f"Child Object(s): {field_name}\n{_extract_dict_text(record)}",
|
||||
link=f"{base_url}/{child_object_id}",
|
||||
)
|
||||
)
|
||||
|
||||
top_level_id = salesforce_object["Id"]
|
||||
top_level_section = Section(
|
||||
text=_extract_dict_text(top_level_dict),
|
||||
link=f"{base_url}/{top_level_id}",
|
||||
)
|
||||
return [top_level_section, *child_object_sections]
|
||||
210
backend/onyx/connectors/salesforce/salesforce_calls.py
Normal file
210
backend/onyx/connectors/salesforce/salesforce_calls.py
Normal file
@@ -0,0 +1,210 @@
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from pytz import UTC
|
||||
from simple_salesforce import Salesforce
|
||||
from simple_salesforce import SFType
|
||||
from simple_salesforce.bulk2 import SFBulk2Handler
|
||||
from simple_salesforce.bulk2 import SFBulk2Type
|
||||
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.salesforce.sqlite_functions import has_at_least_one_object_of_type
|
||||
from onyx.connectors.salesforce.utils import get_object_type_path
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_time_filter_for_salesforce(
|
||||
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> str:
|
||||
if start is None or end is None:
|
||||
return ""
|
||||
start_datetime = datetime.fromtimestamp(start, UTC)
|
||||
end_datetime = datetime.fromtimestamp(end, UTC)
|
||||
return (
|
||||
f" WHERE LastModifiedDate > {start_datetime.isoformat()} "
|
||||
f"AND LastModifiedDate < {end_datetime.isoformat()}"
|
||||
)
|
||||
|
||||
|
||||
def _get_sf_type_object_json(sf_client: Salesforce, type_name: str) -> Any:
|
||||
sf_object = SFType(type_name, sf_client.session_id, sf_client.sf_instance)
|
||||
return sf_object.describe()
|
||||
|
||||
|
||||
def _is_valid_child_object(
|
||||
sf_client: Salesforce, child_relationship: dict[str, Any]
|
||||
) -> bool:
|
||||
if not child_relationship["childSObject"]:
|
||||
return False
|
||||
if not child_relationship["relationshipName"]:
|
||||
return False
|
||||
|
||||
sf_type = child_relationship["childSObject"]
|
||||
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
||||
if not object_description["queryable"]:
|
||||
return False
|
||||
|
||||
if child_relationship["field"]:
|
||||
if child_relationship["field"] == "RelatedToId":
|
||||
return False
|
||||
else:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def get_all_children_of_sf_type(sf_client: Salesforce, sf_type: str) -> set[str]:
|
||||
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
||||
|
||||
child_object_types = set()
|
||||
for child_relationship in object_description["childRelationships"]:
|
||||
if _is_valid_child_object(sf_client, child_relationship):
|
||||
logger.debug(
|
||||
f"Found valid child object {child_relationship['childSObject']}"
|
||||
)
|
||||
child_object_types.add(child_relationship["childSObject"])
|
||||
return child_object_types
|
||||
|
||||
|
||||
def _get_all_queryable_fields_of_sf_type(
|
||||
sf_client: Salesforce,
|
||||
sf_type: str,
|
||||
) -> list[str]:
|
||||
object_description = _get_sf_type_object_json(sf_client, sf_type)
|
||||
fields: list[dict[str, Any]] = object_description["fields"]
|
||||
valid_fields: set[str] = set()
|
||||
compound_field_names: set[str] = set()
|
||||
for field in fields:
|
||||
if compound_field_name := field.get("compoundFieldName"):
|
||||
compound_field_names.add(compound_field_name)
|
||||
if field.get("type", "base64") == "base64":
|
||||
continue
|
||||
if field_name := field.get("name"):
|
||||
valid_fields.add(field_name)
|
||||
|
||||
return list(valid_fields - compound_field_names)
|
||||
|
||||
|
||||
def _check_if_object_type_is_empty(sf_client: Salesforce, sf_type: str) -> bool:
|
||||
"""
|
||||
Send a small query to check if the object type is empty so we don't
|
||||
perform extra bulk queries
|
||||
"""
|
||||
try:
|
||||
query = f"SELECT Count() FROM {sf_type} LIMIT 1"
|
||||
result = sf_client.query(query)
|
||||
if result["totalSize"] == 0:
|
||||
return False
|
||||
except Exception as e:
|
||||
if "OPERATION_TOO_LARGE" not in str(e):
|
||||
logger.warning(f"Object type {sf_type} doesn't support query: {e}")
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _check_for_existing_csvs(sf_type: str) -> list[str] | None:
|
||||
# Check if the csv already exists
|
||||
if os.path.exists(get_object_type_path(sf_type)):
|
||||
existing_csvs = [
|
||||
os.path.join(get_object_type_path(sf_type), f)
|
||||
for f in os.listdir(get_object_type_path(sf_type))
|
||||
if f.endswith(".csv")
|
||||
]
|
||||
# If the csv already exists, return the path
|
||||
# This is likely due to a previous run that failed
|
||||
# after downloading the csv but before the data was
|
||||
# written to the db
|
||||
if existing_csvs:
|
||||
return existing_csvs
|
||||
return None
|
||||
|
||||
|
||||
def _build_bulk_query(sf_client: Salesforce, sf_type: str, time_filter: str) -> str:
|
||||
queryable_fields = _get_all_queryable_fields_of_sf_type(sf_client, sf_type)
|
||||
query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}"
|
||||
return query
|
||||
|
||||
|
||||
def _bulk_retrieve_from_salesforce(
|
||||
sf_client: Salesforce,
|
||||
sf_type: str,
|
||||
time_filter: str,
|
||||
) -> tuple[str, list[str] | None]:
|
||||
if not _check_if_object_type_is_empty(sf_client, sf_type):
|
||||
return sf_type, None
|
||||
|
||||
if existing_csvs := _check_for_existing_csvs(sf_type):
|
||||
return sf_type, existing_csvs
|
||||
|
||||
query = _build_bulk_query(sf_client, sf_type, time_filter)
|
||||
|
||||
bulk_2_handler = SFBulk2Handler(
|
||||
session_id=sf_client.session_id,
|
||||
bulk2_url=sf_client.bulk2_url,
|
||||
proxies=sf_client.proxies,
|
||||
session=sf_client.session,
|
||||
)
|
||||
bulk_2_type = SFBulk2Type(
|
||||
object_name=sf_type,
|
||||
bulk2_url=bulk_2_handler.bulk2_url,
|
||||
headers=bulk_2_handler.headers,
|
||||
session=bulk_2_handler.session,
|
||||
)
|
||||
|
||||
logger.info(f"Downloading {sf_type}")
|
||||
logger.info(f"Query: {query}")
|
||||
|
||||
try:
|
||||
# This downloads the file to a file in the target path with a random name
|
||||
results = bulk_2_type.download(
|
||||
query=query,
|
||||
path=get_object_type_path(sf_type),
|
||||
max_records=1000000,
|
||||
)
|
||||
all_download_paths = [result["file"] for result in results]
|
||||
logger.info(f"Downloaded {sf_type} to {all_download_paths}")
|
||||
return sf_type, all_download_paths
|
||||
except Exception as e:
|
||||
logger.info(f"Failed to download salesforce csv for object type {sf_type}: {e}")
|
||||
return sf_type, None
|
||||
|
||||
|
||||
def fetch_all_csvs_in_parallel(
|
||||
sf_client: Salesforce,
|
||||
object_types: set[str],
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
) -> dict[str, list[str] | None]:
|
||||
"""
|
||||
Fetches all the csvs in parallel for the given object types
|
||||
Returns a dict of (sf_type, full_download_path)
|
||||
"""
|
||||
time_filter = _build_time_filter_for_salesforce(start, end)
|
||||
time_filter_for_each_object_type = {}
|
||||
# We do this outside of the thread pool executor because this requires
|
||||
# a database connection and we don't want to block the thread pool
|
||||
# executor from running
|
||||
for sf_type in object_types:
|
||||
"""Only add time filter if there is at least one object of the type
|
||||
in the database. We aren't worried about partially completed object update runs
|
||||
because this occurs after we check for existing csvs which covers this case"""
|
||||
if has_at_least_one_object_of_type(sf_type):
|
||||
time_filter_for_each_object_type[sf_type] = time_filter
|
||||
else:
|
||||
time_filter_for_each_object_type[sf_type] = ""
|
||||
|
||||
# Run the bulk retrieve in parallel
|
||||
with ThreadPoolExecutor() as executor:
|
||||
results = executor.map(
|
||||
lambda object_type: _bulk_retrieve_from_salesforce(
|
||||
sf_client=sf_client,
|
||||
sf_type=object_type,
|
||||
time_filter=time_filter_for_each_object_type[object_type],
|
||||
),
|
||||
object_types,
|
||||
)
|
||||
return dict(results)
|
||||
@@ -0,0 +1,209 @@
|
||||
import csv
|
||||
import shelve
|
||||
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import (
|
||||
get_child_to_parent_shelf_path,
|
||||
)
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_id_type_shelf_path
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_object_shelf_path
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_utils import (
|
||||
get_parent_to_child_shelf_path,
|
||||
)
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.connectors.salesforce.utils import validate_salesforce_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _update_relationship_shelves(
|
||||
child_id: str,
|
||||
parent_ids: set[str],
|
||||
) -> None:
|
||||
"""Update the relationship shelf when a record is updated."""
|
||||
try:
|
||||
# Convert child_id to string once
|
||||
str_child_id = str(child_id)
|
||||
|
||||
# First update child to parent mapping
|
||||
with shelve.open(
|
||||
get_child_to_parent_shelf_path(),
|
||||
flag="c",
|
||||
protocol=None,
|
||||
writeback=True,
|
||||
) as child_to_parent_db:
|
||||
old_parent_ids = set(child_to_parent_db.get(str_child_id, []))
|
||||
child_to_parent_db[str_child_id] = list(parent_ids)
|
||||
|
||||
# Calculate differences outside the next context manager
|
||||
parent_ids_to_remove = old_parent_ids - parent_ids
|
||||
parent_ids_to_add = parent_ids - old_parent_ids
|
||||
|
||||
# Only sync once at the end
|
||||
child_to_parent_db.sync()
|
||||
|
||||
# Then update parent to child mapping in a single transaction
|
||||
if not parent_ids_to_remove and not parent_ids_to_add:
|
||||
return
|
||||
with shelve.open(
|
||||
get_parent_to_child_shelf_path(),
|
||||
flag="c",
|
||||
protocol=None,
|
||||
writeback=True,
|
||||
) as parent_to_child_db:
|
||||
# Process all removals first
|
||||
for parent_id in parent_ids_to_remove:
|
||||
str_parent_id = str(parent_id)
|
||||
existing_children = set(parent_to_child_db.get(str_parent_id, []))
|
||||
if str_child_id in existing_children:
|
||||
existing_children.remove(str_child_id)
|
||||
parent_to_child_db[str_parent_id] = list(existing_children)
|
||||
|
||||
# Then process all additions
|
||||
for parent_id in parent_ids_to_add:
|
||||
str_parent_id = str(parent_id)
|
||||
existing_children = set(parent_to_child_db.get(str_parent_id, []))
|
||||
existing_children.add(str_child_id)
|
||||
parent_to_child_db[str_parent_id] = list(existing_children)
|
||||
|
||||
# Single sync at the end
|
||||
parent_to_child_db.sync()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating relationship shelves: {e}")
|
||||
logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}")
|
||||
raise
|
||||
|
||||
|
||||
def get_child_ids(parent_id: str) -> set[str]:
|
||||
"""Get all child IDs for a given parent ID.
|
||||
|
||||
Args:
|
||||
parent_id: The ID of the parent object
|
||||
|
||||
Returns:
|
||||
A set of child object IDs
|
||||
"""
|
||||
with shelve.open(get_parent_to_child_shelf_path()) as parent_to_child_db:
|
||||
return set(parent_to_child_db.get(parent_id, []))
|
||||
|
||||
|
||||
def update_sf_db_with_csv(
|
||||
object_type: str,
|
||||
csv_download_path: str,
|
||||
) -> list[str]:
|
||||
"""Update the SF DB with a CSV file using shelve storage."""
|
||||
updated_ids = []
|
||||
shelf_path = get_object_shelf_path(object_type)
|
||||
|
||||
# First read the CSV to get all the data
|
||||
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
id = row["Id"]
|
||||
parent_ids = set()
|
||||
field_to_remove: set[str] = set()
|
||||
# Update relationship shelves for any parent references
|
||||
for field, value in row.items():
|
||||
if validate_salesforce_id(value) and field != "Id":
|
||||
parent_ids.add(value)
|
||||
field_to_remove.add(field)
|
||||
if not value:
|
||||
field_to_remove.add(field)
|
||||
_update_relationship_shelves(id, parent_ids)
|
||||
for field in field_to_remove:
|
||||
# We use this to extract the Primary Owner later
|
||||
if field != "LastModifiedById":
|
||||
del row[field]
|
||||
|
||||
# Update the main object shelf
|
||||
with shelve.open(shelf_path) as object_type_db:
|
||||
object_type_db[id] = row
|
||||
# Update the ID-to-type mapping shelf
|
||||
with shelve.open(get_id_type_shelf_path()) as id_type_db:
|
||||
id_type_db[id] = object_type
|
||||
|
||||
updated_ids.append(id)
|
||||
|
||||
# os.remove(csv_download_path)
|
||||
return updated_ids
|
||||
|
||||
|
||||
def get_type_from_id(object_id: str) -> str | None:
|
||||
"""Get the type of an object from its ID."""
|
||||
# Look up the object type from the ID-to-type mapping
|
||||
with shelve.open(get_id_type_shelf_path()) as id_type_db:
|
||||
if object_id not in id_type_db:
|
||||
logger.warning(f"Object ID {object_id} not found in ID-to-type mapping")
|
||||
return None
|
||||
return id_type_db[object_id]
|
||||
|
||||
|
||||
def get_record(
|
||||
object_id: str, object_type: str | None = None
|
||||
) -> SalesforceObject | None:
|
||||
"""
|
||||
Retrieve the record and return it as a SalesforceObject.
|
||||
The object type will be looked up from the ID-to-type mapping shelf.
|
||||
"""
|
||||
if object_type is None:
|
||||
if not (object_type := get_type_from_id(object_id)):
|
||||
return None
|
||||
|
||||
shelf_path = get_object_shelf_path(object_type)
|
||||
with shelve.open(shelf_path) as db:
|
||||
if object_id not in db:
|
||||
logger.warning(f"Object ID {object_id} not found in {shelf_path}")
|
||||
return None
|
||||
data = db[object_id]
|
||||
return SalesforceObject(
|
||||
id=object_id,
|
||||
type=object_type,
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
def find_ids_by_type(object_type: str) -> list[str]:
|
||||
"""
|
||||
Find all object IDs for rows of the specified type.
|
||||
"""
|
||||
shelf_path = get_object_shelf_path(object_type)
|
||||
try:
|
||||
with shelve.open(shelf_path) as db:
|
||||
return list(db.keys())
|
||||
except FileNotFoundError:
|
||||
return []
|
||||
|
||||
|
||||
def get_affected_parent_ids_by_type(
|
||||
updated_ids: set[str], parent_types: list[str]
|
||||
) -> dict[str, set[str]]:
|
||||
"""Get IDs of objects that are of the specified parent types and are either in the updated_ids
|
||||
or have children in the updated_ids.
|
||||
|
||||
Args:
|
||||
updated_ids: List of IDs that were updated
|
||||
parent_types: List of object types to filter by
|
||||
|
||||
Returns:
|
||||
A dictionary of IDs that match the criteria
|
||||
"""
|
||||
affected_ids_by_type: dict[str, set[str]] = {}
|
||||
|
||||
# Check each updated ID
|
||||
for updated_id in updated_ids:
|
||||
# Add the ID itself if it's of a parent type
|
||||
updated_type = get_type_from_id(updated_id)
|
||||
if updated_type in parent_types:
|
||||
affected_ids_by_type.setdefault(updated_type, set()).add(updated_id)
|
||||
continue
|
||||
|
||||
# Get parents of this ID and add them if they're of a parent type
|
||||
with shelve.open(get_child_to_parent_shelf_path()) as child_to_parent_db:
|
||||
parent_ids = child_to_parent_db.get(updated_id, [])
|
||||
for parent_id in parent_ids:
|
||||
parent_type = get_type_from_id(parent_id)
|
||||
if parent_type in parent_types:
|
||||
affected_ids_by_type.setdefault(parent_type, set()).add(parent_id)
|
||||
|
||||
return affected_ids_by_type
|
||||
@@ -0,0 +1,29 @@
|
||||
import os
|
||||
|
||||
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||
from onyx.connectors.salesforce.utils import get_object_type_path
|
||||
|
||||
|
||||
def get_object_shelf_path(object_type: str) -> str:
|
||||
"""Get the path to the shelf file for a specific object type."""
|
||||
base_path = get_object_type_path(object_type)
|
||||
os.makedirs(base_path, exist_ok=True)
|
||||
return os.path.join(base_path, "data.shelf")
|
||||
|
||||
|
||||
def get_id_type_shelf_path() -> str:
|
||||
"""Get the path to the ID-to-type mapping shelf."""
|
||||
os.makedirs(BASE_DATA_PATH, exist_ok=True)
|
||||
return os.path.join(BASE_DATA_PATH, "id_type_mapping.shelf.4g")
|
||||
|
||||
|
||||
def get_parent_to_child_shelf_path() -> str:
|
||||
"""Get the path to the parent-to-child mapping shelf."""
|
||||
os.makedirs(BASE_DATA_PATH, exist_ok=True)
|
||||
return os.path.join(BASE_DATA_PATH, "parent_to_child_mapping.shelf.4g")
|
||||
|
||||
|
||||
def get_child_to_parent_shelf_path() -> str:
|
||||
"""Get the path to the child-to-parent mapping shelf."""
|
||||
os.makedirs(BASE_DATA_PATH, exist_ok=True)
|
||||
return os.path.join(BASE_DATA_PATH, "child_to_parent_mapping.shelf.4g")
|
||||
@@ -0,0 +1,737 @@
|
||||
import csv
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import find_ids_by_type
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
|
||||
get_affected_parent_ids_by_type,
|
||||
)
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_child_ids
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_record
|
||||
from onyx.connectors.salesforce.shelve_stuff.shelve_functions import (
|
||||
update_sf_db_with_csv,
|
||||
)
|
||||
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||
from onyx.connectors.salesforce.utils import get_object_type_path
|
||||
|
||||
_VALID_SALESFORCE_IDS = [
|
||||
"001bm00000fd9Z3AAI",
|
||||
"001bm00000fdYTdAAM",
|
||||
"001bm00000fdYTeAAM",
|
||||
"001bm00000fdYTfAAM",
|
||||
"001bm00000fdYTgAAM",
|
||||
"001bm00000fdYThAAM",
|
||||
"001bm00000fdYTiAAM",
|
||||
"001bm00000fdYTjAAM",
|
||||
"001bm00000fdYTkAAM",
|
||||
"001bm00000fdYTlAAM",
|
||||
"001bm00000fdYTmAAM",
|
||||
"001bm00000fdYTnAAM",
|
||||
"001bm00000fdYToAAM",
|
||||
"500bm00000XoOxtAAF",
|
||||
"500bm00000XoOxuAAF",
|
||||
"500bm00000XoOxvAAF",
|
||||
"500bm00000XoOxwAAF",
|
||||
"500bm00000XoOxxAAF",
|
||||
"500bm00000XoOxyAAF",
|
||||
"500bm00000XoOxzAAF",
|
||||
"500bm00000XoOy0AAF",
|
||||
"500bm00000XoOy1AAF",
|
||||
"500bm00000XoOy2AAF",
|
||||
"500bm00000XoOy3AAF",
|
||||
"500bm00000XoOy4AAF",
|
||||
"500bm00000XoOy5AAF",
|
||||
"500bm00000XoOy6AAF",
|
||||
"500bm00000XoOy7AAF",
|
||||
"500bm00000XoOy8AAF",
|
||||
"500bm00000XoOy9AAF",
|
||||
"500bm00000XoOyAAAV",
|
||||
"500bm00000XoOyBAAV",
|
||||
"500bm00000XoOyCAAV",
|
||||
"500bm00000XoOyDAAV",
|
||||
"500bm00000XoOyEAAV",
|
||||
"500bm00000XoOyFAAV",
|
||||
"500bm00000XoOyGAAV",
|
||||
"500bm00000XoOyHAAV",
|
||||
"500bm00000XoOyIAAV",
|
||||
"003bm00000EjHCjAAN",
|
||||
"003bm00000EjHCkAAN",
|
||||
"003bm00000EjHClAAN",
|
||||
"003bm00000EjHCmAAN",
|
||||
"003bm00000EjHCnAAN",
|
||||
"003bm00000EjHCoAAN",
|
||||
"003bm00000EjHCpAAN",
|
||||
"003bm00000EjHCqAAN",
|
||||
"003bm00000EjHCrAAN",
|
||||
"003bm00000EjHCsAAN",
|
||||
"003bm00000EjHCtAAN",
|
||||
"003bm00000EjHCuAAN",
|
||||
"003bm00000EjHCvAAN",
|
||||
"003bm00000EjHCwAAN",
|
||||
"003bm00000EjHCxAAN",
|
||||
"003bm00000EjHCyAAN",
|
||||
"003bm00000EjHCzAAN",
|
||||
"003bm00000EjHD0AAN",
|
||||
"003bm00000EjHD1AAN",
|
||||
"003bm00000EjHD2AAN",
|
||||
"550bm00000EXc2tAAD",
|
||||
"006bm000006kyDpAAI",
|
||||
"006bm000006kyDqAAI",
|
||||
"006bm000006kyDrAAI",
|
||||
"006bm000006kyDsAAI",
|
||||
"006bm000006kyDtAAI",
|
||||
"006bm000006kyDuAAI",
|
||||
"006bm000006kyDvAAI",
|
||||
"006bm000006kyDwAAI",
|
||||
"006bm000006kyDxAAI",
|
||||
"006bm000006kyDyAAI",
|
||||
"006bm000006kyDzAAI",
|
||||
"006bm000006kyE0AAI",
|
||||
"006bm000006kyE1AAI",
|
||||
"006bm000006kyE2AAI",
|
||||
"006bm000006kyE3AAI",
|
||||
"006bm000006kyE4AAI",
|
||||
"006bm000006kyE5AAI",
|
||||
"006bm000006kyE6AAI",
|
||||
"006bm000006kyE7AAI",
|
||||
"006bm000006kyE8AAI",
|
||||
"006bm000006kyE9AAI",
|
||||
"006bm000006kyEAAAY",
|
||||
"006bm000006kyEBAAY",
|
||||
"006bm000006kyECAAY",
|
||||
"006bm000006kyEDAAY",
|
||||
"006bm000006kyEEAAY",
|
||||
"006bm000006kyEFAAY",
|
||||
"006bm000006kyEGAAY",
|
||||
"006bm000006kyEHAAY",
|
||||
"006bm000006kyEIAAY",
|
||||
"006bm000006kyEJAAY",
|
||||
"005bm000009zy0TAAQ",
|
||||
"005bm000009zy25AAA",
|
||||
"005bm000009zy26AAA",
|
||||
"005bm000009zy28AAA",
|
||||
"005bm000009zy29AAA",
|
||||
"005bm000009zy2AAAQ",
|
||||
"005bm000009zy2BAAQ",
|
||||
]
|
||||
|
||||
|
||||
def clear_sf_db() -> None:
|
||||
"""
|
||||
Clears the SF DB by deleting all files in the data directory.
|
||||
"""
|
||||
shutil.rmtree(BASE_DATA_PATH)
|
||||
|
||||
|
||||
def create_csv_file(
|
||||
object_type: str, records: list[dict], filename: str = "test_data.csv"
|
||||
) -> None:
|
||||
"""
|
||||
Creates a CSV file for the given object type and records.
|
||||
|
||||
Args:
|
||||
object_type: The Salesforce object type (e.g. "Account", "Contact")
|
||||
records: List of dictionaries containing the record data
|
||||
filename: Name of the CSV file to create (default: test_data.csv)
|
||||
"""
|
||||
if not records:
|
||||
return
|
||||
|
||||
# Get all unique fields from records
|
||||
fields: set[str] = set()
|
||||
for record in records:
|
||||
fields.update(record.keys())
|
||||
fields = set(sorted(list(fields))) # Sort for consistent order
|
||||
|
||||
# Create CSV file
|
||||
csv_path = os.path.join(get_object_type_path(object_type), filename)
|
||||
with open(csv_path, "w", newline="", encoding="utf-8") as f:
|
||||
writer = csv.DictWriter(f, fieldnames=fields)
|
||||
writer.writeheader()
|
||||
for record in records:
|
||||
writer.writerow(record)
|
||||
|
||||
# Update the database with the CSV
|
||||
update_sf_db_with_csv(object_type, csv_path)
|
||||
|
||||
|
||||
def create_csv_with_example_data() -> None:
|
||||
"""
|
||||
Creates CSV files with example data, organized by object type.
|
||||
"""
|
||||
example_data: dict[str, list[dict]] = {
|
||||
"Account": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Acme Inc.",
|
||||
"BillingCity": "New York",
|
||||
"Industry": "Technology",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[1],
|
||||
"Name": "Globex Corp",
|
||||
"BillingCity": "Los Angeles",
|
||||
"Industry": "Manufacturing",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[2],
|
||||
"Name": "Initech",
|
||||
"BillingCity": "Austin",
|
||||
"Industry": "Software",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[3],
|
||||
"Name": "TechCorp Solutions",
|
||||
"BillingCity": "San Francisco",
|
||||
"Industry": "Software",
|
||||
"AnnualRevenue": 5000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[4],
|
||||
"Name": "BioMed Research",
|
||||
"BillingCity": "Boston",
|
||||
"Industry": "Healthcare",
|
||||
"AnnualRevenue": 12000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[5],
|
||||
"Name": "Green Energy Co",
|
||||
"BillingCity": "Portland",
|
||||
"Industry": "Energy",
|
||||
"AnnualRevenue": 8000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[6],
|
||||
"Name": "DataFlow Analytics",
|
||||
"BillingCity": "Seattle",
|
||||
"Industry": "Technology",
|
||||
"AnnualRevenue": 3000000,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[7],
|
||||
"Name": "Cloud Nine Services",
|
||||
"BillingCity": "Denver",
|
||||
"Industry": "Cloud Computing",
|
||||
"AnnualRevenue": 7000000,
|
||||
},
|
||||
],
|
||||
"Contact": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"FirstName": "John",
|
||||
"LastName": "Doe",
|
||||
"Email": "john.doe@acme.com",
|
||||
"Title": "CEO",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[41],
|
||||
"FirstName": "Jane",
|
||||
"LastName": "Smith",
|
||||
"Email": "jane.smith@acme.com",
|
||||
"Title": "CTO",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[42],
|
||||
"FirstName": "Bob",
|
||||
"LastName": "Johnson",
|
||||
"Email": "bob.j@globex.com",
|
||||
"Title": "Sales Director",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[43],
|
||||
"FirstName": "Sarah",
|
||||
"LastName": "Chen",
|
||||
"Email": "sarah.chen@techcorp.com",
|
||||
"Title": "Product Manager",
|
||||
"Phone": "415-555-0101",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[44],
|
||||
"FirstName": "Michael",
|
||||
"LastName": "Rodriguez",
|
||||
"Email": "m.rodriguez@biomed.com",
|
||||
"Title": "Research Director",
|
||||
"Phone": "617-555-0202",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[45],
|
||||
"FirstName": "Emily",
|
||||
"LastName": "Green",
|
||||
"Email": "emily.g@greenenergy.com",
|
||||
"Title": "Sustainability Lead",
|
||||
"Phone": "503-555-0303",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[46],
|
||||
"FirstName": "David",
|
||||
"LastName": "Kim",
|
||||
"Email": "david.kim@dataflow.com",
|
||||
"Title": "Data Scientist",
|
||||
"Phone": "206-555-0404",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[47],
|
||||
"FirstName": "Rachel",
|
||||
"LastName": "Taylor",
|
||||
"Email": "r.taylor@cloudnine.com",
|
||||
"Title": "Cloud Architect",
|
||||
"Phone": "303-555-0505",
|
||||
},
|
||||
],
|
||||
"Opportunity": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[62],
|
||||
"Name": "Acme Server Upgrade",
|
||||
"Amount": 50000,
|
||||
"Stage": "Prospecting",
|
||||
"CloseDate": "2024-06-30",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[63],
|
||||
"Name": "Globex Manufacturing Line",
|
||||
"Amount": 150000,
|
||||
"Stage": "Negotiation",
|
||||
"CloseDate": "2024-03-15",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[64],
|
||||
"Name": "Initech Software License",
|
||||
"Amount": 75000,
|
||||
"Stage": "Closed Won",
|
||||
"CloseDate": "2024-01-30",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[65],
|
||||
"Name": "TechCorp AI Implementation",
|
||||
"Amount": 250000,
|
||||
"Stage": "Needs Analysis",
|
||||
"CloseDate": "2024-08-15",
|
||||
"Probability": 60,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[66],
|
||||
"Name": "BioMed Lab Equipment",
|
||||
"Amount": 500000,
|
||||
"Stage": "Value Proposition",
|
||||
"CloseDate": "2024-09-30",
|
||||
"Probability": 75,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[67],
|
||||
"Name": "Green Energy Solar Project",
|
||||
"Amount": 750000,
|
||||
"Stage": "Proposal",
|
||||
"CloseDate": "2024-07-15",
|
||||
"Probability": 80,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[68],
|
||||
"Name": "DataFlow Analytics Platform",
|
||||
"Amount": 180000,
|
||||
"Stage": "Negotiation",
|
||||
"CloseDate": "2024-05-30",
|
||||
"Probability": 90,
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[69],
|
||||
"Name": "Cloud Nine Infrastructure",
|
||||
"Amount": 300000,
|
||||
"Stage": "Qualification",
|
||||
"CloseDate": "2024-10-15",
|
||||
"Probability": 40,
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Create CSV files for each object type
|
||||
for object_type, records in example_data.items():
|
||||
create_csv_file(object_type, records)
|
||||
|
||||
|
||||
def test_query() -> None:
|
||||
"""
|
||||
Tests querying functionality by verifying:
|
||||
1. All expected Account IDs are found
|
||||
2. Each Account's data matches what was inserted
|
||||
"""
|
||||
# Expected test data for verification
|
||||
expected_accounts: dict[str, dict[str, str | int]] = {
|
||||
_VALID_SALESFORCE_IDS[0]: {
|
||||
"Name": "Acme Inc.",
|
||||
"BillingCity": "New York",
|
||||
"Industry": "Technology",
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[1]: {
|
||||
"Name": "Globex Corp",
|
||||
"BillingCity": "Los Angeles",
|
||||
"Industry": "Manufacturing",
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[2]: {
|
||||
"Name": "Initech",
|
||||
"BillingCity": "Austin",
|
||||
"Industry": "Software",
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[3]: {
|
||||
"Name": "TechCorp Solutions",
|
||||
"BillingCity": "San Francisco",
|
||||
"Industry": "Software",
|
||||
"AnnualRevenue": 5000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[4]: {
|
||||
"Name": "BioMed Research",
|
||||
"BillingCity": "Boston",
|
||||
"Industry": "Healthcare",
|
||||
"AnnualRevenue": 12000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[5]: {
|
||||
"Name": "Green Energy Co",
|
||||
"BillingCity": "Portland",
|
||||
"Industry": "Energy",
|
||||
"AnnualRevenue": 8000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[6]: {
|
||||
"Name": "DataFlow Analytics",
|
||||
"BillingCity": "Seattle",
|
||||
"Industry": "Technology",
|
||||
"AnnualRevenue": 3000000,
|
||||
},
|
||||
_VALID_SALESFORCE_IDS[7]: {
|
||||
"Name": "Cloud Nine Services",
|
||||
"BillingCity": "Denver",
|
||||
"Industry": "Cloud Computing",
|
||||
"AnnualRevenue": 7000000,
|
||||
},
|
||||
}
|
||||
|
||||
# Get all Account IDs
|
||||
account_ids = find_ids_by_type("Account")
|
||||
|
||||
# Verify we found all expected accounts
|
||||
assert len(account_ids) == len(
|
||||
expected_accounts
|
||||
), f"Expected {len(expected_accounts)} accounts, found {len(account_ids)}"
|
||||
assert set(account_ids) == set(
|
||||
expected_accounts.keys()
|
||||
), "Found account IDs don't match expected IDs"
|
||||
|
||||
# Verify each account's data
|
||||
for acc_id in account_ids:
|
||||
combined = get_record(acc_id)
|
||||
assert combined is not None, f"Could not find account {acc_id}"
|
||||
|
||||
expected = expected_accounts[acc_id]
|
||||
|
||||
# Verify account data matches
|
||||
for key, value in expected.items():
|
||||
value = str(value)
|
||||
assert (
|
||||
combined.data[key] == value
|
||||
), f"Account {acc_id} field {key} expected {value}, got {combined.data[key]}"
|
||||
|
||||
print("All query tests passed successfully!")
|
||||
|
||||
|
||||
def test_upsert() -> None:
|
||||
"""
|
||||
Tests upsert functionality by:
|
||||
1. Updating an existing account
|
||||
2. Creating a new account
|
||||
3. Verifying both operations were successful
|
||||
"""
|
||||
# Create CSV for updating an existing account and adding a new one
|
||||
update_data: list[dict[str, str | int]] = [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Acme Inc. Updated",
|
||||
"BillingCity": "New York",
|
||||
"Industry": "Technology",
|
||||
"Description": "Updated company info",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[2],
|
||||
"Name": "New Company Inc.",
|
||||
"BillingCity": "Miami",
|
||||
"Industry": "Finance",
|
||||
"AnnualRevenue": 1000000,
|
||||
},
|
||||
]
|
||||
|
||||
create_csv_file("Account", update_data, "update_data.csv")
|
||||
|
||||
# Verify the update worked
|
||||
updated_record = get_record(_VALID_SALESFORCE_IDS[0])
|
||||
assert updated_record is not None, "Updated record not found"
|
||||
assert updated_record.data["Name"] == "Acme Inc. Updated", "Name not updated"
|
||||
assert (
|
||||
updated_record.data["Description"] == "Updated company info"
|
||||
), "Description not added"
|
||||
|
||||
# Verify the new record was created
|
||||
new_record = get_record(_VALID_SALESFORCE_IDS[2])
|
||||
assert new_record is not None, "New record not found"
|
||||
assert new_record.data["Name"] == "New Company Inc.", "New record name incorrect"
|
||||
assert new_record.data["AnnualRevenue"] == "1000000", "New record revenue incorrect"
|
||||
|
||||
print("All upsert tests passed successfully!")
|
||||
|
||||
|
||||
def test_relationships() -> None:
|
||||
"""
|
||||
Tests relationship shelf updates and queries by:
|
||||
1. Creating test data with relationships
|
||||
2. Verifying the relationships are correctly stored
|
||||
3. Testing relationship queries
|
||||
"""
|
||||
# Create test data for each object type
|
||||
test_data: dict[str, list[dict[str, str | int]]] = {
|
||||
"Case": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[13],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"Subject": "Test Case 1",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[14],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"Subject": "Test Case 2",
|
||||
},
|
||||
],
|
||||
"Contact": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[48],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"FirstName": "Test",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
],
|
||||
"Opportunity": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[62],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Test Opportunity",
|
||||
"Amount": 100000,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Create and update CSV files for each object type
|
||||
for object_type, records in test_data.items():
|
||||
create_csv_file(object_type, records, "relationship_test.csv")
|
||||
|
||||
# Test relationship queries
|
||||
# All these objects should be children of Acme Inc.
|
||||
child_ids = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||
assert len(child_ids) == 4, f"Expected 4 child objects, found {len(child_ids)}"
|
||||
assert _VALID_SALESFORCE_IDS[13] in child_ids, "Case 1 not found in relationship"
|
||||
assert _VALID_SALESFORCE_IDS[14] in child_ids, "Case 2 not found in relationship"
|
||||
assert _VALID_SALESFORCE_IDS[48] in child_ids, "Contact not found in relationship"
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[62] in child_ids
|
||||
), "Opportunity not found in relationship"
|
||||
|
||||
# Test querying relationships for a different account (should be empty)
|
||||
other_account_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
|
||||
assert (
|
||||
len(other_account_children) == 0
|
||||
), "Expected no children for different account"
|
||||
|
||||
print("All relationship tests passed successfully!")
|
||||
|
||||
|
||||
def test_account_with_children() -> None:
|
||||
"""
|
||||
Tests querying all accounts and retrieving their child objects.
|
||||
This test verifies that:
|
||||
1. All accounts can be retrieved
|
||||
2. Child objects are correctly linked
|
||||
3. Child object data is complete and accurate
|
||||
"""
|
||||
# First get all account IDs
|
||||
account_ids = find_ids_by_type("Account")
|
||||
assert len(account_ids) > 0, "No accounts found"
|
||||
|
||||
# For each account, get its children and verify the data
|
||||
for account_id in account_ids:
|
||||
account = get_record(account_id)
|
||||
assert account is not None, f"Could not find account {account_id}"
|
||||
|
||||
# Get all child objects
|
||||
child_ids = get_child_ids(account_id)
|
||||
|
||||
# For Acme Inc., verify specific relationships
|
||||
if account_id == _VALID_SALESFORCE_IDS[0]: # Acme Inc.
|
||||
assert (
|
||||
len(child_ids) == 4
|
||||
), f"Expected 4 children for Acme Inc., found {len(child_ids)}"
|
||||
|
||||
# Get all child records
|
||||
child_records = []
|
||||
for child_id in child_ids:
|
||||
child_record = get_record(child_id)
|
||||
if child_record is not None:
|
||||
child_records.append(child_record)
|
||||
# Verify Cases
|
||||
cases = [r for r in child_records if r.type == "Case"]
|
||||
assert (
|
||||
len(cases) == 2
|
||||
), f"Expected 2 cases for Acme Inc., found {len(cases)}"
|
||||
case_subjects = {case.data["Subject"] for case in cases}
|
||||
assert "Test Case 1" in case_subjects, "Test Case 1 not found"
|
||||
assert "Test Case 2" in case_subjects, "Test Case 2 not found"
|
||||
|
||||
# Verify Contacts
|
||||
contacts = [r for r in child_records if r.type == "Contact"]
|
||||
assert (
|
||||
len(contacts) == 1
|
||||
), f"Expected 1 contact for Acme Inc., found {len(contacts)}"
|
||||
contact = contacts[0]
|
||||
assert contact.data["FirstName"] == "Test", "Contact first name mismatch"
|
||||
assert contact.data["LastName"] == "Contact", "Contact last name mismatch"
|
||||
|
||||
# Verify Opportunities
|
||||
opportunities = [r for r in child_records if r.type == "Opportunity"]
|
||||
assert (
|
||||
len(opportunities) == 1
|
||||
), f"Expected 1 opportunity for Acme Inc., found {len(opportunities)}"
|
||||
opportunity = opportunities[0]
|
||||
assert (
|
||||
opportunity.data["Name"] == "Test Opportunity"
|
||||
), "Opportunity name mismatch"
|
||||
assert opportunity.data["Amount"] == "100000", "Opportunity amount mismatch"
|
||||
|
||||
print("All account with children tests passed successfully!")
|
||||
|
||||
|
||||
def test_relationship_updates() -> None:
|
||||
"""
|
||||
Tests that relationships are properly updated when a child object's parent reference changes.
|
||||
This test verifies:
|
||||
1. Initial relationship is created correctly
|
||||
2. When parent reference is updated, old relationship is removed
|
||||
3. New relationship is created correctly
|
||||
"""
|
||||
# Create initial test data - Contact linked to Acme Inc.
|
||||
initial_contact = [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"FirstName": "Test",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
]
|
||||
create_csv_file("Contact", initial_contact, "initial_contact.csv")
|
||||
|
||||
# Verify initial relationship
|
||||
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[40] in acme_children
|
||||
), "Initial relationship not created"
|
||||
|
||||
# Update contact to be linked to Globex Corp instead
|
||||
updated_contact = [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[1],
|
||||
"FirstName": "Test",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
]
|
||||
create_csv_file("Contact", updated_contact, "updated_contact.csv")
|
||||
|
||||
# Verify old relationship is removed
|
||||
acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0])
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[40] not in acme_children
|
||||
), "Old relationship not removed"
|
||||
|
||||
# Verify new relationship is created
|
||||
globex_children = get_child_ids(_VALID_SALESFORCE_IDS[1])
|
||||
assert _VALID_SALESFORCE_IDS[40] in globex_children, "New relationship not created"
|
||||
|
||||
print("All relationship update tests passed successfully!")
|
||||
|
||||
|
||||
def test_get_affected_parent_ids() -> None:
|
||||
"""
|
||||
Tests get_affected_parent_ids functionality by verifying:
|
||||
1. IDs that are directly in the parent_types list are included
|
||||
2. IDs that have children in the updated_ids list are included
|
||||
3. IDs that are neither of the above are not included
|
||||
"""
|
||||
# Create test data with relationships
|
||||
test_data = {
|
||||
"Account": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Parent Account 1",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[1],
|
||||
"Name": "Parent Account 2",
|
||||
},
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[2],
|
||||
"Name": "Not Affected Account",
|
||||
},
|
||||
],
|
||||
"Contact": [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[40],
|
||||
"AccountId": _VALID_SALESFORCE_IDS[0],
|
||||
"FirstName": "Child",
|
||||
"LastName": "Contact",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Create and update CSV files for test data
|
||||
for object_type, records in test_data.items():
|
||||
create_csv_file(object_type, records)
|
||||
|
||||
# Test Case 1: Account directly in updated_ids and parent_types
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[1]} # Parent Account 2
|
||||
parent_types = ["Account"]
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
|
||||
|
||||
# Test Case 2: Account with child in updated_ids
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
|
||||
parent_types = ["Account"]
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[0] in affected_ids
|
||||
), "Parent of updated child not included"
|
||||
|
||||
# Test Case 3: Both direct and indirect affects
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]} # Both cases
|
||||
parent_types = ["Account"]
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert len(affected_ids) == 2, "Expected exactly two affected parent IDs"
|
||||
assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included"
|
||||
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[2] not in affected_ids
|
||||
), "Unaffected ID incorrectly included"
|
||||
|
||||
# Test Case 4: No matches
|
||||
updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact
|
||||
parent_types = ["Opportunity"] # Wrong type
|
||||
affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types)
|
||||
assert len(affected_ids) == 0, "Should return empty list when no matches"
|
||||
|
||||
print("All get_affected_parent_ids tests passed successfully!")
|
||||
|
||||
|
||||
def main_build() -> None:
|
||||
clear_sf_db()
|
||||
create_csv_with_example_data()
|
||||
test_query()
|
||||
test_upsert()
|
||||
test_relationships()
|
||||
test_account_with_children()
|
||||
test_relationship_updates()
|
||||
test_get_affected_parent_ids()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main_build()
|
||||
386
backend/onyx/connectors/salesforce/sqlite_functions.py
Normal file
386
backend/onyx/connectors/salesforce/sqlite_functions.py
Normal file
@@ -0,0 +1,386 @@
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
|
||||
from onyx.connectors.salesforce.utils import get_sqlite_db_path
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.connectors.salesforce.utils import validate_salesforce_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.utils import batch_list
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_connection(
|
||||
isolation_level: str | None = None,
|
||||
) -> Iterator[sqlite3.Connection]:
|
||||
"""Get a database connection with proper isolation level and error handling.
|
||||
|
||||
Args:
|
||||
isolation_level: SQLite isolation level. None = default "DEFERRED",
|
||||
can be "IMMEDIATE" or "EXCLUSIVE" for more strict isolation.
|
||||
"""
|
||||
# 60 second timeout for locks
|
||||
conn = sqlite3.connect(get_sqlite_db_path(), timeout=60.0)
|
||||
|
||||
if isolation_level is not None:
|
||||
conn.isolation_level = isolation_level
|
||||
try:
|
||||
yield conn
|
||||
except Exception:
|
||||
conn.rollback()
|
||||
raise
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def init_db() -> None:
|
||||
"""Initialize the SQLite database with required tables if they don't exist."""
|
||||
if os.path.exists(get_sqlite_db_path()):
|
||||
return
|
||||
|
||||
# Create database directory if it doesn't exist
|
||||
os.makedirs(os.path.dirname(get_sqlite_db_path()), exist_ok=True)
|
||||
|
||||
with get_db_connection("EXCLUSIVE") as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Enable WAL mode for better concurrent access and write performance
|
||||
cursor.execute("PRAGMA journal_mode=WAL")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL")
|
||||
cursor.execute("PRAGMA temp_store=MEMORY")
|
||||
cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache
|
||||
|
||||
# Main table for storing Salesforce objects
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS salesforce_objects (
|
||||
id TEXT PRIMARY KEY,
|
||||
object_type TEXT NOT NULL,
|
||||
data TEXT NOT NULL, -- JSON serialized data
|
||||
last_modified INTEGER DEFAULT (strftime('%s', 'now')) -- Add timestamp for better cache management
|
||||
) WITHOUT ROWID -- Optimize for primary key lookups
|
||||
"""
|
||||
)
|
||||
|
||||
# Table for parent-child relationships with covering index
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS relationships (
|
||||
child_id TEXT NOT NULL,
|
||||
parent_id TEXT NOT NULL,
|
||||
PRIMARY KEY (child_id, parent_id)
|
||||
) WITHOUT ROWID -- Optimize for primary key lookups
|
||||
"""
|
||||
)
|
||||
|
||||
# New table for caching parent-child relationships with object types
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS relationship_types (
|
||||
child_id TEXT NOT NULL,
|
||||
parent_id TEXT NOT NULL,
|
||||
parent_type TEXT NOT NULL,
|
||||
PRIMARY KEY (child_id, parent_id, parent_type)
|
||||
) WITHOUT ROWID
|
||||
"""
|
||||
)
|
||||
|
||||
# Always recreate indexes to ensure they exist
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_object_type")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_parent_id")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_child_parent")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_object_type_id")
|
||||
cursor.execute("DROP INDEX IF EXISTS idx_relationship_types_lookup")
|
||||
|
||||
# Create covering indexes for common queries
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX idx_object_type
|
||||
ON salesforce_objects(object_type, id)
|
||||
WHERE object_type IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX idx_parent_id
|
||||
ON relationships(parent_id, child_id)
|
||||
"""
|
||||
)
|
||||
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX idx_child_parent
|
||||
ON relationships(child_id)
|
||||
WHERE child_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# New composite index for fast parent type lookups
|
||||
cursor.execute(
|
||||
"""
|
||||
CREATE INDEX idx_relationship_types_lookup
|
||||
ON relationship_types(parent_type, child_id, parent_id)
|
||||
"""
|
||||
)
|
||||
|
||||
# Analyze tables to help query planner
|
||||
cursor.execute("ANALYZE relationships")
|
||||
cursor.execute("ANALYZE salesforce_objects")
|
||||
cursor.execute("ANALYZE relationship_types")
|
||||
|
||||
conn.commit()
|
||||
|
||||
|
||||
def _update_relationship_tables(
|
||||
conn: sqlite3.Connection, child_id: str, parent_ids: set[str]
|
||||
) -> None:
|
||||
"""Update the relationship tables when a record is updated.
|
||||
|
||||
Args:
|
||||
conn: The database connection to use (must be in a transaction)
|
||||
child_id: The ID of the child record
|
||||
parent_ids: Set of parent IDs to link to
|
||||
"""
|
||||
try:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get existing parent IDs
|
||||
cursor.execute(
|
||||
"SELECT parent_id FROM relationships WHERE child_id = ?", (child_id,)
|
||||
)
|
||||
old_parent_ids = {row[0] for row in cursor.fetchall()}
|
||||
|
||||
# Calculate differences
|
||||
parent_ids_to_remove = old_parent_ids - parent_ids
|
||||
parent_ids_to_add = parent_ids - old_parent_ids
|
||||
|
||||
# Remove old relationships
|
||||
if parent_ids_to_remove:
|
||||
cursor.executemany(
|
||||
"DELETE FROM relationships WHERE child_id = ? AND parent_id = ?",
|
||||
[(child_id, pid) for pid in parent_ids_to_remove],
|
||||
)
|
||||
# Also remove from relationship_types
|
||||
cursor.executemany(
|
||||
"DELETE FROM relationship_types WHERE child_id = ? AND parent_id = ?",
|
||||
[(child_id, pid) for pid in parent_ids_to_remove],
|
||||
)
|
||||
|
||||
# Add new relationships
|
||||
if parent_ids_to_add:
|
||||
# First add to relationships table
|
||||
cursor.executemany(
|
||||
"INSERT INTO relationships (child_id, parent_id) VALUES (?, ?)",
|
||||
[(child_id, pid) for pid in parent_ids_to_add],
|
||||
)
|
||||
|
||||
# Then get the types of the parent objects and add to relationship_types
|
||||
for parent_id in parent_ids_to_add:
|
||||
cursor.execute(
|
||||
"SELECT object_type FROM salesforce_objects WHERE id = ?",
|
||||
(parent_id,),
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
parent_type = result[0]
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO relationship_types (child_id, parent_id, parent_type)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(child_id, parent_id, parent_type),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating relationship tables: {e}")
|
||||
logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}")
|
||||
raise
|
||||
|
||||
|
||||
def update_sf_db_with_csv(object_type: str, csv_download_path: str) -> list[str]:
|
||||
"""Update the SF DB with a CSV file using SQLite storage."""
|
||||
updated_ids = []
|
||||
|
||||
# Use IMMEDIATE to get a write lock at the start of the transaction
|
||||
with get_db_connection("IMMEDIATE") as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
with open(csv_download_path, "r", newline="", encoding="utf-8") as f:
|
||||
reader = csv.DictReader(f)
|
||||
for row in reader:
|
||||
if "Id" not in row:
|
||||
logger.warning(
|
||||
f"Row {row} does not have an Id field in {csv_download_path}"
|
||||
)
|
||||
continue
|
||||
id = row["Id"]
|
||||
parent_ids = set()
|
||||
field_to_remove: set[str] = set()
|
||||
|
||||
# Process relationships and clean data
|
||||
for field, value in row.items():
|
||||
if validate_salesforce_id(value) and field != "Id":
|
||||
parent_ids.add(value)
|
||||
field_to_remove.add(field)
|
||||
if not value:
|
||||
field_to_remove.add(field)
|
||||
|
||||
# Remove unwanted fields
|
||||
for field in field_to_remove:
|
||||
if field != "LastModifiedById":
|
||||
del row[field]
|
||||
|
||||
# Update main object data
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT OR REPLACE INTO salesforce_objects (id, object_type, data)
|
||||
VALUES (?, ?, ?)
|
||||
""",
|
||||
(id, object_type, json.dumps(row)),
|
||||
)
|
||||
|
||||
# Update relationships using the same connection
|
||||
_update_relationship_tables(conn, id, parent_ids)
|
||||
updated_ids.append(id)
|
||||
|
||||
conn.commit()
|
||||
|
||||
return updated_ids
|
||||
|
||||
|
||||
def get_child_ids(parent_id: str) -> set[str]:
|
||||
"""Get all child IDs for a given parent ID."""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Force index usage with INDEXED BY
|
||||
cursor.execute(
|
||||
"SELECT child_id FROM relationships INDEXED BY idx_parent_id WHERE parent_id = ?",
|
||||
(parent_id,),
|
||||
)
|
||||
child_ids = {row[0] for row in cursor.fetchall()}
|
||||
return child_ids
|
||||
|
||||
|
||||
def get_type_from_id(object_id: str) -> str | None:
|
||||
"""Get the type of an object from its ID."""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT object_type FROM salesforce_objects WHERE id = ?", (object_id,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
logger.warning(f"Object ID {object_id} not found")
|
||||
return None
|
||||
return result[0]
|
||||
|
||||
|
||||
def get_record(
|
||||
object_id: str, object_type: str | None = None
|
||||
) -> SalesforceObject | None:
|
||||
"""Retrieve the record and return it as a SalesforceObject."""
|
||||
if object_type is None:
|
||||
object_type = get_type_from_id(object_id)
|
||||
if not object_type:
|
||||
return None
|
||||
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT data FROM salesforce_objects WHERE id = ?", (object_id,))
|
||||
result = cursor.fetchone()
|
||||
if not result:
|
||||
logger.warning(f"Object ID {object_id} not found")
|
||||
return None
|
||||
|
||||
data = json.loads(result[0])
|
||||
return SalesforceObject(id=object_id, type=object_type, data=data)
|
||||
|
||||
|
||||
def find_ids_by_type(object_type: str) -> list[str]:
|
||||
"""Find all object IDs for rows of the specified type."""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT id FROM salesforce_objects WHERE object_type = ?", (object_type,)
|
||||
)
|
||||
return [row[0] for row in cursor.fetchall()]
|
||||
|
||||
|
||||
def get_affected_parent_ids_by_type(
|
||||
updated_ids: list[str],
|
||||
parent_types: list[str],
|
||||
batch_size: int = 500,
|
||||
) -> Iterator[tuple[str, set[str]]]:
|
||||
"""Get IDs of objects that are of the specified parent types and are either in the
|
||||
updated_ids or have children in the updated_ids. Yields tuples of (parent_type, affected_ids).
|
||||
"""
|
||||
# SQLite typically has a limit of 999 variables
|
||||
updated_ids_batches = batch_list(updated_ids, batch_size)
|
||||
updated_parent_ids: set[str] = set()
|
||||
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
|
||||
for batch_ids in updated_ids_batches:
|
||||
id_placeholders = ",".join(["?" for _ in batch_ids])
|
||||
|
||||
for parent_type in parent_types:
|
||||
affected_ids: set[str] = set()
|
||||
|
||||
# Get directly updated objects of parent types - using index on object_type
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT id FROM salesforce_objects
|
||||
WHERE id IN ({id_placeholders})
|
||||
AND object_type = ?
|
||||
""",
|
||||
batch_ids + [parent_type],
|
||||
)
|
||||
affected_ids.update(row[0] for row in cursor.fetchall())
|
||||
|
||||
# Get parent objects of updated objects - using optimized relationship_types table
|
||||
cursor.execute(
|
||||
f"""
|
||||
SELECT DISTINCT parent_id
|
||||
FROM relationship_types
|
||||
INDEXED BY idx_relationship_types_lookup
|
||||
WHERE parent_type = ?
|
||||
AND child_id IN ({id_placeholders})
|
||||
""",
|
||||
[parent_type] + batch_ids,
|
||||
)
|
||||
affected_ids.update(row[0] for row in cursor.fetchall())
|
||||
|
||||
# Remove any parent IDs that have already been processed
|
||||
new_affected_ids = affected_ids - updated_parent_ids
|
||||
# Add the new affected IDs to the set of updated parent IDs
|
||||
updated_parent_ids.update(new_affected_ids)
|
||||
|
||||
if new_affected_ids:
|
||||
yield parent_type, new_affected_ids
|
||||
|
||||
|
||||
def has_at_least_one_object_of_type(object_type: str) -> bool:
|
||||
"""Check if there is at least one object of the specified type in the database.
|
||||
|
||||
Args:
|
||||
object_type: The Salesforce object type to check
|
||||
|
||||
Returns:
|
||||
bool: True if at least one object exists, False otherwise
|
||||
"""
|
||||
with get_db_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(
|
||||
"SELECT COUNT(*) FROM salesforce_objects WHERE object_type = ?",
|
||||
(object_type,),
|
||||
)
|
||||
count = cursor.fetchone()[0]
|
||||
return count > 0
|
||||
@@ -1,66 +1,72 @@
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
def _clean_salesforce_dict(data: Union[dict, list]) -> Union[dict, list]:
|
||||
if isinstance(data, dict):
|
||||
if "records" in data.keys():
|
||||
data = data["records"]
|
||||
if isinstance(data, dict):
|
||||
if "attributes" in data.keys():
|
||||
if isinstance(data["attributes"], dict):
|
||||
data.update(data.pop("attributes"))
|
||||
@dataclass
|
||||
class SalesforceObject:
|
||||
id: str
|
||||
type: str
|
||||
data: dict[str, Any]
|
||||
|
||||
if isinstance(data, dict):
|
||||
filtered_dict = {}
|
||||
for key, value in data.items():
|
||||
if not re.search(SF_JSON_FILTER, key, re.IGNORECASE):
|
||||
if "__c" in key: # remove the custom object indicator for display
|
||||
key = key[:-3]
|
||||
if isinstance(value, (dict, list)):
|
||||
filtered_value = _clean_salesforce_dict(value)
|
||||
if filtered_value: # Only add non-empty dictionaries or lists
|
||||
filtered_dict[key] = filtered_value
|
||||
elif value is not None:
|
||||
filtered_dict[key] = value
|
||||
return filtered_dict
|
||||
elif isinstance(data, list):
|
||||
filtered_list = []
|
||||
for item in data:
|
||||
if isinstance(item, (dict, list)):
|
||||
filtered_item = _clean_salesforce_dict(item)
|
||||
if filtered_item: # Only add non-empty dictionaries or lists
|
||||
filtered_list.append(filtered_item)
|
||||
elif item is not None:
|
||||
filtered_list.append(filtered_item)
|
||||
return filtered_list
|
||||
else:
|
||||
return data
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"ID": self.id,
|
||||
"Type": self.type,
|
||||
"Data": self.data,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "SalesforceObject":
|
||||
return cls(
|
||||
id=data["Id"],
|
||||
type=data["Type"],
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
def _json_to_natural_language(data: Union[dict, list], indent: int = 0) -> str:
|
||||
result = []
|
||||
indent_str = " " * indent
|
||||
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
result.append(f"{indent_str}{key}:")
|
||||
result.append(_json_to_natural_language(value, indent + 2))
|
||||
else:
|
||||
result.append(f"{indent_str}{key}: {value}")
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
result.append(_json_to_natural_language(item, indent))
|
||||
else:
|
||||
result.append(f"{indent_str}{data}")
|
||||
|
||||
return "\n".join(result)
|
||||
# This defines the base path for all data files relative to this file
|
||||
# AKA BE CAREFUL WHEN MOVING THIS FILE
|
||||
BASE_DATA_PATH = os.path.join(os.path.dirname(__file__), "data")
|
||||
|
||||
|
||||
def extract_dict_text(raw_dict: dict) -> str:
|
||||
processed_dict = _clean_salesforce_dict(raw_dict)
|
||||
natural_language_dict = _json_to_natural_language(processed_dict)
|
||||
return natural_language_dict
|
||||
def get_sqlite_db_path() -> str:
|
||||
"""Get the path to the sqlite db file."""
|
||||
return os.path.join(BASE_DATA_PATH, "salesforce_db.sqlite")
|
||||
|
||||
|
||||
def get_object_type_path(object_type: str) -> str:
|
||||
"""Get the directory path for a specific object type."""
|
||||
type_dir = os.path.join(BASE_DATA_PATH, object_type)
|
||||
os.makedirs(type_dir, exist_ok=True)
|
||||
return type_dir
|
||||
|
||||
|
||||
_CHECKSUM_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345"
|
||||
_LOOKUP = {format(i, "05b"): _CHECKSUM_CHARS[i] for i in range(32)}
|
||||
|
||||
|
||||
def validate_salesforce_id(salesforce_id: str) -> bool:
|
||||
"""Validate the checksum portion of an 18-character Salesforce ID.
|
||||
|
||||
Args:
|
||||
salesforce_id: An 18-character Salesforce ID
|
||||
|
||||
Returns:
|
||||
bool: True if the checksum is valid, False otherwise
|
||||
"""
|
||||
if len(salesforce_id) != 18:
|
||||
return False
|
||||
|
||||
chunks = [salesforce_id[0:5], salesforce_id[5:10], salesforce_id[10:15]]
|
||||
|
||||
checksum = salesforce_id[15:18]
|
||||
calculated_checksum = ""
|
||||
|
||||
for chunk in chunks:
|
||||
result_string = "".join(
|
||||
"1" if char.isupper() else "0" for char in reversed(chunk)
|
||||
)
|
||||
calculated_checksum += _LOOKUP[result_string]
|
||||
|
||||
return checksum == calculated_checksum
|
||||
|
||||
@@ -264,24 +264,6 @@ class SlackTextCleaner:
|
||||
message = message.replace("<!everyone>", "@everyone")
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def replace_links(message: str) -> str:
|
||||
"""Replaces slack links e.g. `<URL>` -> `URL` and `<URL|DISPLAY>` -> `DISPLAY`"""
|
||||
# Find user IDs in the message
|
||||
possible_link_matches = re.findall(r"<(.*?)>", message)
|
||||
for possible_link in possible_link_matches:
|
||||
if not possible_link:
|
||||
continue
|
||||
# Special slack patterns that aren't for links
|
||||
if possible_link[0] not in ["#", "@", "!"]:
|
||||
link_display = (
|
||||
possible_link
|
||||
if "|" not in possible_link
|
||||
else possible_link.split("|")[1]
|
||||
)
|
||||
message = message.replace(f"<{possible_link}>", link_display)
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def replace_special_catchall(message: str) -> str:
|
||||
"""Replaces pattern of <!something|another-thing> with another-thing
|
||||
|
||||
@@ -33,6 +33,7 @@ from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.sitemap import list_pages_for_site
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -241,6 +242,12 @@ class WebConnector(LoadConnector):
|
||||
self.to_visit_list = extract_urls_from_sitemap(_ensure_valid_url(base_url))
|
||||
|
||||
elif web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.UPLOAD:
|
||||
# Explicitly check if running in multi-tenant mode to prevent potential security risks
|
||||
if MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Upload input for web connector is not supported in cloud environments"
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
"This is not a UI supported Web Connector flow, "
|
||||
"are you sure you want to do this?"
|
||||
|
||||
@@ -10,17 +10,21 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
time_str_to_utc,
|
||||
)
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
MAX_PAGE_SIZE = 30 # Zendesk API maximum
|
||||
_SLIM_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
class ZendeskCredentialsNotSetUpError(PermissionError):
|
||||
@@ -272,7 +276,7 @@ def _ticket_to_document(
|
||||
)
|
||||
|
||||
|
||||
class ZendeskConnector(LoadConnector, PollConnector):
|
||||
class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
@@ -397,6 +401,43 @@ class ZendeskConnector(LoadConnector, PollConnector):
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
if self.content_type == "articles":
|
||||
articles = _get_articles(
|
||||
self.client, start_time=int(start) if start else None
|
||||
)
|
||||
for article in articles:
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=f"article:{article['id']}",
|
||||
)
|
||||
)
|
||||
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
elif self.content_type == "tickets":
|
||||
tickets = _get_tickets(
|
||||
self.client, start_time=int(start) if start else None
|
||||
)
|
||||
for ticket in tickets:
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=f"zendesk_ticket_{ticket['id']}",
|
||||
)
|
||||
)
|
||||
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
else:
|
||||
raise ValueError(f"Unsupported content_type: {self.content_type}")
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy import exists
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -90,15 +91,22 @@ def get_connector_credential_pairs(
|
||||
user: User | None = None,
|
||||
get_editable: bool = True,
|
||||
ids: list[int] | None = None,
|
||||
eager_load_connector: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
|
||||
if eager_load_connector:
|
||||
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
|
||||
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
|
||||
if not include_disabled:
|
||||
stmt = stmt.where(
|
||||
ConnectorCredentialPair.status == ConnectorCredentialPairStatus.ACTIVE
|
||||
) # noqa
|
||||
)
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
|
||||
@@ -416,6 +416,18 @@ def update_docs_last_modified__no_commit(
|
||||
doc.last_modified = now
|
||||
|
||||
|
||||
def update_docs_chunk_count__no_commit(
|
||||
document_ids: list[str],
|
||||
doc_id_to_chunk_count: dict[str, int],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
documents_to_update = (
|
||||
db_session.query(DbDocument).filter(DbDocument.id.in_(document_ids)).all()
|
||||
)
|
||||
for doc in documents_to_update:
|
||||
doc.chunk_count = doc_id_to_chunk_count[doc.id]
|
||||
|
||||
|
||||
def mark_document_as_modified(
|
||||
document_id: str,
|
||||
db_session: Session,
|
||||
@@ -612,3 +624,25 @@ def get_document(
|
||||
stmt = select(DbDocument).where(DbDocument.id == document_id)
|
||||
doc: DbDocument | None = db_session.execute(stmt).scalar_one_or_none()
|
||||
return doc
|
||||
|
||||
|
||||
def fetch_chunk_counts_for_documents(
|
||||
document_ids: list[str],
|
||||
db_session: Session,
|
||||
) -> list[tuple[str, int | None]]:
|
||||
"""
|
||||
Return a list of (document_id, chunk_count) tuples.
|
||||
Note: chunk_count might be None if not set in DB,
|
||||
so we declare it as Optional[int].
|
||||
"""
|
||||
stmt = select(DbDocument.id, DbDocument.chunk_count).where(
|
||||
DbDocument.id.in_(document_ids)
|
||||
)
|
||||
|
||||
# results is a list of 'Row' objects, each containing two columns
|
||||
results = db_session.execute(stmt).all()
|
||||
|
||||
# If DbDocument.id is guaranteed to be a string, you can just do row.id;
|
||||
# otherwise cast to str if you need to be sure it's a string:
|
||||
return [(str(row[0]), row[1]) for row in results]
|
||||
# or row.id, row.chunk_count if they are named attributes in your ORM model
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import ssl
|
||||
@@ -14,7 +15,6 @@ from typing import ContextManager
|
||||
|
||||
import asyncpg # type: ignore
|
||||
import boto3
|
||||
import jwt
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy import event
|
||||
@@ -27,7 +27,7 @@ from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from onyx.configs.app_configs import AWS_REGION
|
||||
from onyx.configs.app_configs import AWS_REGION_NAME
|
||||
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
||||
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
@@ -40,9 +40,9 @@ from onyx.configs.app_configs import POSTGRES_POOL_PRE_PING
|
||||
from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -273,7 +273,7 @@ async def get_async_connection() -> Any:
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
db = POSTGRES_DB
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION)
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
|
||||
|
||||
# asyncpg requires 'ssl="require"' if SSL needed
|
||||
return await asyncpg.connect(
|
||||
@@ -315,38 +315,40 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION)
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
|
||||
cparams["password"] = token
|
||||
cparams["ssl"] = ssl_context
|
||||
|
||||
return _ASYNC_ENGINE
|
||||
|
||||
|
||||
def get_current_tenant_id(request: Request) -> str:
|
||||
async def get_current_tenant_id(request: Request) -> str:
|
||||
if not MULTI_TENANT:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
return tenant_id
|
||||
|
||||
token = request.cookies.get("fastapiusersauth")
|
||||
if not token:
|
||||
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
return current_value
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
USER_AUTH_SECRET,
|
||||
audience=["fastapi-users:auth"],
|
||||
algorithms=["HS256"],
|
||||
)
|
||||
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
# Look up token data in Redis
|
||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||
|
||||
if not token_data:
|
||||
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
logger.debug(
|
||||
f"Token data not found or expired in Redis, defaulting to {current_value}"
|
||||
)
|
||||
return current_value
|
||||
|
||||
tenant_id = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
return tenant_id
|
||||
except jwt.InvalidTokenError:
|
||||
return CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error decoding token data from Redis")
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Internal server error")
|
||||
@@ -368,9 +370,23 @@ async def get_async_session_with_tenant(
|
||||
bind=engine, expire_on_commit=False, class_=AsyncSession
|
||||
) # type: ignore
|
||||
|
||||
async def _set_search_path(session: AsyncSession, tenant_id: str) -> None:
|
||||
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
|
||||
async with async_session_factory() as session:
|
||||
# Register an event listener that is called whenever a new transaction starts
|
||||
@event.listens_for(session.sync_session, "after_begin")
|
||||
def after_begin(session_: Any, transaction: Any, connection: Any) -> None:
|
||||
# Because the event is sync, we can't directly await here.
|
||||
# Instead we queue up an asyncio task to ensures
|
||||
# the next statement sets the search_path
|
||||
session_.do_orm_execute = lambda state: connection.exec_driver_sql(
|
||||
f'SET search_path = "{tenant_id}"'
|
||||
)
|
||||
|
||||
try:
|
||||
await session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
await _set_search_path(session, tenant_id)
|
||||
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
await session.execute(
|
||||
text(
|
||||
@@ -525,6 +541,6 @@ def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) ->
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
region = os.getenv("AWS_REGION", "us-east-2")
|
||||
region = os.getenv("AWS_REGION_NAME", "us-east-2")
|
||||
# Configure for psycopg2 with IAM token
|
||||
configure_psycopg2_iam_auth(cparams, host, port, user, region)
|
||||
|
||||
@@ -54,6 +54,7 @@ from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.pydantic_type import PydanticType
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
@@ -65,6 +66,8 @@ from onyx.utils.headers import HeaderItemDict
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import RerankerProvider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
__abstract__ = True
|
||||
@@ -72,6 +75,8 @@ class Base(DeclarativeBase):
|
||||
|
||||
class EncryptedString(TypeDecorator):
|
||||
impl = LargeBinary
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: str | None, dialect: Dialect) -> bytes | None:
|
||||
if value is not None:
|
||||
@@ -86,6 +91,8 @@ class EncryptedString(TypeDecorator):
|
||||
|
||||
class EncryptedJson(TypeDecorator):
|
||||
impl = LargeBinary
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: dict | None, dialect: Dialect) -> bytes | None:
|
||||
if value is not None:
|
||||
@@ -102,6 +109,21 @@ class EncryptedJson(TypeDecorator):
|
||||
return value
|
||||
|
||||
|
||||
class NullFilteredString(TypeDecorator):
|
||||
impl = String
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: str | None, dialect: Dialect) -> str | None:
|
||||
if value is not None and "\x00" in value:
|
||||
logger.warning(f"NUL characters found in value: {value}")
|
||||
return value.replace("\x00", "")
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: str | None, dialect: Dialect) -> str | None:
|
||||
return value
|
||||
|
||||
|
||||
"""
|
||||
Auth/Authz (users, permissions, access) Tables
|
||||
"""
|
||||
@@ -451,16 +473,16 @@ class Document(Base):
|
||||
|
||||
# this should correspond to the ID of the document
|
||||
# (as is passed around in Onyx)
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
id: Mapped[str] = mapped_column(NullFilteredString, primary_key=True)
|
||||
from_ingestion_api: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, nullable=True
|
||||
)
|
||||
# 0 for neutral, positive for mostly endorse, negative for mostly reject
|
||||
boost: Mapped[int] = mapped_column(Integer, default=DEFAULT_BOOST)
|
||||
hidden: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
semantic_id: Mapped[str] = mapped_column(String)
|
||||
semantic_id: Mapped[str] = mapped_column(NullFilteredString)
|
||||
# First Section's link
|
||||
link: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
link: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True)
|
||||
|
||||
# The updated time is also used as a measure of the last successful state of the doc
|
||||
# pulled from the source (to help skip reindexing already updated docs in case of
|
||||
@@ -472,6 +494,10 @@ class Document(Base):
|
||||
DateTime(timezone=True), nullable=True
|
||||
)
|
||||
|
||||
# Number of chunks in the document (in Vespa)
|
||||
# Only null for documents indexed prior to this change
|
||||
chunk_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
|
||||
# last time any vespa relevant row metadata or the doc changed.
|
||||
# does not include last_synced
|
||||
last_modified: Mapped[datetime.datetime | None] = mapped_column(
|
||||
@@ -482,6 +508,7 @@ class Document(Base):
|
||||
last_synced: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, index=True
|
||||
)
|
||||
|
||||
# The following are not attached to User because the account/email may not be known
|
||||
# within Onyx
|
||||
# Something like the document creator
|
||||
|
||||
@@ -99,6 +99,9 @@ def _add_user_filters(
|
||||
return stmt.where(where_clause)
|
||||
|
||||
|
||||
# fetch_persona_by_id is used to fetch a persona by its ID. It is used to fetch a persona by its ID.
|
||||
|
||||
|
||||
def fetch_persona_by_id(
|
||||
db_session: Session, persona_id: int, user: User | None, get_editable: bool = True
|
||||
) -> Persona:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
@@ -6,10 +7,14 @@ from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlalchemy.sql.elements import KeyedColumnElement
|
||||
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
@@ -90,8 +95,10 @@ def validate_user_role_update(requested_role: UserRole, current_role: UserRole)
|
||||
)
|
||||
|
||||
|
||||
def list_users(
|
||||
db_session: Session, email_filter_string: str = "", include_external: bool = False
|
||||
def get_all_users(
|
||||
db_session: Session,
|
||||
email_filter_string: str | None = None,
|
||||
include_external: bool = False,
|
||||
) -> Sequence[User]:
|
||||
"""List all users. No pagination as of now, as the # of users
|
||||
is assumed to be relatively small (<< 1 million)"""
|
||||
@@ -102,7 +109,7 @@ def list_users(
|
||||
if not include_external:
|
||||
where_clause.append(User.role != UserRole.EXT_PERM_USER)
|
||||
|
||||
if email_filter_string:
|
||||
if email_filter_string is not None:
|
||||
where_clause.append(User.email.ilike(f"%{email_filter_string}%")) # type: ignore
|
||||
|
||||
stmt = stmt.where(*where_clause)
|
||||
@@ -110,13 +117,101 @@ def list_users(
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def _get_accepted_user_where_clause(
|
||||
email_filter_string: str | None = None,
|
||||
roles_filter: list[UserRole] = [],
|
||||
include_external: bool = False,
|
||||
is_active_filter: bool | None = None,
|
||||
) -> list[ColumnElement[bool]]:
|
||||
"""
|
||||
Generates a SQLAlchemy where clause for filtering users based on the provided parameters.
|
||||
This is used to build the filters for the function that retrieves the users for the users table in the admin panel.
|
||||
|
||||
Parameters:
|
||||
- email_filter_string: A substring to filter user emails. Only users whose emails contain this substring will be included.
|
||||
- is_active_filter: When True, only active users will be included. When False, only inactive users will be included.
|
||||
- roles_filter: A list of user roles to filter by. Only users with roles in this list will be included.
|
||||
- include_external: If False, external permissioned users will be excluded.
|
||||
|
||||
Returns:
|
||||
- list: A list of conditions to be used in a SQLAlchemy query to filter users.
|
||||
"""
|
||||
|
||||
# Access table columns directly via __table__.c to get proper SQLAlchemy column types
|
||||
# This ensures type checking works correctly for SQL operations like ilike, endswith, and is_
|
||||
email_col: KeyedColumnElement[Any] = User.__table__.c.email
|
||||
is_active_col: KeyedColumnElement[Any] = User.__table__.c.is_active
|
||||
|
||||
where_clause: list[ColumnElement[bool]] = [
|
||||
expression.not_(email_col.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN))
|
||||
]
|
||||
|
||||
if not include_external:
|
||||
where_clause.append(User.role != UserRole.EXT_PERM_USER)
|
||||
|
||||
if email_filter_string is not None:
|
||||
where_clause.append(email_col.ilike(f"%{email_filter_string}%"))
|
||||
|
||||
if roles_filter:
|
||||
where_clause.append(User.role.in_(roles_filter))
|
||||
|
||||
if is_active_filter is not None:
|
||||
where_clause.append(is_active_col.is_(is_active_filter))
|
||||
|
||||
return where_clause
|
||||
|
||||
|
||||
def get_page_of_filtered_users(
|
||||
db_session: Session,
|
||||
page_size: int,
|
||||
page_num: int,
|
||||
email_filter_string: str | None = None,
|
||||
is_active_filter: bool | None = None,
|
||||
roles_filter: list[UserRole] = [],
|
||||
include_external: bool = False,
|
||||
) -> Sequence[User]:
|
||||
users_stmt = select(User)
|
||||
|
||||
where_clause = _get_accepted_user_where_clause(
|
||||
email_filter_string=email_filter_string,
|
||||
roles_filter=roles_filter,
|
||||
include_external=include_external,
|
||||
is_active_filter=is_active_filter,
|
||||
)
|
||||
# Apply pagination
|
||||
users_stmt = users_stmt.offset((page_num) * page_size).limit(page_size)
|
||||
# Apply filtering
|
||||
users_stmt = users_stmt.where(*where_clause)
|
||||
|
||||
return db_session.scalars(users_stmt).unique().all()
|
||||
|
||||
|
||||
def get_total_filtered_users_count(
|
||||
db_session: Session,
|
||||
email_filter_string: str | None = None,
|
||||
is_active_filter: bool | None = None,
|
||||
roles_filter: list[UserRole] = [],
|
||||
include_external: bool = False,
|
||||
) -> int:
|
||||
where_clause = _get_accepted_user_where_clause(
|
||||
email_filter_string=email_filter_string,
|
||||
roles_filter=roles_filter,
|
||||
include_external=include_external,
|
||||
is_active_filter=is_active_filter,
|
||||
)
|
||||
total_count_stmt = select(func.count()).select_from(User)
|
||||
# Apply filtering
|
||||
total_count_stmt = total_count_stmt.where(*where_clause)
|
||||
|
||||
return db_session.scalar(total_count_stmt) or 0
|
||||
|
||||
|
||||
def get_user_by_email(email: str, db_session: Session) -> User | None:
|
||||
user = (
|
||||
db_session.query(User)
|
||||
.filter(func.lower(User.email) == func.lower(email))
|
||||
.first()
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import math
|
||||
import uuid
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
|
||||
|
||||
DEFAULT_BATCH_SIZE = 30
|
||||
@@ -36,25 +37,118 @@ def translate_boost_count_to_multiplier(boost: int) -> float:
|
||||
return 2 / (1 + math.exp(-1 * boost / 3))
|
||||
|
||||
|
||||
def get_uuid_from_chunk(
|
||||
chunk: IndexChunk | InferenceChunk, mini_chunk_ind: int = 0
|
||||
) -> uuid.UUID:
|
||||
doc_str = (
|
||||
chunk.document_id
|
||||
if isinstance(chunk, InferenceChunk)
|
||||
else chunk.source_document.id
|
||||
)
|
||||
def assemble_document_chunk_info(
|
||||
enriched_document_info_list: list[EnrichedDocumentIndexingInfo],
|
||||
tenant_id: str | None,
|
||||
large_chunks_enabled: bool,
|
||||
) -> list[UUID]:
|
||||
doc_chunk_ids = []
|
||||
|
||||
for enriched_document_info in enriched_document_info_list:
|
||||
for chunk_index in range(
|
||||
enriched_document_info.chunk_start_index,
|
||||
enriched_document_info.chunk_end_index,
|
||||
):
|
||||
if not enriched_document_info.old_version:
|
||||
doc_chunk_ids.append(
|
||||
get_uuid_from_chunk_info(
|
||||
document_id=enriched_document_info.doc_id,
|
||||
chunk_id=chunk_index,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
doc_chunk_ids.append(
|
||||
get_uuid_from_chunk_info_old(
|
||||
document_id=enriched_document_info.doc_id,
|
||||
chunk_id=chunk_index,
|
||||
)
|
||||
)
|
||||
|
||||
if large_chunks_enabled and chunk_index % 4 == 0:
|
||||
large_chunk_id = int(chunk_index / 4)
|
||||
large_chunk_reference_ids = [
|
||||
large_chunk_id + i
|
||||
for i in range(4)
|
||||
if large_chunk_id + i < enriched_document_info.chunk_end_index
|
||||
]
|
||||
if enriched_document_info.old_version:
|
||||
doc_chunk_ids.append(
|
||||
get_uuid_from_chunk_info_old(
|
||||
document_id=enriched_document_info.doc_id,
|
||||
chunk_id=large_chunk_id,
|
||||
large_chunk_reference_ids=large_chunk_reference_ids,
|
||||
)
|
||||
)
|
||||
else:
|
||||
doc_chunk_ids.append(
|
||||
get_uuid_from_chunk_info(
|
||||
document_id=enriched_document_info.doc_id,
|
||||
chunk_id=large_chunk_id,
|
||||
tenant_id=tenant_id,
|
||||
large_chunk_id=large_chunk_id,
|
||||
)
|
||||
)
|
||||
|
||||
return doc_chunk_ids
|
||||
|
||||
|
||||
def get_uuid_from_chunk_info(
|
||||
*,
|
||||
document_id: str,
|
||||
chunk_id: int,
|
||||
tenant_id: str | None,
|
||||
large_chunk_id: int | None = None,
|
||||
) -> UUID:
|
||||
doc_str = document_id
|
||||
|
||||
# Web parsing URL duplicate catching
|
||||
if doc_str and doc_str[-1] == "/":
|
||||
doc_str = doc_str[:-1]
|
||||
unique_identifier_string = "_".join(
|
||||
[doc_str, str(chunk.chunk_id), str(mini_chunk_ind)]
|
||||
|
||||
chunk_index = (
|
||||
"large_" + str(large_chunk_id) if large_chunk_id is not None else str(chunk_id)
|
||||
)
|
||||
if chunk.large_chunk_reference_ids:
|
||||
unique_identifier_string = "_".join([doc_str, chunk_index])
|
||||
if tenant_id:
|
||||
unique_identifier_string += "_" + tenant_id
|
||||
|
||||
return uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)
|
||||
|
||||
|
||||
def get_uuid_from_chunk_info_old(
|
||||
*, document_id: str, chunk_id: int, large_chunk_reference_ids: list[int] = []
|
||||
) -> UUID:
|
||||
doc_str = document_id
|
||||
|
||||
# Web parsing URL duplicate catching
|
||||
if doc_str and doc_str[-1] == "/":
|
||||
doc_str = doc_str[:-1]
|
||||
unique_identifier_string = "_".join([doc_str, str(chunk_id), "0"])
|
||||
if large_chunk_reference_ids:
|
||||
unique_identifier_string += "_large" + "_".join(
|
||||
[
|
||||
str(referenced_chunk_id)
|
||||
for referenced_chunk_id in chunk.large_chunk_reference_ids
|
||||
for referenced_chunk_id in large_chunk_reference_ids
|
||||
]
|
||||
)
|
||||
return uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)
|
||||
|
||||
|
||||
def get_uuid_from_chunk(chunk: DocMetadataAwareIndexChunk) -> uuid.UUID:
|
||||
return get_uuid_from_chunk_info(
|
||||
document_id=chunk.source_document.id,
|
||||
chunk_id=chunk.chunk_id,
|
||||
tenant_id=chunk.tenant_id,
|
||||
large_chunk_id=chunk.large_chunk_id,
|
||||
)
|
||||
|
||||
|
||||
def get_uuid_from_chunk_old(
|
||||
chunk: DocMetadataAwareIndexChunk, large_chunk_reference_ids: list[int] = []
|
||||
) -> UUID:
|
||||
return get_uuid_from_chunk_info_old(
|
||||
document_id=chunk.source_document.id,
|
||||
chunk_id=chunk.chunk_id,
|
||||
large_chunk_reference_ids=large_chunk_reference_ids,
|
||||
)
|
||||
|
||||
@@ -35,6 +35,38 @@ class VespaChunkRequest:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexBatchParams:
|
||||
"""
|
||||
Information necessary for efficiently indexing a batch of documents
|
||||
"""
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int | None]
|
||||
doc_id_to_new_chunk_cnt: dict[str, int]
|
||||
tenant_id: str | None
|
||||
large_chunks_enabled: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class MinimalDocumentIndexingInfo:
|
||||
"""
|
||||
Minimal information necessary for indexing a document
|
||||
"""
|
||||
|
||||
doc_id: str
|
||||
chunk_start_index: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnrichedDocumentIndexingInfo(MinimalDocumentIndexingInfo):
|
||||
"""
|
||||
Enriched information necessary for indexing a document, including version and chunk range.
|
||||
"""
|
||||
|
||||
old_version: bool
|
||||
chunk_end_index: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocumentMetadata:
|
||||
"""
|
||||
@@ -148,7 +180,7 @@ class Indexable(abc.ABC):
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
fresh_index: bool = False,
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""
|
||||
Takes a list of document chunks and indexes them in the document index
|
||||
@@ -166,14 +198,11 @@ class Indexable(abc.ABC):
|
||||
only needs to index chunks into the PRIMARY index. Do not update the secondary index here,
|
||||
it is done automatically outside of this code.
|
||||
|
||||
NOTE: The fresh_index parameter, when set to True, assumes no documents have been previously
|
||||
indexed for the given index/tenant. This can be used to optimize the indexing process for
|
||||
new or empty indices.
|
||||
|
||||
Parameters:
|
||||
- chunks: Document chunks with all of the information needed for indexing to the document
|
||||
index.
|
||||
- fresh_index: Boolean indicating whether this is a fresh index with no existing documents.
|
||||
- tenant_id: The tenant id of the user whose chunks are being indexed
|
||||
- large_chunks_enabled: Whether large chunks are enabled
|
||||
|
||||
Returns:
|
||||
List of document ids which map to unique documents and are used for deduping chunks
|
||||
@@ -185,7 +214,7 @@ class Indexable(abc.ABC):
|
||||
|
||||
class Deletable(abc.ABC):
|
||||
"""
|
||||
Class must implement the ability to delete document by their unique document ids.
|
||||
Class must implement the ability to delete document by a given unique document id.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -198,16 +227,6 @@ class Deletable(abc.ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def delete(self, doc_ids: list[str]) -> None:
|
||||
"""
|
||||
Given a list of document ids, hard delete them from the document index
|
||||
|
||||
Parameters:
|
||||
- doc_ids: list of document ids as specified by the connector
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Updatable(abc.ABC):
|
||||
"""
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import concurrent.futures
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
from retry import retry
|
||||
|
||||
from onyx.document_index.vespa.chunk_retrieval import (
|
||||
get_all_vespa_ids_for_document_id,
|
||||
)
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import NUM_THREADS
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -16,29 +14,27 @@ logger = setup_logger()
|
||||
CONTENT_SUMMARY = "content_summary"
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _delete_vespa_doc_chunks(
|
||||
document_id: str, index_name: str, http_client: httpx.Client
|
||||
@retry(tries=10, delay=1, backoff=2)
|
||||
def _retryable_http_delete(http_client: httpx.Client, url: str) -> None:
|
||||
res = http_client.delete(url)
|
||||
res.raise_for_status()
|
||||
|
||||
|
||||
def _delete_vespa_chunk(
|
||||
doc_chunk_id: UUID, index_name: str, http_client: httpx.Client
|
||||
) -> None:
|
||||
doc_chunk_ids = get_all_vespa_ids_for_document_id(
|
||||
document_id=document_id,
|
||||
index_name=index_name,
|
||||
get_large_chunks=True,
|
||||
)
|
||||
|
||||
for chunk_id in doc_chunk_ids:
|
||||
try:
|
||||
res = http_client.delete(
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{chunk_id}"
|
||||
)
|
||||
res.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Failed to delete chunk, details: {e.response.text}")
|
||||
raise
|
||||
try:
|
||||
_retryable_http_delete(
|
||||
http_client,
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}",
|
||||
)
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Failed to delete chunk, details: {e.response.text}")
|
||||
raise
|
||||
|
||||
|
||||
def delete_vespa_docs(
|
||||
document_ids: list[str],
|
||||
def delete_vespa_chunks(
|
||||
doc_chunk_ids: list[UUID],
|
||||
index_name: str,
|
||||
http_client: httpx.Client,
|
||||
executor: concurrent.futures.ThreadPoolExecutor | None = None,
|
||||
@@ -50,13 +46,13 @@ def delete_vespa_docs(
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS)
|
||||
|
||||
try:
|
||||
doc_deletion_future = {
|
||||
chunk_deletion_future = {
|
||||
executor.submit(
|
||||
_delete_vespa_doc_chunks, doc_id, index_name, http_client
|
||||
): doc_id
|
||||
for doc_id in document_ids
|
||||
_delete_vespa_chunk, doc_chunk_id, index_name, http_client
|
||||
): doc_chunk_id
|
||||
for doc_chunk_id in doc_chunk_ids
|
||||
}
|
||||
for future in concurrent.futures.as_completed(doc_deletion_future):
|
||||
for future in concurrent.futures.as_completed(chunk_deletion_future):
|
||||
# Will raise exception if the deletion raised an exception
|
||||
future.result()
|
||||
|
||||
|
||||
@@ -25,8 +25,12 @@ from onyx.configs.chat_configs import VESPA_SEARCHER_THREADS
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.document_index.document_index_utils import assemble_document_chunk_info
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentInsertionRecord
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.document_index.interfaces import MinimalDocumentIndexingInfo
|
||||
from onyx.document_index.interfaces import UpdateRequest
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
@@ -38,12 +42,10 @@ from onyx.document_index.vespa.chunk_retrieval import (
|
||||
parallel_visit_api_retrieval,
|
||||
)
|
||||
from onyx.document_index.vespa.chunk_retrieval import query_vespa
|
||||
from onyx.document_index.vespa.deletion import delete_vespa_docs
|
||||
from onyx.document_index.vespa.deletion import delete_vespa_chunks
|
||||
from onyx.document_index.vespa.indexing_utils import batch_index_vespa_chunks
|
||||
from onyx.document_index.vespa.indexing_utils import check_for_final_chunk_existence
|
||||
from onyx.document_index.vespa.indexing_utils import clean_chunk_id_copy
|
||||
from onyx.document_index.vespa.indexing_utils import (
|
||||
get_existing_documents_from_chunks,
|
||||
)
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
@@ -307,12 +309,18 @@ class VespaIndex(DocumentIndex):
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
fresh_index: bool = False,
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""Receive a list of chunks from a batch of documents and index the chunks into Vespa along
|
||||
with updating the associated permissions. Assumes that a document will not be split into
|
||||
multiple chunk batches calling this function multiple times, otherwise only the last set of
|
||||
chunks will be kept"""
|
||||
|
||||
doc_id_to_previous_chunk_cnt = index_batch_params.doc_id_to_previous_chunk_cnt
|
||||
doc_id_to_new_chunk_cnt = index_batch_params.doc_id_to_new_chunk_cnt
|
||||
tenant_id = index_batch_params.tenant_id
|
||||
large_chunks_enabled = index_batch_params.large_chunks_enabled
|
||||
|
||||
# IMPORTANT: This must be done one index at a time, do not use secondary index here
|
||||
cleaned_chunks = [clean_chunk_id_copy(chunk) for chunk in chunks]
|
||||
|
||||
@@ -324,30 +332,59 @@ class VespaIndex(DocumentIndex):
|
||||
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
|
||||
get_vespa_http_client() as http_client,
|
||||
):
|
||||
if not fresh_index:
|
||||
# Check for existing documents, existing documents need to have all of their chunks deleted
|
||||
# prior to indexing as the document size (num chunks) may have shrunk
|
||||
first_chunks = [
|
||||
chunk for chunk in cleaned_chunks if chunk.chunk_id == 0
|
||||
]
|
||||
for chunk_batch in batch_generator(first_chunks, BATCH_SIZE):
|
||||
existing_docs.update(
|
||||
get_existing_documents_from_chunks(
|
||||
chunks=chunk_batch,
|
||||
index_name=self.index_name,
|
||||
http_client=http_client,
|
||||
executor=executor,
|
||||
)
|
||||
)
|
||||
# We require the start and end index for each document in order to
|
||||
# know precisely which chunks to delete. This information exists for
|
||||
# documents that have `chunk_count` in the database, but not for
|
||||
# `old_version` documents.
|
||||
|
||||
for doc_id_batch in batch_generator(existing_docs, BATCH_SIZE):
|
||||
delete_vespa_docs(
|
||||
document_ids=doc_id_batch,
|
||||
enriched_doc_infos: list[EnrichedDocumentIndexingInfo] = []
|
||||
for document_id, doc_count in doc_id_to_previous_chunk_cnt.items():
|
||||
last_indexed_chunk = doc_id_to_previous_chunk_cnt.get(document_id, None)
|
||||
# If the document has no `chunk_count` in the database, we know that it
|
||||
# has the old chunk ID system and we must check for the final chunk index
|
||||
is_old_version = False
|
||||
if last_indexed_chunk is None:
|
||||
is_old_version = True
|
||||
minimal_doc_info = MinimalDocumentIndexingInfo(
|
||||
doc_id=document_id,
|
||||
chunk_start_index=doc_id_to_new_chunk_cnt.get(document_id, 0),
|
||||
)
|
||||
last_indexed_chunk = check_for_final_chunk_existence(
|
||||
minimal_doc_info=minimal_doc_info,
|
||||
start_index=doc_id_to_new_chunk_cnt[document_id],
|
||||
index_name=self.index_name,
|
||||
http_client=http_client,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
# If the document has previously indexed chunks, we know it previously existed
|
||||
if doc_count or last_indexed_chunk:
|
||||
existing_docs.add(document_id)
|
||||
|
||||
enriched_doc_info = EnrichedDocumentIndexingInfo(
|
||||
doc_id=document_id,
|
||||
chunk_start_index=doc_id_to_new_chunk_cnt.get(document_id, 0),
|
||||
chunk_end_index=last_indexed_chunk,
|
||||
old_version=is_old_version,
|
||||
)
|
||||
enriched_doc_infos.append(enriched_doc_info)
|
||||
|
||||
# Now, for each doc, we know exactly where to start and end our deletion
|
||||
# So let's generate the chunk IDs for each chunk to delete
|
||||
chunks_to_delete = assemble_document_chunk_info(
|
||||
enriched_document_info_list=enriched_doc_infos,
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=large_chunks_enabled,
|
||||
)
|
||||
|
||||
# Delete old Vespa documents
|
||||
for doc_chunk_ids_batch in batch_generator(chunks_to_delete, BATCH_SIZE):
|
||||
delete_vespa_chunks(
|
||||
doc_chunk_ids=doc_chunk_ids_batch,
|
||||
index_name=self.index_name,
|
||||
http_client=http_client,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
|
||||
batch_index_vespa_chunks(
|
||||
chunks=chunk_batch,
|
||||
@@ -588,24 +625,6 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
return total_chunks_updated
|
||||
|
||||
def delete(self, doc_ids: list[str]) -> None:
|
||||
logger.info(f"Deleting {len(doc_ids)} documents from Vespa")
|
||||
|
||||
doc_ids = [replace_invalid_doc_id_characters(doc_id) for doc_id in doc_ids]
|
||||
|
||||
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficial for
|
||||
# indexing / updates / deletes since we have to make a large volume of requests.
|
||||
with get_vespa_http_client() as http_client:
|
||||
index_names = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
for index_name in index_names:
|
||||
delete_vespa_docs(
|
||||
document_ids=doc_ids, index_name=index_name, http_client=http_client
|
||||
)
|
||||
return
|
||||
|
||||
def delete_single(self, doc_id: str) -> int:
|
||||
"""Possibly faster overall than the delete method due to using a single
|
||||
delete call with a selection query."""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import concurrent.futures
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
@@ -11,6 +12,8 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info_old
|
||||
from onyx.document_index.interfaces import MinimalDocumentIndexingInfo
|
||||
from onyx.document_index.vespa.shared_utils.utils import remove_invalid_unicode_chars
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
@@ -48,14 +51,9 @@ logger = setup_logger()
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _does_document_exist(
|
||||
doc_chunk_id: str,
|
||||
index_name: str,
|
||||
http_client: httpx.Client,
|
||||
def _does_doc_chunk_exist(
|
||||
doc_chunk_id: uuid.UUID, index_name: str, http_client: httpx.Client
|
||||
) -> bool:
|
||||
"""Returns whether the document already exists and the users/group whitelists
|
||||
Specifically in this case, document refers to a vespa document which is equivalent to a Onyx
|
||||
chunk. This checks for whether the chunk exists already in the index"""
|
||||
doc_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}"
|
||||
doc_fetch_response = http_client.get(doc_url)
|
||||
if doc_fetch_response.status_code == 404:
|
||||
@@ -64,10 +62,10 @@ def _does_document_exist(
|
||||
if doc_fetch_response.status_code != 200:
|
||||
logger.debug(f"Failed to check for document with URL {doc_url}")
|
||||
raise RuntimeError(
|
||||
f"Unexpected fetch document by ID value from Vespa "
|
||||
f"with error {doc_fetch_response.status_code}"
|
||||
f"Index name: {index_name}"
|
||||
f"Doc chunk id: {doc_chunk_id}"
|
||||
f"Unexpected fetch document by ID value from Vespa: "
|
||||
f"error={doc_fetch_response.status_code} "
|
||||
f"index={index_name} "
|
||||
f"doc_chunk_id={doc_chunk_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -98,8 +96,8 @@ def get_existing_documents_from_chunks(
|
||||
try:
|
||||
chunk_existence_future = {
|
||||
executor.submit(
|
||||
_does_document_exist,
|
||||
str(get_uuid_from_chunk(chunk)),
|
||||
_does_doc_chunk_exist,
|
||||
get_uuid_from_chunk(chunk),
|
||||
index_name,
|
||||
http_client,
|
||||
): chunk
|
||||
@@ -248,3 +246,22 @@ def clean_chunk_id_copy(
|
||||
}
|
||||
)
|
||||
return clean_chunk
|
||||
|
||||
|
||||
def check_for_final_chunk_existence(
|
||||
minimal_doc_info: MinimalDocumentIndexingInfo,
|
||||
start_index: int,
|
||||
index_name: str,
|
||||
http_client: httpx.Client,
|
||||
) -> int:
|
||||
index = start_index
|
||||
while True:
|
||||
doc_chunk_id = get_uuid_from_chunk_info_old(
|
||||
document_id=minimal_doc_info.doc_id,
|
||||
chunk_id=index,
|
||||
large_chunk_reference_ids=[],
|
||||
)
|
||||
if not _does_doc_chunk_exist(doc_chunk_id, index_name, http_client):
|
||||
return index
|
||||
|
||||
index += 1
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import SOURCE_TYPE
|
||||
from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -59,7 +60,8 @@ def build_vespa_filters(
|
||||
|
||||
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
|
||||
|
||||
if filters.tenant_id:
|
||||
# If running in multi-tenant mode, we may want to filter by tenant_id
|
||||
if filters.tenant_id and MULTI_TENANT:
|
||||
filter_str += f'({TENANT_ID} contains "{filters.tenant_id}") and '
|
||||
|
||||
# CAREFUL touching this one, currently there is no second ACL double-check post retrieval
|
||||
|
||||
@@ -35,6 +35,8 @@ DOCUMENT_ID_ENDPOINT = (
|
||||
f"{VESPA_APP_CONTAINER_URL}/document/v1/default/{{index_name}}/docid"
|
||||
)
|
||||
|
||||
# the default document id endpoint is http://localhost:8080/document/v1/default/danswer_chunk/docid
|
||||
|
||||
SEARCH_ENDPOINT = f"{VESPA_APP_CONTAINER_URL}/search/"
|
||||
|
||||
NUM_THREADS = (
|
||||
|
||||
@@ -67,7 +67,9 @@ def is_text_file_extension(file_name: str) -> bool:
|
||||
|
||||
def get_file_ext(file_path_or_name: str | Path) -> str:
|
||||
_, extension = os.path.splitext(file_path_or_name)
|
||||
return extension
|
||||
# standardize all extensions to be lowercase so that checks against
|
||||
# VALID_FILE_EXTENSIONS and similar will work as intended
|
||||
return extension.lower()
|
||||
|
||||
|
||||
def is_valid_file_ext(ext: str) -> bool:
|
||||
|
||||
@@ -73,25 +73,25 @@ def _get_metadata_suffix_for_document_index(
|
||||
return metadata_semantic, metadata_keyword
|
||||
|
||||
|
||||
def _combine_chunks(chunks: list[DocAwareChunk], index: int) -> DocAwareChunk:
|
||||
def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwareChunk:
|
||||
merged_chunk = DocAwareChunk(
|
||||
source_document=chunks[0].source_document,
|
||||
chunk_id=index,
|
||||
chunk_id=chunks[0].chunk_id,
|
||||
blurb=chunks[0].blurb,
|
||||
content=chunks[0].content,
|
||||
source_links=chunks[0].source_links or {},
|
||||
section_continuation=(index > 0),
|
||||
section_continuation=(chunks[0].chunk_id > 0),
|
||||
title_prefix=chunks[0].title_prefix,
|
||||
metadata_suffix_semantic=chunks[0].metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=chunks[0].metadata_suffix_keyword,
|
||||
large_chunk_reference_ids=[chunks[0].chunk_id],
|
||||
large_chunk_reference_ids=[chunk.chunk_id for chunk in chunks],
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=large_chunk_id,
|
||||
)
|
||||
|
||||
offset = 0
|
||||
for i in range(1, len(chunks)):
|
||||
merged_chunk.content += SECTION_SEPARATOR + chunks[i].content
|
||||
merged_chunk.large_chunk_reference_ids.append(chunks[i].chunk_id)
|
||||
|
||||
offset += len(SECTION_SEPARATOR) + len(chunks[i - 1].content)
|
||||
for link_offset, link_text in (chunks[i].source_links or {}).items():
|
||||
@@ -103,11 +103,12 @@ def _combine_chunks(chunks: list[DocAwareChunk], index: int) -> DocAwareChunk:
|
||||
|
||||
|
||||
def generate_large_chunks(chunks: list[DocAwareChunk]) -> list[DocAwareChunk]:
|
||||
large_chunks = [
|
||||
_combine_chunks(chunks[i : i + LARGE_CHUNK_RATIO], idx)
|
||||
for idx, i in enumerate(range(0, len(chunks), LARGE_CHUNK_RATIO))
|
||||
if len(chunks[i : i + LARGE_CHUNK_RATIO]) > 1
|
||||
]
|
||||
large_chunks = []
|
||||
for idx, i in enumerate(range(0, len(chunks), LARGE_CHUNK_RATIO)):
|
||||
chunk_group = chunks[i : i + LARGE_CHUNK_RATIO]
|
||||
if len(chunk_group) > 1:
|
||||
large_chunk = _combine_chunks(chunk_group, idx)
|
||||
large_chunks.append(large_chunk)
|
||||
return large_chunks
|
||||
|
||||
|
||||
@@ -219,6 +220,7 @@ class Chunker:
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=self._get_mini_chunk_texts(text),
|
||||
large_chunk_id=None,
|
||||
)
|
||||
|
||||
for section_idx, section in enumerate(document.sections):
|
||||
|
||||
@@ -20,8 +20,10 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
)
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.document import fetch_chunk_counts_for_documents
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.document import prepare_to_modify_documents
|
||||
from onyx.db.document import update_docs_chunk_count__no_commit
|
||||
from onyx.db.document import update_docs_last_modified__no_commit
|
||||
from onyx.db.document import update_docs_updated_at__no_commit
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
@@ -34,6 +36,7 @@ from onyx.db.tag import create_or_add_document_tag
|
||||
from onyx.db.tag import create_or_add_document_tag_list
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -370,16 +373,35 @@ def index_doc_batch(
|
||||
# NOTE: don't need to acquire till here, since this is when the actual race condition
|
||||
# with Vespa can occur.
|
||||
with prepare_to_modify_documents(db_session=db_session, document_ids=updatable_ids):
|
||||
document_id_to_access_info = get_access_for_documents(
|
||||
doc_id_to_access_info = get_access_for_documents(
|
||||
document_ids=updatable_ids, db_session=db_session
|
||||
)
|
||||
document_id_to_document_set = {
|
||||
doc_id_to_document_set = {
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
document_ids=updatable_ids, db_session=db_session
|
||||
)
|
||||
}
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int | None] = {
|
||||
document_id: chunk_count
|
||||
for document_id, chunk_count in fetch_chunk_counts_for_documents(
|
||||
document_ids=updatable_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
}
|
||||
|
||||
doc_id_to_new_chunk_cnt: dict[str, int] = {
|
||||
document_id: len(
|
||||
[
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == document_id
|
||||
]
|
||||
)
|
||||
for document_id in updatable_ids
|
||||
}
|
||||
|
||||
# we're concerned about race conditions where multiple simultaneous indexings might result
|
||||
# in one set of metadata overwriting another one in vespa.
|
||||
# we still write data here for the immediate and most likely correct sync, but
|
||||
@@ -388,11 +410,9 @@ def index_doc_batch(
|
||||
access_aware_chunks = [
|
||||
DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=document_id_to_access_info.get(
|
||||
chunk.source_document.id, no_access
|
||||
),
|
||||
access=doc_id_to_access_info.get(chunk.source_document.id, no_access),
|
||||
document_sets=set(
|
||||
document_id_to_document_set.get(chunk.source_document.id, [])
|
||||
doc_id_to_document_set.get(chunk.source_document.id, [])
|
||||
),
|
||||
boost=(
|
||||
ctx.id_to_db_doc_map[chunk.source_document.id].boost
|
||||
@@ -410,7 +430,15 @@ def index_doc_batch(
|
||||
# A document will not be spread across different batches, so all the
|
||||
# documents with chunks in this set, are fully represented by the chunks
|
||||
# in this set
|
||||
insertion_records = document_index.index(chunks=access_aware_chunks)
|
||||
insertion_records = document_index.index(
|
||||
chunks=access_aware_chunks,
|
||||
index_batch_params=IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=doc_id_to_new_chunk_cnt,
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=chunker.enable_large_chunks,
|
||||
),
|
||||
)
|
||||
|
||||
successful_doc_ids = [record.document_id for record in insertion_records]
|
||||
successful_docs = [
|
||||
@@ -435,6 +463,12 @@ def index_doc_batch(
|
||||
document_ids=last_modified_ids, db_session=db_session
|
||||
)
|
||||
|
||||
update_docs_chunk_count__no_commit(
|
||||
document_ids=updatable_ids,
|
||||
doc_id_to_chunk_count=doc_id_to_new_chunk_cnt,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
result = (
|
||||
@@ -445,6 +479,28 @@ def index_doc_batch(
|
||||
return result
|
||||
|
||||
|
||||
def check_enable_large_chunks_and_multipass(
|
||||
embedder: IndexingEmbedder, db_session: Session
|
||||
) -> tuple[bool, bool]:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
multipass = (
|
||||
search_settings.multipass_indexing
|
||||
if search_settings
|
||||
else ENABLE_MULTIPASS_INDEXING
|
||||
)
|
||||
|
||||
enable_large_chunks = (
|
||||
multipass
|
||||
and
|
||||
# Only local models that supports larger context are from Nomic
|
||||
(embedder.model_name.startswith("nomic-ai"))
|
||||
and
|
||||
# Cohere does not support larger context they recommend not going above 512 tokens
|
||||
embedder.provider_type != EmbeddingProvider.COHERE
|
||||
)
|
||||
return multipass, enable_large_chunks
|
||||
|
||||
|
||||
def build_indexing_pipeline(
|
||||
*,
|
||||
embedder: IndexingEmbedder,
|
||||
@@ -457,24 +513,8 @@ def build_indexing_pipeline(
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
multipass = (
|
||||
search_settings.multipass_indexing
|
||||
if search_settings
|
||||
else ENABLE_MULTIPASS_INDEXING
|
||||
)
|
||||
|
||||
enable_large_chunks = (
|
||||
multipass
|
||||
and
|
||||
# Only local models that supports larger context are from Nomic
|
||||
(
|
||||
embedder.provider_type is not None
|
||||
or embedder.model_name.startswith("nomic-ai")
|
||||
)
|
||||
and
|
||||
# Cohere does not support larger context they recommend not going above 512 tokens
|
||||
embedder.provider_type != EmbeddingProvider.COHERE
|
||||
multipass, enable_large_chunks = check_enable_large_chunks_and_multipass(
|
||||
embedder, db_session
|
||||
)
|
||||
|
||||
chunker = chunker or Chunker(
|
||||
|
||||
@@ -47,6 +47,8 @@ class DocAwareChunk(BaseChunk):
|
||||
|
||||
mini_chunk_texts: list[str] | None
|
||||
|
||||
large_chunk_id: int | None
|
||||
|
||||
large_chunk_reference_ids: list[int] = Field(default_factory=list)
|
||||
|
||||
def to_short_descriptor(self) -> str:
|
||||
|
||||
@@ -74,6 +74,9 @@ from onyx.server.manage.search_settings import router as search_settings_router
|
||||
from onyx.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from onyx.server.manage.users import router as user_router
|
||||
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
|
||||
from onyx.server.middleware.rate_limiting import close_limiter
|
||||
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
|
||||
from onyx.server.middleware.rate_limiting import setup_limiter
|
||||
from onyx.server.onyx_api.ingestion import router as onyx_api_router
|
||||
from onyx.server.openai_assistants_api.full_openai_assistants_api import (
|
||||
get_full_openai_assistants_api_router,
|
||||
@@ -153,6 +156,23 @@ def include_router_with_global_prefix_prepended(
|
||||
application.include_router(router, **final_kwargs)
|
||||
|
||||
|
||||
def include_auth_router_with_prefix(
|
||||
application: FastAPI,
|
||||
router: APIRouter,
|
||||
prefix: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Wrapper function to include an 'auth' router with prefix + rate-limiting dependencies."""
|
||||
final_tags = tags or ["auth"]
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
router,
|
||||
prefix=prefix,
|
||||
tags=final_tags,
|
||||
dependencies=get_auth_rate_limiters(),
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
# Set recursion limit
|
||||
@@ -194,8 +214,15 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
setup_multitenant_onyx()
|
||||
|
||||
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
||||
|
||||
# Set up rate limiter
|
||||
await setup_limiter()
|
||||
|
||||
yield
|
||||
|
||||
# Close rate limiter
|
||||
await close_limiter()
|
||||
|
||||
|
||||
def log_http_error(_: Request, exc: Exception) -> JSONResponse:
|
||||
status_code = getattr(exc, "status_code", 500)
|
||||
@@ -283,42 +310,37 @@ def get_application() -> FastAPI:
|
||||
pass
|
||||
|
||||
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_auth_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_register_router(UserRead, UserCreate),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_reset_password_router(),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_verify_router(UserRead),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_users_router(UserRead, UserUpdate),
|
||||
prefix="/users",
|
||||
tags=["users"],
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
oauth_client,
|
||||
@@ -330,15 +352,13 @@ def get_application() -> FastAPI:
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
|
||||
),
|
||||
prefix="/auth/oauth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# Need basic auth router for `logout` endpoint
|
||||
include_router_with_global_prefix_prepended(
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_logout_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
application.add_exception_handler(
|
||||
|
||||
@@ -131,10 +131,15 @@ class EmbeddingModel:
|
||||
tries=10, delay=10, exceptions=ModelServerRateLimitError
|
||||
)(final_make_request_func)
|
||||
|
||||
response: Response | None = None
|
||||
|
||||
try:
|
||||
response = final_make_request_func()
|
||||
return EmbedResponse(**response.json())
|
||||
except requests.HTTPError as e:
|
||||
if not response:
|
||||
raise HTTPError("HTTP error occurred - response is None.") from e
|
||||
|
||||
try:
|
||||
error_detail = response.json().get("detail", str(e))
|
||||
except Exception:
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from re import Match
|
||||
|
||||
import pytz
|
||||
import timeago # type: ignore
|
||||
@@ -59,33 +57,6 @@ def get_feedback_reminder_blocks(thread_link: str, include_followup: bool) -> Bl
|
||||
return SectionBlock(text=text)
|
||||
|
||||
|
||||
def _process_citations_for_slack(text: str) -> str:
|
||||
"""
|
||||
Converts instances of [[x]](LINK) in the input text to Slack's link format <LINK|[x]>.
|
||||
|
||||
Args:
|
||||
- text (str): The input string containing markdown links.
|
||||
|
||||
Returns:
|
||||
- str: The string with markdown links converted to Slack format.
|
||||
"""
|
||||
# Regular expression to find all instances of [[x]](LINK)
|
||||
pattern = r"\[\[(.*?)\]\]\((.*?)\)"
|
||||
|
||||
# Function to replace each found instance with Slack's format
|
||||
def slack_link_format(match: Match) -> str:
|
||||
link_text = match.group(1)
|
||||
link_url = match.group(2)
|
||||
|
||||
# Account for empty link citations
|
||||
if link_url == "":
|
||||
return f"[{link_text}]"
|
||||
return f"<{link_url}|[{link_text}]>"
|
||||
|
||||
# Substitute all matches in the input text
|
||||
return re.sub(pattern, slack_link_format, text)
|
||||
|
||||
|
||||
def _split_text(text: str, limit: int = 3000) -> list[str]:
|
||||
if len(text) <= limit:
|
||||
return [text]
|
||||
@@ -369,15 +340,12 @@ def _build_citations_blocks(
|
||||
|
||||
def _build_qa_response_blocks(
|
||||
answer: ChatOnyxBotResponse,
|
||||
process_message_for_citations: bool = False,
|
||||
) -> list[Block]:
|
||||
retrieval_info = answer.docs
|
||||
if not retrieval_info:
|
||||
# This should not happen, even with no docs retrieved, there is still info returned
|
||||
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
|
||||
|
||||
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return []
|
||||
|
||||
@@ -408,18 +376,18 @@ def _build_qa_response_blocks(
|
||||
|
||||
filter_block = SectionBlock(text=f"_{filter_text}_")
|
||||
|
||||
if not formatted_answer:
|
||||
if not answer.answer:
|
||||
answer_blocks = [
|
||||
SectionBlock(
|
||||
text="Sorry, I was unable to find an answer, but I did find some potentially relevant docs 🤓"
|
||||
)
|
||||
]
|
||||
else:
|
||||
# replaces markdown links with slack format links
|
||||
formatted_answer = format_slack_message(answer.answer)
|
||||
answer_processed = decode_escapes(
|
||||
remove_slack_text_interactions(formatted_answer)
|
||||
)
|
||||
if process_message_for_citations:
|
||||
answer_processed = _process_citations_for_slack(answer_processed)
|
||||
answer_blocks = [
|
||||
SectionBlock(text=text) for text in _split_text(answer_processed)
|
||||
]
|
||||
@@ -525,7 +493,6 @@ def build_slack_response_blocks(
|
||||
|
||||
answer_blocks = _build_qa_response_blocks(
|
||||
answer=answer,
|
||||
process_message_for_citations=use_citations,
|
||||
)
|
||||
|
||||
web_follow_up_block = []
|
||||
|
||||
@@ -4,73 +4,55 @@ from onyx.configs.constants import DocumentSource
|
||||
def source_to_github_img_link(source: DocumentSource) -> str | None:
|
||||
# TODO: store these images somewhere better
|
||||
if source == DocumentSource.WEB.value:
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/Web.png"
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Web.png"
|
||||
if source == DocumentSource.FILE.value:
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/File.png"
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/File.png"
|
||||
if source == DocumentSource.GOOGLE_SITES.value:
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/GoogleSites.png"
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/GoogleSites.png"
|
||||
if source == DocumentSource.SLACK.value:
|
||||
return (
|
||||
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Slack.png"
|
||||
)
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Slack.png"
|
||||
if source == DocumentSource.GMAIL.value:
|
||||
return (
|
||||
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Gmail.png"
|
||||
)
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Gmail.png"
|
||||
if source == DocumentSource.GOOGLE_DRIVE.value:
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/GoogleDrive.png"
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/GoogleDrive.png"
|
||||
if source == DocumentSource.GITHUB.value:
|
||||
return (
|
||||
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Github.png"
|
||||
)
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Github.png"
|
||||
if source == DocumentSource.GITLAB.value:
|
||||
return (
|
||||
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Gitlab.png"
|
||||
)
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Gitlab.png"
|
||||
if source == DocumentSource.CONFLUENCE.value:
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/Confluence.png"
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Confluence.png"
|
||||
if source == DocumentSource.JIRA.value:
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/Jira.png"
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Jira.png"
|
||||
if source == DocumentSource.NOTION.value:
|
||||
return (
|
||||
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Notion.png"
|
||||
)
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Notion.png"
|
||||
if source == DocumentSource.ZENDESK.value:
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/Zendesk.png"
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Zendesk.png"
|
||||
if source == DocumentSource.GONG.value:
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Gong.png"
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Gong.png"
|
||||
if source == DocumentSource.LINEAR.value:
|
||||
return (
|
||||
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Linear.png"
|
||||
)
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Linear.png"
|
||||
if source == DocumentSource.PRODUCTBOARD.value:
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Productboard.webp"
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Productboard.webp"
|
||||
if source == DocumentSource.SLAB.value:
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/SlabLogo.png"
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/SlabLogo.png"
|
||||
if source == DocumentSource.ZULIP.value:
|
||||
return (
|
||||
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Zulip.png"
|
||||
)
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Zulip.png"
|
||||
if source == DocumentSource.GURU.value:
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/Guru.png"
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Guru.png"
|
||||
if source == DocumentSource.HUBSPOT.value:
|
||||
return (
|
||||
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/HubSpot.png"
|
||||
)
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/HubSpot.png"
|
||||
if source == DocumentSource.DOCUMENT360.value:
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Document360.png"
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Document360.png"
|
||||
if source == DocumentSource.BOOKSTACK.value:
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Bookstack.png"
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Bookstack.png"
|
||||
if source == DocumentSource.LOOPIO.value:
|
||||
return (
|
||||
"https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Loopio.png"
|
||||
)
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Loopio.png"
|
||||
if source == DocumentSource.SHAREPOINT.value:
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/web/public/Sharepoint.png"
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Sharepoint.png"
|
||||
if source == DocumentSource.REQUESTTRACKER.value:
|
||||
# just use file icon for now
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/File.png"
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/File.png"
|
||||
if source == DocumentSource.INGESTION_API.value:
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/File.png"
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/File.png"
|
||||
|
||||
return "https://raw.githubusercontent.com/onyx-ai/onyx/main/backend/slackbot_images/File.png"
|
||||
return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/File.png"
|
||||
|
||||
@@ -375,7 +375,6 @@ def remove_slack_text_interactions(slack_str: str) -> str:
|
||||
slack_str = SlackTextCleaner.replace_tags_basic(slack_str)
|
||||
slack_str = SlackTextCleaner.replace_channels_basic(slack_str)
|
||||
slack_str = SlackTextCleaner.replace_special_mentions(slack_str)
|
||||
slack_str = SlackTextCleaner.replace_links(slack_str)
|
||||
slack_str = SlackTextCleaner.replace_special_catchall(slack_str)
|
||||
slack_str = SlackTextCleaner.add_zero_width_whitespace_after_tag(slack_str)
|
||||
return slack_str
|
||||
|
||||
@@ -30,6 +30,9 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
# SYNCING_HASH = PREFIX + ":vespa_syncing"
|
||||
SYNCING_PREFIX = PREFIX + ":vespa_syncing"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
@@ -56,6 +59,10 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
# the list on the fly
|
||||
self.skip_docs = skip_docs
|
||||
|
||||
@staticmethod
|
||||
def make_redis_syncing_key(doc_id: str) -> str:
|
||||
return f"{RedisConnectorCredentialPair.SYNCING_PREFIX}:{doc_id}"
|
||||
|
||||
def generate_tasks(
|
||||
self,
|
||||
celery_app: Celery,
|
||||
@@ -64,6 +71,9 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
) -> tuple[int, int] | None:
|
||||
# an arbitrary number in seconds to prevent the same doc from syncing repeatedly
|
||||
SYNC_EXPIRATION = 24 * 60 * 60
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
async_results = []
|
||||
@@ -92,6 +102,14 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
if doc.id in self.skip_docs:
|
||||
continue
|
||||
|
||||
# is the document sync already queued?
|
||||
# if redis_client.hexists(doc.id):
|
||||
# continue
|
||||
|
||||
redis_syncing_key = self.make_redis_syncing_key(doc.id)
|
||||
if redis_client.exists(redis_syncing_key):
|
||||
continue
|
||||
|
||||
# celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac"
|
||||
# we prefix the task id so it's easier to keep track of who created the task
|
||||
@@ -104,6 +122,13 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
RedisConnectorCredentialPair.get_taskset_key(), custom_task_id
|
||||
)
|
||||
|
||||
# track the doc.id in redis so that we don't resubmit it repeatedly
|
||||
# redis_client.hset(
|
||||
# self.SYNCING_HASH, doc.id, custom_task_id
|
||||
# )
|
||||
|
||||
redis_client.set(redis_syncing_key, custom_task_id, ex=SYNC_EXPIRATION)
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
|
||||
@@ -162,7 +162,7 @@ class RedisConnectorPermissionSync:
|
||||
),
|
||||
queue=OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
async_results.append(result)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.document_set import construct_document_select_by_docset
|
||||
from onyx.db.models import Document
|
||||
from onyx.redis.redis_object_helper import RedisObjectHelper
|
||||
|
||||
|
||||
@@ -60,6 +61,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
async_results = []
|
||||
stmt = construct_document_select_by_docset(int(self._id), current_only=False)
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
doc = cast(Document, doc)
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
|
||||
@@ -1,12 +1,19 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import json
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
import redis
|
||||
from fastapi import Request
|
||||
from redis import asyncio as aioredis
|
||||
from redis.client import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||
from onyx.configs.app_configs import REDIS_DB_NUMBER
|
||||
from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL
|
||||
from onyx.configs.app_configs import REDIS_HOST
|
||||
@@ -105,6 +112,9 @@ class TenantRedis(redis.Redis):
|
||||
"sadd",
|
||||
"srem",
|
||||
"scard",
|
||||
"hexists",
|
||||
"hset",
|
||||
"hdel",
|
||||
] # Regular methods that need simple prefixing
|
||||
|
||||
if item == "scan_iter":
|
||||
@@ -196,3 +206,87 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
|
||||
# redis_client.set('key', 'value')
|
||||
# value = redis_client.get('key')
|
||||
# print(value.decode()) # Output: 'value'
|
||||
|
||||
_async_redis_connection: aioredis.Redis | None = None
|
||||
_async_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_async_redis_connection() -> aioredis.Redis:
|
||||
"""
|
||||
Provides a shared async Redis connection, using the same configs (host, port, SSL, etc.).
|
||||
Ensures that the connection is created only once (lazily) and reused for all future calls.
|
||||
"""
|
||||
global _async_redis_connection
|
||||
|
||||
# If we haven't yet created an async Redis connection, we need to create one
|
||||
if _async_redis_connection is None:
|
||||
# Acquire the lock to ensure that only one coroutine attempts to create the connection
|
||||
async with _async_lock:
|
||||
# Double-check inside the lock to avoid race conditions
|
||||
if _async_redis_connection is None:
|
||||
scheme = "rediss" if REDIS_SSL else "redis"
|
||||
url = f"{scheme}://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER}"
|
||||
|
||||
# Create a new Redis connection (or connection pool) from the URL
|
||||
_async_redis_connection = aioredis.from_url(
|
||||
url,
|
||||
password=REDIS_PASSWORD,
|
||||
max_connections=REDIS_POOL_MAX_CONNECTIONS,
|
||||
)
|
||||
|
||||
# Return the established connection (or pool) for all future operations
|
||||
return _async_redis_connection
|
||||
|
||||
|
||||
async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None:
|
||||
token = request.cookies.get("fastapiusersauth")
|
||||
if not token:
|
||||
logger.debug("No auth token cookie found")
|
||||
return None
|
||||
|
||||
try:
|
||||
redis = await get_async_redis_connection()
|
||||
redis_key = REDIS_AUTH_KEY_PREFIX + token
|
||||
token_data_str = await redis.get(redis_key)
|
||||
|
||||
if not token_data_str:
|
||||
logger.debug(f"Token key {redis_key} not found or expired in Redis")
|
||||
return None
|
||||
|
||||
return json.loads(token_data_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.error("Error decoding token data from Redis")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Unexpected error in retrieve_auth_token_data_from_redis: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
def redis_lock_dump(lock: RedisLock, r: Redis) -> None:
|
||||
# diagnostic logging for lock errors
|
||||
name = lock.name
|
||||
ttl = r.ttl(name)
|
||||
locked = lock.locked()
|
||||
owned = lock.owned()
|
||||
local_token: str | None = lock.local.token # type: ignore
|
||||
|
||||
remote_token_raw = r.get(lock.name)
|
||||
if remote_token_raw:
|
||||
remote_token_bytes = cast(bytes, remote_token_raw)
|
||||
remote_token = remote_token_bytes.decode("utf-8")
|
||||
else:
|
||||
remote_token = None
|
||||
|
||||
logger.warning(
|
||||
f"RedisLock diagnostic: "
|
||||
f"name={name} "
|
||||
f"locked={locked} "
|
||||
f"owned={owned} "
|
||||
f"local_token={local_token} "
|
||||
f"remote_token={remote_token} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.models import Document
|
||||
from onyx.redis.redis_object_helper import RedisObjectHelper
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
@@ -73,6 +74,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
|
||||
stmt = construct_document_select_by_usergroup(int(self._id))
|
||||
for doc in db_session.scalars(stmt).yield_per(1):
|
||||
doc = cast(Document, doc)
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (
|
||||
CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4
|
||||
|
||||
@@ -25,6 +25,7 @@ from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import mock_successful_index_attempt
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
@@ -86,6 +87,7 @@ def _create_indexable_chunks(
|
||||
access=default_public_access,
|
||||
document_sets=set(),
|
||||
boost=DEFAULT_BOOST,
|
||||
large_chunk_id=None,
|
||||
)
|
||||
chunks.append(chunk)
|
||||
|
||||
@@ -217,7 +219,15 @@ def seed_initial_documents(
|
||||
# as we just sent over the Vespa schema and there is a slight delay
|
||||
|
||||
index_with_retries = retry_builder(tries=15)(document_index.index)
|
||||
index_with_retries(chunks=chunks, fresh_index=cohere_enabled)
|
||||
index_with_retries(
|
||||
chunks=chunks,
|
||||
index_batch_params=IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt={},
|
||||
doc_id_to_new_chunk_cnt={},
|
||||
large_chunks_enabled=False,
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
)
|
||||
|
||||
# Mock a run for the UI even though it did not actually call out to anything
|
||||
mock_successful_index_attempt(
|
||||
|
||||
@@ -5,6 +5,7 @@ from fastapi.dependencies.models import Dependant
|
||||
from starlette.routing import BaseRoute
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
@@ -109,6 +110,7 @@ def check_router_auth(
|
||||
or depends_fn == current_curator_or_admin_user
|
||||
or depends_fn == api_key_dep
|
||||
or depends_fn == current_user_with_expired_token
|
||||
or depends_fn == current_chat_accesssible_user
|
||||
or depends_fn == control_plane_dep
|
||||
or depends_fn == current_cloud_superuser
|
||||
):
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import math
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
|
||||
@@ -48,7 +47,8 @@ from onyx.server.documents.models import CCStatusUpdateRequest
|
||||
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from onyx.server.documents.models import ConnectorCredentialPairMetadata
|
||||
from onyx.server.documents.models import DocumentSyncStatus
|
||||
from onyx.server.documents.models import PaginatedIndexAttempts
|
||||
from onyx.server.documents.models import IndexAttemptSnapshot
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.models import StatusResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
@@ -64,7 +64,7 @@ def get_cc_pair_index_attempts(
|
||||
page_size: int = Query(10, ge=1, le=1000),
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedIndexAttempts:
|
||||
) -> PaginatedReturn[IndexAttemptSnapshot]:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id, db_session, user, get_editable=False
|
||||
)
|
||||
@@ -82,10 +82,12 @@ def get_cc_pair_index_attempts(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return PaginatedIndexAttempts.from_models(
|
||||
index_attempt_models=index_attempts,
|
||||
page=page,
|
||||
total_pages=math.ceil(total_count / page_size),
|
||||
return PaginatedReturn(
|
||||
items=[
|
||||
IndexAttemptSnapshot.from_index_attempt_db_model(index_attempt)
|
||||
for index_attempt in index_attempts
|
||||
],
|
||||
total_items=total_count,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.celery_utils import get_deletion_attempt_snapshot
|
||||
@@ -1055,10 +1056,10 @@ class BasicCCPairInfo(BaseModel):
|
||||
|
||||
@router.get("/connector-status")
|
||||
def get_basic_connector_indexing_status(
|
||||
_: User = Depends(current_user),
|
||||
_: User = Depends(current_chat_accesssible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[BasicCCPairInfo]:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
cc_pairs = get_connector_credential_pairs(db_session, eager_load_connector=True)
|
||||
return [
|
||||
BasicCCPairInfo(
|
||||
has_successful_run=cc_pair.last_successful_index_time is not None,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -19,6 +21,8 @@ from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexAttemptError as DbIndexAttemptError
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import TaskStatus
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import InvitedUserSnapshot
|
||||
from onyx.server.utils import mask_credential_dict
|
||||
|
||||
|
||||
@@ -201,26 +205,19 @@ class IndexAttemptError(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class PaginatedIndexAttempts(BaseModel):
|
||||
index_attempts: list[IndexAttemptSnapshot]
|
||||
page: int
|
||||
total_pages: int
|
||||
# These are the types currently supported by the pagination hook
|
||||
# More api endpoints can be refactored and be added here for use with the pagination hook
|
||||
PaginatedType = TypeVar(
|
||||
"PaginatedType",
|
||||
IndexAttemptSnapshot,
|
||||
FullUserSnapshot,
|
||||
InvitedUserSnapshot,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
cls,
|
||||
index_attempt_models: list[IndexAttempt],
|
||||
page: int,
|
||||
total_pages: int,
|
||||
) -> "PaginatedIndexAttempts":
|
||||
return cls(
|
||||
index_attempts=[
|
||||
IndexAttemptSnapshot.from_index_attempt_db_model(index_attempt_model)
|
||||
for index_attempt_model in index_attempt_models
|
||||
],
|
||||
page=page,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
class PaginatedReturn(BaseModel, Generic[PaginatedType]):
|
||||
items: list[PaginatedType]
|
||||
total_items: int
|
||||
|
||||
|
||||
class CCPairFullInfo(BaseModel):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
from typing import cast
|
||||
@@ -6,7 +7,9 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
@@ -28,6 +31,8 @@ router = APIRouter(prefix="/connector/oauth")
|
||||
|
||||
_OAUTH_STATE_KEY_FMT = "oauth_state:{state}"
|
||||
_OAUTH_STATE_EXPIRATION_SECONDS = 10 * 60 # 10 minutes
|
||||
_DESIRED_RETURN_URL_KEY = "desired_return_url"
|
||||
_ADDITIONAL_KWARGS_KEY = "additional_kwargs"
|
||||
|
||||
# Cache for OAuth connectors, populated at module load time
|
||||
_OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {}
|
||||
@@ -51,12 +56,36 @@ def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]:
|
||||
_discover_oauth_connectors()
|
||||
|
||||
|
||||
def _get_additional_kwargs(
|
||||
request: Request, connector_cls: type[OAuthConnector], args_to_ignore: list[str]
|
||||
) -> dict[str, str]:
|
||||
# get additional kwargs from request
|
||||
# e.g. anything except for desired_return_url
|
||||
additional_kwargs_dict = {
|
||||
k: v for k, v in request.query_params.items() if k not in args_to_ignore
|
||||
}
|
||||
try:
|
||||
# validate
|
||||
connector_cls.AdditionalOauthKwargs(**additional_kwargs_dict)
|
||||
except ValidationError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=(
|
||||
f"Invalid additional kwargs. Got {additional_kwargs_dict}, expected "
|
||||
f"{connector_cls.AdditionalOauthKwargs.model_json_schema()}"
|
||||
),
|
||||
)
|
||||
|
||||
return additional_kwargs_dict
|
||||
|
||||
|
||||
class AuthorizeResponse(BaseModel):
|
||||
redirect_url: str
|
||||
|
||||
|
||||
@router.get("/authorize/{source}")
|
||||
def oauth_authorize(
|
||||
request: Request,
|
||||
source: DocumentSource,
|
||||
desired_return_url: Annotated[str | None, Query()] = None,
|
||||
_: User = Depends(current_user),
|
||||
@@ -71,6 +100,12 @@ def oauth_authorize(
|
||||
connector_cls = oauth_connectors[source]
|
||||
base_url = WEB_DOMAIN
|
||||
|
||||
# get additional kwargs from request
|
||||
# e.g. anything except for desired_return_url
|
||||
additional_kwargs = _get_additional_kwargs(
|
||||
request, connector_cls, ["desired_return_url"]
|
||||
)
|
||||
|
||||
# store state in redis
|
||||
if not desired_return_url:
|
||||
desired_return_url = f"{base_url}/admin/connectors/{source}?step=0"
|
||||
@@ -78,12 +113,19 @@ def oauth_authorize(
|
||||
state = str(uuid.uuid4())
|
||||
redis_client.set(
|
||||
_OAUTH_STATE_KEY_FMT.format(state=state),
|
||||
desired_return_url,
|
||||
json.dumps(
|
||||
{
|
||||
_DESIRED_RETURN_URL_KEY: desired_return_url,
|
||||
_ADDITIONAL_KWARGS_KEY: additional_kwargs,
|
||||
}
|
||||
),
|
||||
ex=_OAUTH_STATE_EXPIRATION_SECONDS,
|
||||
)
|
||||
|
||||
return AuthorizeResponse(
|
||||
redirect_url=connector_cls.oauth_authorization_url(base_url, state)
|
||||
redirect_url=connector_cls.oauth_authorization_url(
|
||||
base_url, state, additional_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -110,15 +152,18 @@ def oauth_callback(
|
||||
|
||||
# get state from redis
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
original_url_bytes = cast(
|
||||
oauth_state_bytes = cast(
|
||||
bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state))
|
||||
)
|
||||
if not original_url_bytes:
|
||||
if not oauth_state_bytes:
|
||||
raise HTTPException(status_code=400, detail="Invalid OAuth state")
|
||||
original_url = original_url_bytes.decode("utf-8")
|
||||
oauth_state = json.loads(oauth_state_bytes.decode("utf-8"))
|
||||
|
||||
desired_return_url = cast(str, oauth_state[_DESIRED_RETURN_URL_KEY])
|
||||
additional_kwargs = cast(dict[str, str], oauth_state[_ADDITIONAL_KWARGS_KEY])
|
||||
|
||||
base_url = WEB_DOMAIN
|
||||
token_info = connector_cls.oauth_code_to_token(base_url, code)
|
||||
token_info = connector_cls.oauth_code_to_token(base_url, code, additional_kwargs)
|
||||
|
||||
# Create a new credential with the token info
|
||||
credential_data = CredentialBase(
|
||||
@@ -136,8 +181,52 @@ def oauth_callback(
|
||||
|
||||
return CallbackResponse(
|
||||
redirect_url=(
|
||||
f"{original_url}?credentialId={credential.id}"
|
||||
if "?" not in original_url
|
||||
else f"{original_url}&credentialId={credential.id}"
|
||||
f"{desired_return_url}?credentialId={credential.id}"
|
||||
if "?" not in desired_return_url
|
||||
else f"{desired_return_url}&credentialId={credential.id}"
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class OAuthAdditionalKwargDescription(BaseModel):
|
||||
name: str
|
||||
display_name: str
|
||||
description: str
|
||||
|
||||
|
||||
class OAuthDetails(BaseModel):
|
||||
oauth_enabled: bool
|
||||
additional_kwargs: list[OAuthAdditionalKwargDescription]
|
||||
|
||||
|
||||
@router.get("/details/{source}")
|
||||
def oauth_details(
|
||||
source: DocumentSource,
|
||||
_: User = Depends(current_user),
|
||||
) -> OAuthDetails:
|
||||
oauth_connectors = _discover_oauth_connectors()
|
||||
|
||||
if source not in oauth_connectors:
|
||||
return OAuthDetails(
|
||||
oauth_enabled=False,
|
||||
additional_kwargs=[],
|
||||
)
|
||||
|
||||
connector_cls = oauth_connectors[source]
|
||||
|
||||
additional_kwarg_descriptions = []
|
||||
for key, value in connector_cls.AdditionalOauthKwargs.model_json_schema()[
|
||||
"properties"
|
||||
].items():
|
||||
additional_kwarg_descriptions.append(
|
||||
OAuthAdditionalKwargDescription(
|
||||
name=key,
|
||||
display_name=value.get("title", key),
|
||||
description=value.get("description", ""),
|
||||
)
|
||||
)
|
||||
|
||||
return OAuthDetails(
|
||||
oauth_enabled=True,
|
||||
additional_kwargs=additional_kwarg_descriptions,
|
||||
)
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
@@ -323,7 +324,7 @@ def get_image_generation_tool(
|
||||
|
||||
@basic_router.get("")
|
||||
def list_personas(
|
||||
user: User | None = Depends(current_user),
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
include_deleted: bool = False,
|
||||
persona_ids: list[int] = Query(None),
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import user_needs_to_be_verified
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.server.manage.models import AuthTypeResponse
|
||||
@@ -18,7 +19,9 @@ def healthcheck() -> StatusResponse:
|
||||
@router.get("/auth/type")
|
||||
def get_auth_type() -> AuthTypeResponse:
|
||||
return AuthTypeResponse(
|
||||
auth_type=AUTH_TYPE, requires_verification=user_needs_to_be_verified()
|
||||
auth_type=AUTH_TYPE,
|
||||
requires_verification=user_needs_to_be_verified(),
|
||||
anonymous_user_enabled=anonymous_user_enabled(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user