mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 16:25:45 +00:00
Compare commits
3 Commits
mystery
...
cloud_debu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
09e6bd3c9c | ||
|
|
c1803cdd56 | ||
|
|
a5b9c76012 |
18
.github/pull_request_template.md
vendored
18
.github/pull_request_template.md
vendored
@@ -6,6 +6,24 @@
|
||||
[Describe the tests you ran to verify your changes]
|
||||
|
||||
|
||||
## Accepted Risk (provide if relevant)
|
||||
N/A
|
||||
|
||||
|
||||
## Related Issue(s) (provide if relevant)
|
||||
N/A
|
||||
|
||||
|
||||
## Mental Checklist:
|
||||
- All of the automated tests pass
|
||||
- All PR comments are addressed and marked resolved
|
||||
- If there are migrations, they have been rebased to latest main
|
||||
- If there are new dependencies, they are added to the requirements
|
||||
- If there are new environment variables, they are added to all of the deployment methods
|
||||
- If there are new APIs that don't require auth, they are added to PUBLIC_ENDPOINT_SPECS
|
||||
- Docker images build and basic functionalities work
|
||||
- Author has done a final read through of the PR right before merge
|
||||
|
||||
## Backporting (check the box to trigger backport action)
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
|
||||
@@ -66,7 +66,6 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
10
.github/workflows/pr-python-connector-tests.yml
vendored
10
.github/workflows/pr-python-connector-tests.yml
vendored
@@ -26,15 +26,7 @@ env:
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
# Slab
|
||||
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
|
||||
# Salesforce
|
||||
SF_USERNAME: ${{ secrets.SF_USERNAME }}
|
||||
SF_PASSWORD: ${{ secrets.SF_PASSWORD }}
|
||||
SF_SECURITY_TOKEN: ${{ secrets.SF_SECURITY_TOKEN }}
|
||||
# Airtable
|
||||
AIRTABLE_TEST_BASE_ID: ${{ secrets.AIRTABLE_TEST_BASE_ID }}
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
|
||||
16
README.md
16
README.md
@@ -3,7 +3,7 @@
|
||||
<a name="readme-top"></a>
|
||||
|
||||
<h2 align="center">
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/OnyxLogoCropped.jpg?raw=true)" /></a>
|
||||
<a href="https://www.onyx.app/"> <img width="50%" src="https://github.com/onyx-dot-app/onyx/blob/logo/LogoOnyx.png?raw=true)" /></a>
|
||||
</h2>
|
||||
|
||||
<p align="center">
|
||||
@@ -24,7 +24,7 @@
|
||||
</a>
|
||||
</p>
|
||||
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (Formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
|
||||
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
|
||||
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
|
||||
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
|
||||
@@ -133,3 +133,15 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#onyx-dot-app/onyx&Date)
|
||||
|
||||
## ✨Contributors
|
||||
|
||||
<a href="https://github.com/onyx-dot-app/onyx/graphs/contributors">
|
||||
<img alt="contributors" src="https://contrib.rocks/image?repo=onyx-dot-app/onyx"/>
|
||||
</a>
|
||||
|
||||
<p align="right" style="font-size: 14px; color: #555; margin-top: 20px;">
|
||||
<a href="#readme-top" style="text-decoration: none; color: #007bff; font-weight: bold;">
|
||||
↑ Back to Top ↑
|
||||
</a>
|
||||
</p>
|
||||
|
||||
@@ -4,7 +4,7 @@ from onyx.configs.app_configs import USE_IAM_AUTH
|
||||
from onyx.configs.app_configs import POSTGRES_HOST
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.app_configs import AWS_REGION_NAME
|
||||
from onyx.configs.app_configs import AWS_REGION
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from sqlalchemy import event
|
||||
@@ -120,7 +120,7 @@ def provide_iam_token_for_alembic(
|
||||
) -> None:
|
||||
if USE_IAM_AUTH:
|
||||
# Database connection settings
|
||||
region = AWS_REGION_NAME
|
||||
region = AWS_REGION
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
|
||||
@@ -122,7 +122,7 @@ def _cleanup_document_set__user_group_relationships__no_commit(
|
||||
)
|
||||
|
||||
|
||||
def validate_object_creation_for_user(
|
||||
def validate_user_creation_permissions(
|
||||
db_session: Session,
|
||||
user: User | None,
|
||||
target_group_ids: list[int] | None = None,
|
||||
@@ -440,108 +440,32 @@ def remove_curator_status__no_commit(db_session: Session, user: User) -> None:
|
||||
_validate_curator_status__no_commit(db_session, [user])
|
||||
|
||||
|
||||
def _validate_curator_relationship_update_requester(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
user_making_change: User | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
This function validates that the user making the change has the necessary permissions
|
||||
to update the curator relationship for the target user in the given user group.
|
||||
"""
|
||||
|
||||
if user_making_change is None or user_making_change.role == UserRole.ADMIN:
|
||||
return
|
||||
|
||||
# check if the user making the change is a curator in the group they are changing the curator relationship for
|
||||
user_making_change_curator_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user_making_change.id,
|
||||
# only check if the user making the change is a curator if they are a curator
|
||||
# otherwise, they are a global_curator and can update the curator relationship
|
||||
# for any group they are a member of
|
||||
only_curator_groups=user_making_change.role == UserRole.CURATOR,
|
||||
)
|
||||
requestor_curator_group_ids = [
|
||||
group.id for group in user_making_change_curator_groups
|
||||
]
|
||||
if user_group_id not in requestor_curator_group_ids:
|
||||
raise ValueError(
|
||||
f"user making change {user_making_change.email} is not a curator,"
|
||||
f" admin, or global_curator for group '{user_group_id}'"
|
||||
)
|
||||
|
||||
|
||||
def _validate_curator_relationship_update_request(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
target_user: User,
|
||||
) -> None:
|
||||
"""
|
||||
This function validates that the curator_relationship_update request itself is valid.
|
||||
"""
|
||||
if target_user.role == UserRole.ADMIN:
|
||||
raise ValueError(
|
||||
f"User '{target_user.email}' is an admin and therefore has all permissions "
|
||||
"of a curator. If you'd like this user to only have curator permissions, "
|
||||
"you must update their role to BASIC then assign them to be CURATOR in the "
|
||||
"appropriate groups."
|
||||
)
|
||||
elif target_user.role == UserRole.GLOBAL_CURATOR:
|
||||
raise ValueError(
|
||||
f"User '{target_user.email}' is a global_curator and therefore has all "
|
||||
"permissions of a curator for all groups. If you'd like this user to only "
|
||||
"have curator permissions for a specific group, you must update their role "
|
||||
"to BASIC then assign them to be CURATOR in the appropriate groups."
|
||||
)
|
||||
elif target_user.role not in [UserRole.CURATOR, UserRole.BASIC]:
|
||||
raise ValueError(
|
||||
f"This endpoint can only be used to update the curator relationship for "
|
||||
"users with the CURATOR or BASIC role. \n"
|
||||
f"Target user: {target_user.email} \n"
|
||||
f"Target user role: {target_user.role} \n"
|
||||
)
|
||||
|
||||
# check if the target user is in the group they are changing the curator relationship for
|
||||
requested_user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=target_user.id,
|
||||
only_curator_groups=False,
|
||||
)
|
||||
group_ids = [group.id for group in requested_user_groups]
|
||||
if user_group_id not in group_ids:
|
||||
raise ValueError(
|
||||
f"target user {target_user.email} is not in group '{user_group_id}'"
|
||||
)
|
||||
|
||||
|
||||
def update_user_curator_relationship(
|
||||
db_session: Session,
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
user_making_change: User | None = None,
|
||||
) -> None:
|
||||
target_user = fetch_user_by_id(db_session, set_curator_request.user_id)
|
||||
if not target_user:
|
||||
user = fetch_user_by_id(db_session, set_curator_request.user_id)
|
||||
if not user:
|
||||
raise ValueError(f"User with id '{set_curator_request.user_id}' not found")
|
||||
|
||||
_validate_curator_relationship_update_request(
|
||||
if user.role == UserRole.ADMIN:
|
||||
raise ValueError(
|
||||
f"User '{user.email}' is an admin and therefore has all permissions "
|
||||
"of a curator. If you'd like this user to only have curator permissions, "
|
||||
"you must update their role to BASIC then assign them to be CURATOR in the "
|
||||
"appropriate groups."
|
||||
)
|
||||
|
||||
requested_user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
target_user=target_user,
|
||||
user_id=set_curator_request.user_id,
|
||||
only_curator_groups=False,
|
||||
)
|
||||
|
||||
_validate_curator_relationship_update_requester(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
user_making_change=user_making_change,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"user_making_change={user_making_change.email if user_making_change else 'None'} is "
|
||||
f"updating the curator relationship for user={target_user.email} "
|
||||
f"in group={user_group_id} to is_curator={set_curator_request.is_curator}"
|
||||
)
|
||||
group_ids = [group.id for group in requested_user_groups]
|
||||
if user_group_id not in group_ids:
|
||||
raise ValueError(f"user is not in group '{user_group_id}'")
|
||||
|
||||
relationship_to_update = (
|
||||
db_session.query(User__UserGroup)
|
||||
@@ -562,7 +486,7 @@ def update_user_curator_relationship(
|
||||
)
|
||||
db_session.add(relationship_to_update)
|
||||
|
||||
_validate_curator_status__no_commit(db_session, [target_user])
|
||||
_validate_curator_status__no_commit(db_session, [user])
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -40,7 +40,6 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.main import get_application as get_application_base
|
||||
from onyx.main import include_auth_router_with_prefix
|
||||
from onyx.main import include_router_with_global_prefix_prepended
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
@@ -63,7 +62,7 @@ def get_application() -> FastAPI:
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
oauth_client,
|
||||
@@ -75,17 +74,19 @@ def get_application() -> FastAPI:
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
|
||||
),
|
||||
prefix="/auth/oauth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# Need basic auth router for `logout` endpoint
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_logout_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.OIDC:
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
|
||||
@@ -96,21 +97,19 @@ def get_application() -> FastAPI:
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback",
|
||||
),
|
||||
prefix="/auth/oidc",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# need basic auth router for `logout` endpoint
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_auth_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
elif AUTH_TYPE == AuthType.SAML:
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
saml_router,
|
||||
prefix="/auth/saml",
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, saml_router)
|
||||
|
||||
# RBAC / group access control
|
||||
include_router_with_global_prefix_prepended(application, user_group_router)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
@@ -12,29 +10,11 @@ from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
|
||||
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
@@ -82,7 +62,14 @@ class SlackOAuth:
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={cls.REDIRECT_URI}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
@@ -90,14 +77,10 @@ class SlackOAuth:
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&redirect_uri={cls.DEV_REDIRECT_URI}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
@@ -119,151 +102,82 @@ class SlackOAuth:
|
||||
return session
|
||||
|
||||
|
||||
class ConfluenceCloudOAuth:
|
||||
"""work in progress"""
|
||||
# Work in progress
|
||||
# class ConfluenceCloudOAuth:
|
||||
# """work in progress"""
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
# # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
# class OAuthSession(BaseModel):
|
||||
# """Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
# email: str
|
||||
# redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
# CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
|
||||
# CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
# TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
|
||||
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
CONFLUENCE_OAUTH_SCOPE = (
|
||||
"read:confluence-props%20"
|
||||
"read:confluence-content.all%20"
|
||||
"read:confluence-content.summary%20"
|
||||
"read:confluence-content.permission%20"
|
||||
"read:confluence-user%20"
|
||||
"read:confluence-groups%20"
|
||||
"readonly:content.attachment:confluence"
|
||||
)
|
||||
# # All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
# CONFLUENCE_OAUTH_SCOPE = (
|
||||
# "read:confluence-props%20"
|
||||
# "read:confluence-content.all%20"
|
||||
# "read:confluence-content.summary%20"
|
||||
# "read:confluence-content.permission%20"
|
||||
# "read:confluence-user%20"
|
||||
# "read:confluence-groups%20"
|
||||
# "readonly:content.attachment:confluence"
|
||||
# )
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
# REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
# DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
# eventually for Confluence Data Center
|
||||
# oauth_url = (
|
||||
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# f"&redirect_uri={redirectme_uri}"
|
||||
# )
|
||||
# # eventually for Confluence Data Center
|
||||
# # oauth_url = (
|
||||
# # f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# # f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# # f"&redirect_uri={redirectme_uri}"
|
||||
# # )
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
# @classmethod
|
||||
# def generate_oauth_url(cls, state: str) -> str:
|
||||
# return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
# @classmethod
|
||||
# def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
# """dev mode workaround for localhost testing
|
||||
# - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
# """
|
||||
# return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
"https://auth.atlassian.com/authorize"
|
||||
f"?audience=api.atlassian.com"
|
||||
f"&client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
f"&state={state}"
|
||||
"&response_type=code"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
# @classmethod
|
||||
# def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# url = (
|
||||
# "https://auth.atlassian.com/authorize"
|
||||
# f"?audience=api.atlassian.com"
|
||||
# f"&client_id={cls.CLIENT_ID}"
|
||||
# f"&redirect_uri={redirect_uri}"
|
||||
# f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
# f"&state={state}"
|
||||
# "&response_type=code"
|
||||
# "&prompt=consent"
|
||||
# )
|
||||
# return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = ConfluenceCloudOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
# @classmethod
|
||||
# def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
# """Temporary state to store in redis. to be looked up on auth response.
|
||||
# Returns a json string.
|
||||
# """
|
||||
# session = ConfluenceCloudOAuth.OAuthSession(
|
||||
# email=email, redirect_on_success=redirect_on_success
|
||||
# )
|
||||
# return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
class GoogleDriveOAuth:
|
||||
# https://developers.google.com/identity/protocols/oauth2
|
||||
# https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# SCOPE is per https://docs.onyx.app/connectors/google-drive
|
||||
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
|
||||
SCOPE = (
|
||||
"https://www.googleapis.com/auth/drive.readonly%20"
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# without prompt=consent, a refresh token is only issued the first time the user approves
|
||||
url = (
|
||||
f"https://accounts.google.com/o/oauth2/v2/auth"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
"&response_type=code"
|
||||
f"&scope={cls.SCOPE}"
|
||||
"&access_type=offline"
|
||||
f"&state={state}"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = GoogleDriveOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
# @classmethod
|
||||
# def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
|
||||
# session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
# return session
|
||||
|
||||
|
||||
@router.post("/prepare-authorization-request")
|
||||
@@ -278,11 +192,8 @@ def prepare_authorization_request(
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# urlsafe b64 encode the uuid for the oauth url
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
@@ -292,11 +203,6 @@ def prepare_authorization_request(
|
||||
session = SlackOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
|
||||
session = GoogleDriveOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
# elif connector == DocumentSource.CONFLUENCE:
|
||||
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
|
||||
# session = ConfluenceCloudOAuth.session_dump_json(
|
||||
@@ -304,6 +210,8 @@ def prepare_authorization_request(
|
||||
# )
|
||||
# elif connector == DocumentSource.JIRA:
|
||||
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
|
||||
# elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
# oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = None
|
||||
|
||||
@@ -315,7 +223,6 @@ def prepare_authorization_request(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# store important session state to retrieve when the user is redirected back
|
||||
# 10 min is the max we want an oauth flow to be valid
|
||||
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
|
||||
|
||||
@@ -514,116 +421,3 @@ def handle_slack_oauth_callback(
|
||||
# "redirect_on_success": session.redirect_on_success,
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
@router.post("/connector/google-drive/callback")
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = GoogleDriveOAuth.parse_session(session_json)
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
GoogleDriveOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": GoogleDriveOAuth.CLIENT_ID,
|
||||
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
authorization_response: dict[str, Any] = response.json()
|
||||
|
||||
# the connector wants us to store the json in its authorized_user_info format
|
||||
# returned from OAuthCredentials.get_authorized_user_info().
|
||||
# So refresh immediately via get_google_oauth_creds with the params filled in
|
||||
# from fields in authorization_response to get the json we need
|
||||
authorized_user_info = {}
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
|
||||
)
|
||||
if not oauth_creds:
|
||||
raise RuntimeError("get_google_oauth_creds returned None.")
|
||||
|
||||
# save off the credentials
|
||||
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
|
||||
credential_dict: dict[str, str] = {}
|
||||
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
|
||||
credential_dict[
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD
|
||||
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
name="OAuth (interactive)",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Google Drive OAuth completed successfully.",
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
@@ -13,23 +12,15 @@ from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_jwt_strategy
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.settings.store import store_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -123,48 +114,3 @@ async def impersonate_user(
|
||||
samesite="lax",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/leave-organization")
|
||||
async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> None:
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
)
|
||||
|
||||
user_to_delete = get_user_by_email(user_email.user_email, db_session)
|
||||
if user_to_delete is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
num_admin_users = await get_user_count(only_admin_users=True)
|
||||
|
||||
should_delete_tenant = num_admin_users == 1
|
||||
|
||||
if should_delete_tenant:
|
||||
logger.info(
|
||||
"Last admin user is leaving the organization. Deleting tenant from control plane."
|
||||
)
|
||||
try:
|
||||
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
|
||||
logger.debug("User deleted from control plane")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to remove user from control plane: {str(e)}",
|
||||
)
|
||||
|
||||
db_session.expunge(user_to_delete)
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
|
||||
if should_delete_tenant:
|
||||
remove_all_users_from_tenant(tenant_id)
|
||||
else:
|
||||
remove_users_from_tenant([user_to_delete.email], tenant_id)
|
||||
|
||||
@@ -39,8 +39,3 @@ class TenantCreationPayload(BaseModel):
|
||||
tenant_id: str
|
||||
email: str
|
||||
referral_source: str | None = None
|
||||
|
||||
|
||||
class TenantDeletionPayload(BaseModel):
|
||||
tenant_id: str
|
||||
email: str
|
||||
|
||||
@@ -15,7 +15,6 @@ from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
|
||||
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import TenantCreationPayload
|
||||
from ee.onyx.server.tenants.models import TenantDeletionPayload
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
from ee.onyx.server.tenants.schema_management import drop_schema
|
||||
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
|
||||
@@ -186,7 +185,6 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
try:
|
||||
# Drop the tenant's schema to rollback provisioning
|
||||
drop_schema(tenant_id)
|
||||
|
||||
# Remove tenant mapping
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
@@ -322,26 +320,3 @@ async def submit_to_hubspot(
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to submit to HubSpot: {response.text}")
|
||||
|
||||
|
||||
async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = TenantDeletionPayload(tenant_id=tenant_id, email=email)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.delete(
|
||||
f"{CONTROL_PLANE_API_BASE_URL}/tenants/delete",
|
||||
headers=headers,
|
||||
json=payload.model_dump(),
|
||||
) as response:
|
||||
print(response)
|
||||
if response.status != 200:
|
||||
error_text = await response.text()
|
||||
logger.error(f"Control plane tenant creation failed: {error_text}")
|
||||
raise Exception(
|
||||
f"Failed to delete tenant on control plane: {error_text}"
|
||||
)
|
||||
|
||||
@@ -68,11 +68,3 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
|
||||
)
|
||||
db_session.rollback()
|
||||
|
||||
|
||||
def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
@@ -83,7 +83,7 @@ def patch_user_group(
|
||||
def set_user_curator(
|
||||
user_group_id: int,
|
||||
set_curator_request: SetCuratorRequest,
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
@@ -91,7 +91,6 @@ def set_user_curator(
|
||||
db_session=db_session,
|
||||
user_group_id=user_group_id,
|
||||
set_curator_request=set_curator_request,
|
||||
user_making_change=user,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Error setting user curator: {e}")
|
||||
|
||||
@@ -10,7 +10,6 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def posthog_on_error(error: Any, items: Any) -> None:
|
||||
"""Log any PostHog delivery errors."""
|
||||
logger.error(f"PostHog error: {error}, items: {items}")
|
||||
|
||||
|
||||
@@ -25,10 +24,15 @@ posthog = Posthog(
|
||||
def event_telemetry(
|
||||
distinct_id: str, event: str, properties: dict | None = None
|
||||
) -> None:
|
||||
"""Capture and send an event to PostHog, flushing immediately."""
|
||||
logger.info(f"Capturing PostHog event: {distinct_id} {event} {properties}")
|
||||
logger.info(f"Capturing Posthog event: {distinct_id} {event} {properties}")
|
||||
print("API KEY", POSTHOG_API_KEY)
|
||||
print("HOST", POSTHOG_HOST)
|
||||
try:
|
||||
posthog.capture(distinct_id, event, properties)
|
||||
print(type(distinct_id))
|
||||
print(type(event))
|
||||
print(type(properties))
|
||||
response = posthog.capture(distinct_id, event, properties)
|
||||
posthog.flush()
|
||||
print(response)
|
||||
except Exception as e:
|
||||
logger.error(f"Error capturing PostHog event: {e}")
|
||||
logger.error(f"Error capturing Posthog event: {e}")
|
||||
|
||||
@@ -44,7 +44,6 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
|
||||
the files in the existing huggingface cache that don't exist in the temp
|
||||
huggingface cache.
|
||||
"""
|
||||
|
||||
for item in source.iterdir():
|
||||
target_path = dest / item.relative_to(source)
|
||||
if item.is_dir():
|
||||
|
||||
@@ -1,80 +0,0 @@
|
||||
import smtplib
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from textwrap import dedent
|
||||
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
from onyx.configs.app_configs import SMTP_PASS
|
||||
from onyx.configs.app_configs import SMTP_PORT
|
||||
from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
def send_email(
|
||||
user_email: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
if not EMAIL_CONFIGURED:
|
||||
raise ValueError("Email is not configured.")
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
msg.attach(MIMEText(body))
|
||||
|
||||
try:
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
s.starttls()
|
||||
s.login(SMTP_USER, SMTP_PASS)
|
||||
s.send_message(msg)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
||||
subject = "Invitation to Join Onyx Workspace"
|
||||
body = dedent(
|
||||
f"""\
|
||||
Hello,
|
||||
|
||||
You have been invited to join a workspace on Onyx.
|
||||
|
||||
To join the workspace, please visit the following link:
|
||||
|
||||
{WEB_DOMAIN}/auth/login
|
||||
|
||||
Best regards,
|
||||
The Onyx Team
|
||||
"""
|
||||
)
|
||||
send_email(user_email, subject, body, current_user.email)
|
||||
|
||||
|
||||
def send_forgot_password_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
body = f"Click the following link to reset your password: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
subject = "Onyx Email Verification"
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
body = f"Click the following link to verify your email address: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
@@ -30,16 +30,13 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
)
|
||||
|
||||
|
||||
def fetch_no_auth_user(
|
||||
store: KeyValueStore, *, anonymous_user_enabled: bool | None = None
|
||||
) -> UserInfo:
|
||||
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
|
||||
return UserInfo(
|
||||
id=NO_AUTH_USER_ID,
|
||||
email=NO_AUTH_USER_EMAIL,
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
|
||||
role=UserRole.ADMIN,
|
||||
preferences=load_no_auth_user_preferences(store),
|
||||
is_anonymous_user=anonymous_user_enabled,
|
||||
)
|
||||
|
||||
@@ -49,7 +49,4 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
"""
|
||||
Role updates are not allowed through the user update endpoint for security reasons
|
||||
Role changes should be handled through a separate, admin-only process
|
||||
"""
|
||||
role: UserRole
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import smtplib
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
@@ -50,17 +53,19 @@ from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.email_utils import send_forgot_password_email
|
||||
from onyx.auth.email_utils import send_user_verification_email
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserUpdate
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import SMTP_PASS
|
||||
from onyx.configs.app_configs import SMTP_PORT
|
||||
from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
@@ -69,7 +74,6 @@ from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.api_key import fetch_user_for_api_key
|
||||
@@ -85,7 +89,7 @@ from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
@@ -99,11 +103,6 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
@@ -144,20 +143,6 @@ def user_needs_to_be_verified() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def anonymous_user_enabled() -> bool:
|
||||
if MULTI_TENANT:
|
||||
return False
|
||||
|
||||
redis_client = get_redis_client(tenant_id=None)
|
||||
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
|
||||
if value is None:
|
||||
return False
|
||||
|
||||
assert isinstance(value, bytes)
|
||||
return int(value.decode("utf-8")) == 1
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
whitelist = get_invited_users()
|
||||
if not whitelist:
|
||||
@@ -208,6 +193,30 @@ def verify_email_domain(email: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = "Onyx Email Verification"
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
|
||||
body = MIMEText(f"Click the following link to verify your email address: {link}")
|
||||
msg.attach(body)
|
||||
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
s.starttls()
|
||||
# If credentials fails with gmail, check (You need an app password, not just the basic email password)
|
||||
# https://support.google.com/accounts/answer/185833?sjid=8512343437447396151-NA
|
||||
s.login(SMTP_USER, SMTP_PASS)
|
||||
s.send_message(msg)
|
||||
|
||||
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
@@ -272,6 +281,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
if not user.role.is_web_login() and user_create.role.is_web_login():
|
||||
user_update = UserUpdate(
|
||||
password=user_create.password,
|
||||
role=user_create.role,
|
||||
is_verified=user_create.is_verified,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
@@ -496,15 +506,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async def on_after_forgot_password(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
) -> None:
|
||||
if not EMAIL_CONFIGURED:
|
||||
logger.error(
|
||||
"Email is not configured. Please configure email in the admin panel"
|
||||
)
|
||||
raise HTTPException(
|
||||
status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
"Your admin has not enbaled this feature.",
|
||||
)
|
||||
send_forgot_password_email(user.email, token)
|
||||
logger.notice(f"User {user.id} has forgot their password. Reset token: {token}")
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
@@ -622,7 +624,9 @@ def get_database_strategy(
|
||||
|
||||
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="jwt", transport=cookie_transport, get_strategy=get_jwt_strategy
|
||||
name="jwt" if MULTI_TENANT else "database",
|
||||
transport=cookie_transport,
|
||||
get_strategy=get_jwt_strategy if MULTI_TENANT else get_database_strategy, # type: ignore
|
||||
) # type: ignore
|
||||
|
||||
|
||||
@@ -709,36 +713,30 @@ async def double_check_user(
|
||||
user: User | None,
|
||||
optional: bool = DISABLE_AUTH,
|
||||
include_expired: bool = False,
|
||||
allow_anonymous_access: bool = False,
|
||||
) -> User | None:
|
||||
if optional:
|
||||
return user
|
||||
|
||||
if user is not None:
|
||||
# If user attempted to authenticate, verify them, do not default
|
||||
# to anonymous access if it fails.
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if (
|
||||
user.oidc_expiry
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
if allow_anonymous_access:
|
||||
return None
|
||||
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
if user is None:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not authenticated.",
|
||||
)
|
||||
|
||||
if user_needs_to_be_verified() and not user.is_verified:
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User is not verified.",
|
||||
)
|
||||
|
||||
if (
|
||||
user.oidc_expiry
|
||||
and user.oidc_expiry < datetime.now(timezone.utc)
|
||||
and not include_expired
|
||||
):
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def current_user_with_expired_token(
|
||||
@@ -753,14 +751,6 @@ async def current_limited_user(
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_chat_accesssible_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
return await double_check_user(
|
||||
user, allow_anonymous_access=anonymous_user_enabled()
|
||||
)
|
||||
|
||||
|
||||
async def current_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
|
||||
@@ -414,21 +414,11 @@ def on_setup_logging(
|
||||
task_logger.setLevel(loglevel)
|
||||
task_logger.propagate = False
|
||||
|
||||
# hide celery task received spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] received"
|
||||
# Hide celery task received and succeeded/failed messages
|
||||
strategy.logger.setLevel(logging.WARNING)
|
||||
|
||||
# uncomment this to hide celery task succeeded/failed spam
|
||||
# e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] succeeded in 0.03137450001668185s: None"
|
||||
trace.logger.setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def set_task_finished_log_level(logLevel: int) -> None:
|
||||
"""call this to override the setLevel in on_setup_logging. We are interested
|
||||
in the task timings in the cloud but it can be spammy for self hosted."""
|
||||
trace.logger.setLevel(logLevel)
|
||||
|
||||
|
||||
class TenantContextFilter(logging.Filter):
|
||||
|
||||
"""Logging filter to inject tenant ID into the logger's name."""
|
||||
|
||||
@@ -60,12 +60,7 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_APP_NAME)
|
||||
|
||||
# rkuo: been seeing transient connection exceptions here, so upping the connection count
|
||||
# from just concurrency/concurrency to concurrency/concurrency*2
|
||||
SqlEngine.init_engine(
|
||||
pool_size=sender.concurrency, max_overflow=sender.concurrency * 2
|
||||
)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=sender.concurrency)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -195,10 +194,6 @@ def on_setup_logging(
|
||||
) -> None:
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
# this can be spammy, so just enable it in the cloud for now
|
||||
if MULTI_TENANT:
|
||||
app_base.set_task_finished_log_level(logging.INFO)
|
||||
|
||||
|
||||
class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
"""Regularly reacquires the primary worker lock outside of the task queue.
|
||||
|
||||
@@ -3,54 +3,12 @@ import json
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.configs.base import CELERY_SEPARATOR
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
|
||||
|
||||
def celery_get_unacked_length(r: Redis) -> int:
|
||||
"""Checking the unacked queue is useful because a non-zero length tells us there
|
||||
may be prefetched tasks.
|
||||
|
||||
There can be other tasks in here besides indexing tasks, so this is mostly useful
|
||||
just to see if the task count is non zero.
|
||||
|
||||
ref: https://blog.hikaru.run/2022/08/29/get-waiting-tasks-count-in-celery.html
|
||||
"""
|
||||
length = cast(int, r.hlen("unacked"))
|
||||
return length
|
||||
|
||||
|
||||
def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]:
|
||||
"""Gets the set of task id's matching the given queue in the unacked hash.
|
||||
|
||||
Unacked entries belonging to the indexing queue are "prefetched", so this gives
|
||||
us crucial visibility as to what tasks are in that state.
|
||||
"""
|
||||
tasks: set[str] = set()
|
||||
|
||||
for _, v in r.hscan_iter("unacked"):
|
||||
v_bytes = cast(bytes, v)
|
||||
v_str = v_bytes.decode("utf-8")
|
||||
task = json.loads(v_str)
|
||||
|
||||
task_description = task[0]
|
||||
task_queue = task[2]
|
||||
|
||||
if task_queue != queue:
|
||||
continue
|
||||
|
||||
task_id = task_description.get("headers", {}).get("id")
|
||||
if not task_id:
|
||||
continue
|
||||
|
||||
# if the queue matches and we see the task_id, add it
|
||||
tasks.add(task_id)
|
||||
return tasks
|
||||
|
||||
|
||||
def celery_get_queue_length(queue: str, r: Redis) -> int:
|
||||
"""This is a redis specific way to get the length of a celery queue.
|
||||
It is priority aware and knows how to count across the multiple redis lists
|
||||
@@ -89,74 +47,3 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
|
||||
"""Returns a list of current workers containing name_filter, or all workers if
|
||||
name_filter is None.
|
||||
|
||||
We've empirically discovered that the celery inspect API is potentially unstable
|
||||
and may hang or return empty results when celery is under load. Suggest using this
|
||||
more to debug and troubleshoot than in production code.
|
||||
"""
|
||||
worker_names: list[str] = []
|
||||
|
||||
# filter for and create an indexing specific inspect object
|
||||
inspect = app.control.inspect()
|
||||
workers: dict[str, Any] = inspect.ping() # type: ignore
|
||||
if workers:
|
||||
for worker_name in list(workers.keys()):
|
||||
# if the name filter not set, return all worker names
|
||||
if not name_filter:
|
||||
worker_names.append(worker_name)
|
||||
continue
|
||||
|
||||
# if the name filter is set, return only worker names that contain the name filter
|
||||
if name_filter not in worker_name:
|
||||
continue
|
||||
|
||||
worker_names.append(worker_name)
|
||||
|
||||
return worker_names
|
||||
|
||||
|
||||
def celery_inspect_get_reserved(worker_names: list[str], app: Celery) -> set[str]:
|
||||
"""Returns a list of reserved tasks on the specified workers.
|
||||
|
||||
We've empirically discovered that the celery inspect API is potentially unstable
|
||||
and may hang or return empty results when celery is under load. Suggest using this
|
||||
more to debug and troubleshoot than in production code.
|
||||
"""
|
||||
reserved_task_ids: set[str] = set()
|
||||
|
||||
inspect = app.control.inspect(destination=worker_names)
|
||||
|
||||
# get the list of reserved tasks
|
||||
reserved_tasks: dict[str, list] | None = inspect.reserved() # type: ignore
|
||||
if reserved_tasks:
|
||||
for _, task_list in reserved_tasks.items():
|
||||
for task in task_list:
|
||||
reserved_task_ids.add(task["id"])
|
||||
|
||||
return reserved_task_ids
|
||||
|
||||
|
||||
def celery_inspect_get_active(worker_names: list[str], app: Celery) -> set[str]:
|
||||
"""Returns a list of active tasks on the specified workers.
|
||||
|
||||
We've empirically discovered that the celery inspect API is potentially unstable
|
||||
and may hang or return empty results when celery is under load. Suggest using this
|
||||
more to debug and troubleshoot than in production code.
|
||||
"""
|
||||
active_task_ids: set[str] = set()
|
||||
|
||||
inspect = app.control.inspect(destination=worker_names)
|
||||
|
||||
# get the list of reserved tasks
|
||||
active_tasks: dict[str, list] | None = inspect.active() # type: ignore
|
||||
if active_tasks:
|
||||
for _, task_list in active_tasks.items():
|
||||
for task in task_list:
|
||||
active_task_ids.add(task["id"])
|
||||
|
||||
return active_task_ids
|
||||
|
||||
@@ -16,14 +16,6 @@ result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
|
||||
# Indexing worker specific ... this lets us track the transition to STARTED in redis
|
||||
# We don't currently rely on this but it has the potential to be useful and
|
||||
# indexing tasks are not high volume
|
||||
|
||||
# we don't turn this on yet because celery occasionally runs tasks more than once
|
||||
# which means a duplicate run might change the task state unexpectedly
|
||||
# task_track_started = True
|
||||
|
||||
worker_concurrency = CELERY_WORKER_INDEXING_CONCURRENCY
|
||||
worker_pool = "threads"
|
||||
worker_prefetch_multiplier = 1
|
||||
|
||||
@@ -4,12 +4,6 @@ from typing import Any
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
|
||||
# choosing 15 minutes because it roughly gives us enough time to process many tasks
|
||||
# we might be able to reduce this greatly if we can run a unified
|
||||
# loop across all tenants rather than tasks per tenant
|
||||
|
||||
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# we set expires because it isn't necessary to queue up these tasks
|
||||
# it's only important that they run relatively regularly
|
||||
tasks_to_schedule = [
|
||||
@@ -19,7 +13,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -28,7 +22,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -37,7 +31,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -46,7 +40,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -55,7 +49,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOWEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -64,7 +58,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -73,7 +67,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -82,7 +76,7 @@ tasks_to_schedule = [
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"expires": 60,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
@@ -34,9 +34,7 @@ class TaskDependencyError(RuntimeError):
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_connector_deletion_task(
|
||||
self: Task, *, tenant_id: str | None
|
||||
) -> bool | None:
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -47,7 +45,7 @@ def check_for_connector_deletion_task(
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
return
|
||||
|
||||
# collect cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
@@ -83,8 +81,6 @@ def check_for_connector_deletion_task(
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_generate_document_cc_pair_cleanup_tasks(
|
||||
app: Celery,
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from time import sleep
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
@@ -20,7 +18,6 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -91,10 +88,10 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
lock_beat = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -102,7 +99,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
return
|
||||
|
||||
# get all cc pairs that need to be synced
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
@@ -131,8 +128,6 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
@@ -224,43 +219,6 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_permission_sync_generator_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector.permissions.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_permission_sync_generator_task - fence not found: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector.permissions.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError(
|
||||
"connector_permission_sync_generator_task: payload invalid or not found"
|
||||
)
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_permission_sync_generator_task - Waiting for fence: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"connector_permission_sync_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
@@ -296,11 +254,8 @@ def connector_permission_sync_generator_task(
|
||||
if not payload:
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
|
||||
new_payload = RedisConnectorPermissionSyncPayload(
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=payload.celery_task_id,
|
||||
)
|
||||
redis_connector.permissions.set_fence(new_payload)
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
||||
|
||||
|
||||
@@ -94,10 +94,10 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
lock_beat = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -105,7 +105,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
return
|
||||
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -149,8 +149,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_external_group_sync_task(
|
||||
app: Celery,
|
||||
@@ -164,7 +162,7 @@ def try_creating_external_group_sync_task(
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
lock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
from typing import cast
|
||||
from typing import Any
|
||||
|
||||
import redis
|
||||
import sentry_sdk
|
||||
@@ -20,12 +18,10 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.indexing.job_client import SimpleJobClient
|
||||
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -33,7 +29,6 @@ from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
@@ -74,18 +69,14 @@ logger = setup_logger()
|
||||
|
||||
|
||||
class IndexingCallback(IndexingHeartbeatInterface):
|
||||
PARENT_CHECK_INTERVAL = 60
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_pid: int,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_pid = parent_pid
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
@@ -96,68 +87,25 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
self.last_tag: str = "IndexingCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
|
||||
self.last_parent_check = time.monotonic()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
|
||||
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
|
||||
# so leave this code in until we're ready to test it.
|
||||
|
||||
# if self.parent_pid:
|
||||
# # check if the parent pid is alive so we aren't running as a zombie
|
||||
# now = time.monotonic()
|
||||
# if now - self.last_parent_check > IndexingCallback.PARENT_CHECK_INTERVAL:
|
||||
# try:
|
||||
# # this is unintuitive, but it checks if the parent pid is still running
|
||||
# os.kill(self.parent_pid, 0)
|
||||
# except Exception:
|
||||
# logger.exception("IndexingCallback - parent pid check exceptioned")
|
||||
# raise
|
||||
# self.last_parent_check = now
|
||||
|
||||
try:
|
||||
self.redis_lock.reacquire()
|
||||
self.last_tag = tag
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"IndexingCallback - lock.reacquire exceptioned: "
|
||||
f"IndexingCallback - lock.reacquire exceptioned. "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_tag={self.last_tag} "
|
||||
f"last_reacquired={self.last_lock_reacquire} "
|
||||
f"now={datetime.now(timezone.utc)}"
|
||||
)
|
||||
|
||||
# diagnostic logging for lock errors
|
||||
name = self.redis_lock.name
|
||||
ttl = self.redis_client.ttl(name)
|
||||
locked = self.redis_lock.locked()
|
||||
owned = self.redis_lock.owned()
|
||||
local_token: str | None = self.redis_lock.local.token # type: ignore
|
||||
|
||||
remote_token_raw = self.redis_client.get(self.redis_lock.name)
|
||||
if remote_token_raw:
|
||||
remote_token_bytes = cast(bytes, remote_token_raw)
|
||||
remote_token = remote_token_bytes.decode("utf-8")
|
||||
else:
|
||||
remote_token = None
|
||||
|
||||
logger.warning(
|
||||
f"IndexingCallback - lock diagnostics: "
|
||||
f"name={name} "
|
||||
f"locked={locked} "
|
||||
f"owned={owned} "
|
||||
f"local_token={local_token} "
|
||||
f"remote_token={remote_token} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
raise
|
||||
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
@@ -227,7 +175,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
# redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK,
|
||||
@@ -370,19 +318,23 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
|
||||
# clear any indexing fences that don't have associated celery tasks in progress
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_indexing_fences(
|
||||
tenant_id, self.app, redis_client, redis_client_celery, lock_beat
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating indexing fences")
|
||||
# rkuo: The following code logically appears to work, but the celery inspect code may be unstable
|
||||
# turning off for the moment to see if it helps cloud stability
|
||||
|
||||
redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
|
||||
# we want to run this less frequently than the overall task
|
||||
# if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
|
||||
# # clear any indexing fences that don't have associated celery tasks in progress
|
||||
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# # or be currently executing
|
||||
# try:
|
||||
# task_logger.info("Validating indexing fences...")
|
||||
# validate_indexing_fences(
|
||||
# tenant_id, self.app, redis_client, redis_client_celery, lock_beat
|
||||
# )
|
||||
# except Exception:
|
||||
# task_logger.exception("Exception while validating indexing fences")
|
||||
|
||||
# redis_client.set(OnyxRedisSignals.VALIDATE_INDEXING_FENCES, 1, ex=60)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
@@ -401,7 +353,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.debug(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
|
||||
task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}")
|
||||
return tasks_created
|
||||
|
||||
|
||||
@@ -412,9 +364,46 @@ def validate_indexing_fences(
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
reserved_indexing_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
reserved_indexing_tasks: set[str] = set()
|
||||
active_indexing_tasks: set[str] = set()
|
||||
indexing_worker_names: list[str] = []
|
||||
|
||||
# filter for and create an indexing specific inspect object
|
||||
inspect = celery_app.control.inspect()
|
||||
workers: dict[str, Any] = inspect.ping() # type: ignore
|
||||
if not workers:
|
||||
raise ValueError("No workers found!")
|
||||
|
||||
for worker_name in list(workers.keys()):
|
||||
if "indexing" in worker_name:
|
||||
indexing_worker_names.append(worker_name)
|
||||
|
||||
if len(indexing_worker_names) == 0:
|
||||
raise ValueError("No indexing workers found!")
|
||||
|
||||
inspect_indexing = celery_app.control.inspect(destination=indexing_worker_names)
|
||||
|
||||
# NOTE: each dict entry is a map of worker name to a list of tasks
|
||||
# we want sets for reserved task and active task id's to optimize
|
||||
# subsequent validation lookups
|
||||
|
||||
# get the list of reserved tasks
|
||||
reserved_tasks: dict[str, list] | None = inspect_indexing.reserved() # type: ignore
|
||||
if reserved_tasks is None:
|
||||
raise ValueError("inspect_indexing.reserved() returned None!")
|
||||
|
||||
for _, task_list in reserved_tasks.items():
|
||||
for task in task_list:
|
||||
reserved_indexing_tasks.add(task["id"])
|
||||
|
||||
# get the list of active tasks
|
||||
active_tasks: dict[str, list] | None = inspect_indexing.active() # type: ignore
|
||||
if active_tasks is None:
|
||||
raise ValueError("inspect_indexing.active() returned None!")
|
||||
|
||||
for _, task_list in active_tasks.items():
|
||||
for task in task_list:
|
||||
active_indexing_tasks.add(task["id"])
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
@@ -424,6 +413,7 @@ def validate_indexing_fences(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_indexing_tasks,
|
||||
active_indexing_tasks,
|
||||
r_celery,
|
||||
db_session,
|
||||
)
|
||||
@@ -434,6 +424,7 @@ def validate_indexing_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
active_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
@@ -443,15 +434,11 @@ def validate_indexing_fence(
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1. Active signal is renewed with a 5 minute TTL
|
||||
1.1 When the fence is created
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
1.3. When the task is seen in the reserved or active list for a worker
|
||||
2. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
@@ -479,8 +466,6 @@ def validate_indexing_fence(
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector_index.fenced:
|
||||
return
|
||||
|
||||
@@ -516,24 +501,24 @@ def validate_indexing_fence(
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in active_tasks:
|
||||
# the celery task is active (aka currently executing)
|
||||
redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
# we didn't find any direct indication that associated celery tasks exist, but they still might be there
|
||||
# due to gaps in our ability to check states during transitions
|
||||
# Rely on the active signal (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector_index.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
logger.warning(
|
||||
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: "
|
||||
f"index_attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"fence={fence_key}"
|
||||
f"validate_indexing_fence - Resetting fence because no associated celery tasks were found: fence={fence_key}"
|
||||
)
|
||||
if payload.index_attempt_id:
|
||||
try:
|
||||
@@ -798,6 +783,7 @@ def connector_indexing_proxy_task(
|
||||
return
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing proxy - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
@@ -809,58 +795,6 @@ def connector_indexing_proxy_task(
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
# renew active signal
|
||||
redis_connector_index.set_active()
|
||||
|
||||
# if the job is done, clean up and break
|
||||
if job.done():
|
||||
try:
|
||||
if job.status == "error":
|
||||
ignore_exitcode = False
|
||||
|
||||
exit_code: int | None = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
|
||||
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
|
||||
# even though logging clearly indicates successful completion
|
||||
# to work around this, we ignore the job error state if the completion signal is OK
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if not ignore_exitcode:
|
||||
raise RuntimeError("Spawned task exceptioned.")
|
||||
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
|
||||
raise
|
||||
finally:
|
||||
job.release()
|
||||
|
||||
break
|
||||
|
||||
# if a termination signal is detected, clean up and break
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - termination signal detected: "
|
||||
@@ -887,33 +821,75 @@ def connector_indexing_proxy_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
job.cancel()
|
||||
|
||||
break
|
||||
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
if not job.done():
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
ignore_exitcode = False
|
||||
|
||||
exit_code: int | None = None
|
||||
if job.process:
|
||||
exit_code = job.process.exitcode
|
||||
|
||||
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
|
||||
# even though logging clearly indicates that they completed successfully
|
||||
# to work around this, we ignore the job error state if the completion signal is OK
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
ignore_exitcode = True
|
||||
|
||||
if ignore_exitcode:
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - spawned task has non-zero exit code "
|
||||
"but completion signal is OK. Continuing...: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code}"
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"exit_code={exit_code} "
|
||||
f"error={job.exception()}"
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
continue
|
||||
job.release()
|
||||
break
|
||||
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - finished: attempt={index_attempt_id} "
|
||||
@@ -942,7 +918,7 @@ def connector_indexing_task_wrapper(
|
||||
tenant_id,
|
||||
is_ee,
|
||||
)
|
||||
except Exception:
|
||||
except:
|
||||
logger.exception(
|
||||
f"connector_indexing_task exceptioned: "
|
||||
f"tenant={tenant_id} "
|
||||
@@ -950,14 +926,7 @@ def connector_indexing_task_wrapper(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
# There is a cloud related bug outside of our code
|
||||
# where spawned tasks return with an exit code of 1.
|
||||
# Unfortunately, exceptions also return with an exit code of 1,
|
||||
# so just raising an exception isn't informative
|
||||
# Exiting with 255 makes it possible to distinguish between normal exits
|
||||
# and exceptions.
|
||||
sys.exit(255)
|
||||
raise
|
||||
|
||||
return result
|
||||
|
||||
@@ -1029,17 +998,7 @@ def connector_indexing_task(
|
||||
f"fence={redis_connector.stop.fence_key}"
|
||||
)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector_index.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
|
||||
@@ -1080,9 +1039,7 @@ def connector_indexing_task(
|
||||
if not acquired:
|
||||
logger.warning(
|
||||
f"Indexing task already running, exiting...: "
|
||||
f"index_attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
f"index_attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1118,7 +1075,6 @@ def connector_indexing_task(
|
||||
|
||||
# define a callback class
|
||||
callback = IndexingCallback(
|
||||
os.getppid(),
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector_index.generator_progress_key,
|
||||
lock,
|
||||
@@ -1152,19 +1108,8 @@ def connector_indexing_task(
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
if attempt_found:
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id, db_session, failure_reason=str(e)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_failed(index_attempt_id, db_session, failure_reason=str(e))
|
||||
|
||||
raise e
|
||||
finally:
|
||||
|
||||
@@ -81,10 +81,10 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
lock_beat = r.lock(
|
||||
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -92,7 +92,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
return
|
||||
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -127,8 +127,6 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def try_creating_prune_generator_task(
|
||||
celery_app: Celery,
|
||||
@@ -285,7 +283,6 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
|
||||
callback = IndexingCallback(
|
||||
0,
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector.prune.generator_progress_key,
|
||||
lock,
|
||||
|
||||
@@ -20,7 +20,6 @@ from tenacity import RetryError
|
||||
from onyx.access.access import get_access_for_document
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
@@ -88,7 +87,7 @@ logger = setup_logger()
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
time_start = time.monotonic()
|
||||
@@ -103,7 +102,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
try:
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_stale_document_sync_tasks(
|
||||
@@ -165,8 +164,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
lock_beat.release()
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.debug(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
|
||||
return True
|
||||
task_logger.info(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}")
|
||||
return
|
||||
|
||||
|
||||
def try_generate_stale_document_sync_tasks(
|
||||
@@ -637,23 +636,15 @@ def monitor_ccpair_indexing_taskset(
|
||||
if not payload:
|
||||
return
|
||||
|
||||
elapsed_started_str = None
|
||||
if payload.started:
|
||||
elapsed_started = datetime.now(timezone.utc) - payload.started
|
||||
elapsed_started_str = f"{elapsed_started.total_seconds():.2f}"
|
||||
|
||||
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
|
||||
|
||||
progress = redis_connector_index.get_progress()
|
||||
if progress is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"Connector indexing progress: cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
@@ -724,14 +715,11 @@ def monitor_ccpair_indexing_taskset(
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"Connector indexing finished: cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
@@ -778,34 +766,31 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
|
||||
prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"indexing_prefetched={len(prefetched)} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
)
|
||||
|
||||
# scan and monitor activity to completion
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
|
||||
@@ -816,21 +801,28 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
# uncomment for debugging if needed
|
||||
# r_celery = celery_app.broker_connection().channel().client
|
||||
# length = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
# task_logger.warning(f"queue={OnyxCeleryQueues.VESPA_METADATA_SYNC} length={length}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -840,7 +832,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
lock_beat.release()
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.debug(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
|
||||
task_logger.info(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
@@ -91,35 +90,6 @@ def _get_connector_runner(
|
||||
)
|
||||
|
||||
|
||||
def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
cleaned_batch = []
|
||||
for doc in doc_batch:
|
||||
cleaned_doc = doc.model_copy()
|
||||
|
||||
if "\x00" in cleaned_doc.id:
|
||||
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
|
||||
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
|
||||
|
||||
if "\x00" in cleaned_doc.semantic_identifier:
|
||||
logger.warning(
|
||||
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
|
||||
)
|
||||
cleaned_doc.semantic_identifier = cleaned_doc.semantic_identifier.replace(
|
||||
"\x00", ""
|
||||
)
|
||||
|
||||
for section in cleaned_doc.sections:
|
||||
if section.link and "\x00" in section.link:
|
||||
logger.warning(
|
||||
f"NUL characters found in document link for document: {cleaned_doc.id}"
|
||||
)
|
||||
section.link = section.link.replace("\x00", "")
|
||||
|
||||
cleaned_batch.append(cleaned_doc)
|
||||
|
||||
return cleaned_batch
|
||||
|
||||
|
||||
class ConnectorStopSignal(Exception):
|
||||
"""A custom exception used to signal a stop in processing."""
|
||||
|
||||
@@ -268,9 +238,7 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
batch_description = []
|
||||
|
||||
doc_batch_cleaned = strip_null_characters(doc_batch)
|
||||
for doc in doc_batch_cleaned:
|
||||
for doc in doc_batch:
|
||||
batch_description.append(doc.to_short_descriptor())
|
||||
|
||||
doc_size = 0
|
||||
@@ -290,15 +258,15 @@ def _run_indexing(
|
||||
|
||||
# real work happens here!
|
||||
new_docs, total_batch_chunks = indexing_pipeline(
|
||||
document_batch=doc_batch_cleaned,
|
||||
document_batch=doc_batch,
|
||||
index_attempt_metadata=index_attempt_md,
|
||||
)
|
||||
|
||||
batch_num += 1
|
||||
net_doc_change += new_docs
|
||||
chunk_count += total_batch_chunks
|
||||
document_count += len(doc_batch_cleaned)
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch_cleaned)
|
||||
document_count += len(doc_batch)
|
||||
all_connector_doc_ids.update(doc.id for doc in doc_batch)
|
||||
|
||||
# commit transaction so that the `update` below begins
|
||||
# with a brand new transaction. Postgres uses the start
|
||||
@@ -308,7 +276,7 @@ def _run_indexing(
|
||||
db_session.commit()
|
||||
|
||||
if callback:
|
||||
callback.progress("_run_indexing", len(doc_batch_cleaned))
|
||||
callback.progress("_run_indexing", len(doc_batch))
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
update_docs_indexed(
|
||||
|
||||
@@ -92,7 +92,6 @@ SMTP_SERVER = os.environ.get("SMTP_SERVER") or "smtp.gmail.com"
|
||||
SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587")
|
||||
SMTP_USER = os.environ.get("SMTP_USER", "your-email@gmail.com")
|
||||
SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
|
||||
EMAIL_CONFIGURED = all([SMTP_SERVER, SMTP_USER, SMTP_PASS])
|
||||
EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER
|
||||
|
||||
# If set, Onyx will listen to the `expires_at` returned by the identity
|
||||
@@ -146,7 +145,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"
|
||||
AWS_REGION = os.environ.get("AWS_REGION") or "us-east-2"
|
||||
|
||||
POSTGRES_API_SERVER_POOL_SIZE = int(
|
||||
os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40
|
||||
@@ -185,25 +184,6 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
|
||||
# Rate limiting for auth endpoints
|
||||
|
||||
|
||||
RATE_LIMIT_WINDOW_SECONDS: int | None = None
|
||||
_rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS")
|
||||
if _rate_limit_window_seconds_str is not None:
|
||||
try:
|
||||
RATE_LIMIT_WINDOW_SECONDS = int(_rate_limit_window_seconds_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
RATE_LIMIT_MAX_REQUESTS: int | None = None
|
||||
_rate_limit_max_requests_str = os.environ.get("RATE_LIMIT_MAX_REQUESTS")
|
||||
if _rate_limit_max_requests_str is not None:
|
||||
try:
|
||||
RATE_LIMIT_MAX_REQUESTS = int(_rate_limit_max_requests_str)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Used for general redis things
|
||||
REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0))
|
||||
|
||||
|
||||
@@ -36,8 +36,6 @@ DISABLED_GEN_AI_MSG = (
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
DEFAULT_CC_PAIR_ID = 1
|
||||
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
@@ -83,9 +81,6 @@ CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
|
||||
|
||||
# how long a task should wait for associated fence to be ready
|
||||
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
|
||||
@@ -142,7 +137,6 @@ class DocumentSource(str, Enum):
|
||||
FRESHDESK = "freshdesk"
|
||||
FIREFLIES = "fireflies"
|
||||
EGNYTE = "egnyte"
|
||||
AIRTABLE = "airtable"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
@@ -279,7 +273,6 @@ class OnyxRedisLocks:
|
||||
|
||||
SLACK_BOT_LOCK = "da_lock:slack_bot"
|
||||
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
|
||||
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
|
||||
@@ -1,266 +0,0 @@
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
import requests
|
||||
from pyairtable import Api as AirtableApi
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# NOTE: all are made lowercase to avoid case sensitivity issues
|
||||
# these are the field types that are considered metadata rather
|
||||
# than sections
|
||||
_METADATA_FIELD_TYPES = {
|
||||
"singlecollaborator",
|
||||
"collaborator",
|
||||
"createdby",
|
||||
"singleselect",
|
||||
"multipleselects",
|
||||
"checkbox",
|
||||
"date",
|
||||
"datetime",
|
||||
"email",
|
||||
"phone",
|
||||
"url",
|
||||
"number",
|
||||
"currency",
|
||||
"duration",
|
||||
"percent",
|
||||
"rating",
|
||||
"createdtime",
|
||||
"lastmodifiedtime",
|
||||
"autonumber",
|
||||
"rollup",
|
||||
"lookup",
|
||||
"count",
|
||||
"formula",
|
||||
"date",
|
||||
}
|
||||
|
||||
|
||||
class AirtableClientNotSetUpError(PermissionError):
|
||||
def __init__(self) -> None:
|
||||
super().__init__("Airtable Client is not set up, was load_credentials called?")
|
||||
|
||||
|
||||
class AirtableConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.base_id = base_id
|
||||
self.table_name_or_id = table_name_or_id
|
||||
self.batch_size = batch_size
|
||||
self.airtable_client: AirtableApi | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
|
||||
return None
|
||||
|
||||
def _get_field_value(self, field_info: Any, field_type: str) -> list[str]:
|
||||
"""
|
||||
Extract value(s) from a field regardless of its type.
|
||||
Returns either a single string or list of strings for attachments.
|
||||
"""
|
||||
if field_info is None:
|
||||
return []
|
||||
|
||||
# skip references to other records for now (would need to do another
|
||||
# request to get the actual record name/type)
|
||||
# TODO: support this
|
||||
if field_type == "multipleRecordLinks":
|
||||
return []
|
||||
|
||||
if field_type == "multipleAttachments":
|
||||
attachment_texts: list[str] = []
|
||||
for attachment in field_info:
|
||||
url = attachment.get("url")
|
||||
filename = attachment.get("filename", "")
|
||||
if not url:
|
||||
continue
|
||||
|
||||
@retry(
|
||||
tries=5,
|
||||
delay=1,
|
||||
backoff=2,
|
||||
max_delay=10,
|
||||
)
|
||||
def get_attachment_with_retry(url: str) -> bytes | None:
|
||||
attachment_response = requests.get(url)
|
||||
if attachment_response.status_code == 200:
|
||||
return attachment_response.content
|
||||
return None
|
||||
|
||||
attachment_content = get_attachment_with_retry(url)
|
||||
if attachment_content:
|
||||
try:
|
||||
file_ext = get_file_ext(filename)
|
||||
attachment_text = extract_file_text(
|
||||
BytesIO(attachment_content),
|
||||
filename,
|
||||
break_on_unprocessable=False,
|
||||
extension=file_ext,
|
||||
)
|
||||
if attachment_text:
|
||||
attachment_texts.append(f"{filename}:\n{attachment_text}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to process attachment {filename}: {str(e)}"
|
||||
)
|
||||
return attachment_texts
|
||||
|
||||
if field_type in ["singleCollaborator", "collaborator", "createdBy"]:
|
||||
combined = []
|
||||
collab_name = field_info.get("name")
|
||||
collab_email = field_info.get("email")
|
||||
if collab_name:
|
||||
combined.append(collab_name)
|
||||
if collab_email:
|
||||
combined.append(f"({collab_email})")
|
||||
return [" ".join(combined) if combined else str(field_info)]
|
||||
|
||||
if isinstance(field_info, list):
|
||||
return [str(item) for item in field_info]
|
||||
|
||||
return [str(field_info)]
|
||||
|
||||
def _should_be_metadata(self, field_type: str) -> bool:
|
||||
"""Determine if a field type should be treated as metadata."""
|
||||
return field_type.lower() in _METADATA_FIELD_TYPES
|
||||
|
||||
def _process_field(
|
||||
self,
|
||||
field_name: str,
|
||||
field_info: Any,
|
||||
field_type: str,
|
||||
table_id: str,
|
||||
record_id: str,
|
||||
) -> tuple[list[Section], dict[str, Any]]:
|
||||
"""
|
||||
Process a single Airtable field and return sections or metadata.
|
||||
|
||||
Args:
|
||||
field_name: Name of the field
|
||||
field_info: Raw field information from Airtable
|
||||
field_type: Airtable field type
|
||||
|
||||
Returns:
|
||||
(list of Sections, dict of metadata)
|
||||
"""
|
||||
if field_info is None:
|
||||
return [], {}
|
||||
|
||||
# Get the value(s) for the field
|
||||
field_values = self._get_field_value(field_info, field_type)
|
||||
if len(field_values) == 0:
|
||||
return [], {}
|
||||
|
||||
# Determine if it should be metadata or a section
|
||||
if self._should_be_metadata(field_type):
|
||||
if len(field_values) > 1:
|
||||
return [], {field_name: field_values}
|
||||
return [], {field_name: field_values[0]}
|
||||
|
||||
# Otherwise, create relevant sections
|
||||
sections = [
|
||||
Section(
|
||||
link=f"https://airtable.com/{self.base_id}/{table_id}/{record_id}",
|
||||
text=(
|
||||
f"{field_name}:\n"
|
||||
"------------------------\n"
|
||||
f"{text}\n"
|
||||
"------------------------"
|
||||
),
|
||||
)
|
||||
for text in field_values
|
||||
]
|
||||
return sections, {}
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
"""
|
||||
Fetch all records from the table.
|
||||
|
||||
NOTE: Airtable does not support filtering by time updated, so
|
||||
we have to fetch all records every time.
|
||||
"""
|
||||
if not self.airtable_client:
|
||||
raise AirtableClientNotSetUpError()
|
||||
|
||||
table = self.airtable_client.table(self.base_id, self.table_name_or_id)
|
||||
table_id = table.id
|
||||
record_pages = table.iterate()
|
||||
|
||||
table_schema = table.schema()
|
||||
# have to get the name from the schema, since the table object will
|
||||
# give back the ID instead of the name if the ID is used to create
|
||||
# the table object
|
||||
table_name = table_schema.name
|
||||
primary_field_name = None
|
||||
|
||||
# Find a primary field from the schema
|
||||
for field in table_schema.fields:
|
||||
if field.id == table_schema.primary_field_id:
|
||||
primary_field_name = field.name
|
||||
break
|
||||
|
||||
record_documents: list[Document] = []
|
||||
for page in record_pages:
|
||||
for record in page:
|
||||
record_id = record["id"]
|
||||
fields = record["fields"]
|
||||
sections: list[Section] = []
|
||||
metadata: dict[str, Any] = {}
|
||||
|
||||
# Possibly retrieve the primary field's value
|
||||
primary_field_value = (
|
||||
fields.get(primary_field_name) if primary_field_name else None
|
||||
)
|
||||
for field_schema in table_schema.fields:
|
||||
field_name = field_schema.name
|
||||
field_val = fields.get(field_name)
|
||||
field_type = field_schema.type
|
||||
|
||||
field_sections, field_metadata = self._process_field(
|
||||
field_name=field_name,
|
||||
field_info=field_val,
|
||||
field_type=field_type,
|
||||
table_id=table_id,
|
||||
record_id=record_id,
|
||||
)
|
||||
|
||||
sections.extend(field_sections)
|
||||
metadata.update(field_metadata)
|
||||
|
||||
semantic_id = (
|
||||
f"{table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
else table_name
|
||||
)
|
||||
|
||||
record_document = Document(
|
||||
id=f"airtable__{record_id}",
|
||||
sections=sections,
|
||||
source=DocumentSource.AIRTABLE,
|
||||
semantic_identifier=semantic_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
record_documents.append(record_document)
|
||||
|
||||
if len(record_documents) >= self.batch_size:
|
||||
yield record_documents
|
||||
record_documents = []
|
||||
|
||||
if record_documents:
|
||||
yield record_documents
|
||||
@@ -56,23 +56,6 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
|
||||
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
|
||||
"png",
|
||||
"jpg",
|
||||
"jpeg",
|
||||
"gif",
|
||||
"mp4",
|
||||
"mov",
|
||||
"mp3",
|
||||
"wav",
|
||||
]
|
||||
_FULL_EXTENSION_FILTER_STRING = "".join(
|
||||
[
|
||||
f" and title!~'*.{extension}'"
|
||||
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
@@ -81,7 +64,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
is_cloud: bool,
|
||||
space: str = "",
|
||||
page_id: str = "",
|
||||
index_recursively: bool = False,
|
||||
index_recursively: bool = True,
|
||||
cql_query: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE,
|
||||
@@ -99,25 +82,23 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
|
||||
"""
|
||||
If nothing is provided, we default to fetching all pages
|
||||
Only one or none of the following options should be specified so
|
||||
the order shouldn't matter
|
||||
However, we use elif to ensure that only of the following is enforced
|
||||
"""
|
||||
base_cql_page_query = "type=page"
|
||||
# if nothing is provided, we will fetch all pages
|
||||
cql_page_query = "type=page"
|
||||
if cql_query:
|
||||
base_cql_page_query = cql_query
|
||||
# if a cql_query is provided, we will use it to fetch the pages
|
||||
cql_page_query = cql_query
|
||||
elif page_id:
|
||||
# if a cql_query is not provided, we will use the page_id to fetch the page
|
||||
if index_recursively:
|
||||
base_cql_page_query += f" and (ancestor='{page_id}' or id='{page_id}')"
|
||||
cql_page_query += f" and ancestor='{page_id}'"
|
||||
else:
|
||||
base_cql_page_query += f" and id='{page_id}'"
|
||||
cql_page_query += f" and id='{page_id}'"
|
||||
elif space:
|
||||
uri_safe_space = quote(space)
|
||||
base_cql_page_query += f" and space='{uri_safe_space}'"
|
||||
# if no cql_query or page_id is provided, we will use the space to fetch the pages
|
||||
cql_page_query += f" and space='{quote(space)}'"
|
||||
|
||||
self.base_cql_page_query = base_cql_page_query
|
||||
self.cql_page_query = cql_page_query
|
||||
self.cql_time_filter = ""
|
||||
|
||||
self.cql_label_filter = ""
|
||||
if labels_to_skip:
|
||||
@@ -145,33 +126,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
return None
|
||||
|
||||
def _construct_page_query(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> str:
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
|
||||
# Add time filters
|
||||
if start:
|
||||
formatted_start_time = datetime.fromtimestamp(
|
||||
start, tz=self.timezone
|
||||
).strftime("%Y-%m-%d %H:%M")
|
||||
page_query += f" and lastmodified >= '{formatted_start_time}'"
|
||||
if end:
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
page_query += f" and lastmodified <= '{formatted_end_time}'"
|
||||
|
||||
return page_query
|
||||
|
||||
def _construct_attachment_query(self, confluence_page_id: str) -> str:
|
||||
attachment_query = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_query += self.cql_label_filter
|
||||
attachment_query += _FULL_EXTENSION_FILTER_STRING
|
||||
return attachment_query
|
||||
|
||||
def _get_comment_string_for_page_id(self, page_id: str) -> str:
|
||||
comment_string = ""
|
||||
|
||||
@@ -251,15 +205,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
metadata=doc_metadata,
|
||||
)
|
||||
|
||||
def _fetch_document_batches(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
def _fetch_document_batches(self) -> GenerateDocumentsOutput:
|
||||
doc_batch: list[Document] = []
|
||||
confluence_page_ids: list[str] = []
|
||||
|
||||
page_query = self._construct_page_query(start, end)
|
||||
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
|
||||
logger.debug(f"page_query: {page_query}")
|
||||
# Fetch pages as Documents
|
||||
for page in self.confluence_client.paginated_cql_retrieval(
|
||||
@@ -278,10 +228,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
# Fetch attachments as Documents
|
||||
for confluence_page_id in confluence_page_ids:
|
||||
attachment_query = self._construct_attachment_query(confluence_page_id)
|
||||
attachment_cql = f"type=attachment and container='{confluence_page_id}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
# TODO: maybe should add time filter as well?
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=attachment_query,
|
||||
cql=attachment_cql,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
@@ -297,12 +248,17 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_document_batches()
|
||||
|
||||
def poll_source(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
return self._fetch_document_batches(start, end)
|
||||
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
|
||||
# Add time filters
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
|
||||
self.cql_time_filter += f" and lastmodified <= '{formatted_end_time}'"
|
||||
return self._fetch_document_batches()
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
@@ -313,7 +269,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
|
||||
|
||||
page_query = self.base_cql_page_query + self.cql_label_filter
|
||||
page_query = self.cql_page_query + self.cql_label_filter
|
||||
for page in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=page_query,
|
||||
expand=restrictions_expand,
|
||||
@@ -338,9 +294,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
perm_sync_data=page_perm_sync_data,
|
||||
)
|
||||
)
|
||||
attachment_query = self._construct_attachment_query(page["id"])
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
attachment_cql += self.cql_label_filter
|
||||
for attachment in self.confluence_client.cql_paginate_all_expansions(
|
||||
cql=attachment_query,
|
||||
cql=attachment_cql,
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
|
||||
@@ -190,7 +190,7 @@ class DiscourseConnector(PollConnector):
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> GenerateDocumentsOutput:
|
||||
page = 0
|
||||
page = 1
|
||||
while topic_ids := self._get_latest_topics(start, end, page):
|
||||
doc_batch: list[Document] = []
|
||||
for topic_id in topic_ids:
|
||||
|
||||
@@ -7,7 +7,6 @@ from logging import Logger
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
from urllib.parse import quote
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
@@ -73,8 +72,7 @@ def _request_with_retries(
|
||||
logger.exception(
|
||||
f"Failed to call Egnyte API.\n"
|
||||
f"URL: {url}\n"
|
||||
# NOTE: can't log headers because they contain the access token
|
||||
# f"Headers: {headers}\n"
|
||||
f"Headers: {headers}\n"
|
||||
f"Data: {data}\n"
|
||||
f"Params: {params}"
|
||||
)
|
||||
@@ -262,8 +260,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
"list_content": True,
|
||||
}
|
||||
|
||||
url_encoded_path = quote(path or "", safe="")
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}"
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{path or ''}"
|
||||
response = _request_with_retries(
|
||||
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
|
||||
)
|
||||
@@ -318,8 +315,7 @@ class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
}
|
||||
url_encoded_path = quote(file["path"], safe="")
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}"
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}"
|
||||
response = _request_with_retries(
|
||||
method="GET",
|
||||
url=url,
|
||||
|
||||
@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSourceRequiringTenantContext
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
@@ -104,7 +103,6 @@ def identify_connector_class(
|
||||
DocumentSource.FRESHDESK: FreshdeskConnector,
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
DocumentSource.AIRTABLE: AirtableConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Dict
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -250,36 +249,17 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
return new_creds_dict
|
||||
|
||||
def _get_all_user_emails(self) -> list[str]:
|
||||
"""
|
||||
List all user emails if we are on a Google Workspace domain.
|
||||
If the domain is gmail.com, or if we attempt to call the Admin SDK and
|
||||
get a 404, fall back to using the single user.
|
||||
"""
|
||||
|
||||
try:
|
||||
admin_service = get_admin_service(self.creds, self.primary_admin_email)
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
|
||||
except HttpError as e:
|
||||
if e.resp.status == 404:
|
||||
logger.warning(
|
||||
"Received 404 from Admin SDK; this may indicate a personal Gmail account "
|
||||
"with no Workspace domain. Falling back to single user."
|
||||
)
|
||||
return [self.primary_admin_email]
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
raise
|
||||
admin_service = get_admin_service(self.creds, self.primary_admin_email)
|
||||
emails = []
|
||||
for user in execute_paginated_retrieval(
|
||||
retrieval_function=admin_service.users().list,
|
||||
list_key="users",
|
||||
fields=USER_FIELDS,
|
||||
domain=self.google_domain,
|
||||
):
|
||||
if email := user.get("primaryEmail"):
|
||||
emails.append(email)
|
||||
return emails
|
||||
|
||||
def _fetch_threads(
|
||||
self,
|
||||
|
||||
@@ -8,7 +8,6 @@ from typing import cast
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
|
||||
@@ -21,7 +20,6 @@ from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
|
||||
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
|
||||
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_utils.google_auth import get_google_creds
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
@@ -43,7 +41,6 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
logger = setup_logger()
|
||||
# TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html
|
||||
@@ -289,30 +286,13 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
logger.info(f"Impersonating user {user_email}")
|
||||
|
||||
drive_service = get_drive_service(self.creds, user_email)
|
||||
|
||||
# validate that the user has access to the drive APIs by performing a simple
|
||||
# request and checking for a 401
|
||||
try:
|
||||
retry_builder()(get_root_folder_id)(drive_service)
|
||||
except HttpError as e:
|
||||
if e.status_code == 401:
|
||||
# fail gracefully, let the other impersonations continue
|
||||
# one user without access shouldn't block the entire connector
|
||||
logger.exception(
|
||||
f"User '{user_email}' does not have access to the drive APIs."
|
||||
)
|
||||
return
|
||||
raise
|
||||
|
||||
# if we are including my drives, try to get the current user's my
|
||||
# drive if any of the following are true:
|
||||
# - include_my_drives is true
|
||||
# - the current user's email is in the requested emails
|
||||
if self.include_my_drives or user_email in self._requested_my_drive_emails:
|
||||
logger.info(f"Getting all files in my drive as '{user_email}'")
|
||||
yield from get_all_files_in_my_drive(
|
||||
service=drive_service,
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
@@ -323,7 +303,6 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
remaining_drive_ids = filtered_drive_ids - self._retrieved_ids
|
||||
for drive_id in remaining_drive_ids:
|
||||
logger.info(f"Getting files in shared drive '{drive_id}' as '{user_email}'")
|
||||
yield from get_files_in_shared_drive(
|
||||
service=drive_service,
|
||||
drive_id=drive_id,
|
||||
@@ -335,7 +314,6 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
remaining_folders = filtered_folder_ids - self._retrieved_ids
|
||||
for folder_id in remaining_folders:
|
||||
logger.info(f"Getting files in folder '{folder_id}' as '{user_email}'")
|
||||
yield from crawl_folders_for_files(
|
||||
service=drive_service,
|
||||
parent_id=folder_id,
|
||||
@@ -366,15 +344,6 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
elif self.include_shared_drives:
|
||||
drive_ids_to_retrieve = all_drive_ids
|
||||
|
||||
# checkpoint - we've found all users and drives, now time to actually start
|
||||
# fetching stuff
|
||||
logger.info(f"Found {len(all_org_emails)} users to impersonate")
|
||||
logger.debug(f"Users: {all_org_emails}")
|
||||
logger.info(f"Found {len(drive_ids_to_retrieve)} drives to retrieve")
|
||||
logger.debug(f"Drives: {drive_ids_to_retrieve}")
|
||||
logger.info(f"Found {len(folder_ids_to_retrieve)} folders to retrieve")
|
||||
logger.debug(f"Folders: {folder_ids_to_retrieve}")
|
||||
|
||||
# Process users in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
future_to_email = {
|
||||
@@ -411,13 +380,6 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
drive_service = get_drive_service(self.creds, self.primary_admin_email)
|
||||
|
||||
if self.include_files_shared_with_me or self.include_my_drives:
|
||||
logger.info(
|
||||
f"Getting shared files/my drive files for OAuth "
|
||||
f"with include_files_shared_with_me={self.include_files_shared_with_me}, "
|
||||
f"include_my_drives={self.include_my_drives}, "
|
||||
f"include_shared_drives={self.include_shared_drives}."
|
||||
f"Using '{self.primary_admin_email}' as the account."
|
||||
)
|
||||
yield from get_all_files_for_oauth(
|
||||
service=drive_service,
|
||||
include_files_shared_with_me=self.include_files_shared_with_me,
|
||||
@@ -450,9 +412,6 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
drive_ids_to_retrieve = all_drive_ids
|
||||
|
||||
for drive_id in drive_ids_to_retrieve:
|
||||
logger.info(
|
||||
f"Getting files in shared drive '{drive_id}' as '{self.primary_admin_email}'"
|
||||
)
|
||||
yield from get_files_in_shared_drive(
|
||||
service=drive_service,
|
||||
drive_id=drive_id,
|
||||
@@ -466,9 +425,6 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# that could be folders.
|
||||
remaining_folders = folder_ids_to_retrieve - self._retrieved_ids
|
||||
for folder_id in remaining_folders:
|
||||
logger.info(
|
||||
f"Getting files in folder '{folder_id}' as '{self.primary_admin_email}'"
|
||||
)
|
||||
yield from crawl_folders_for_files(
|
||||
service=drive_service,
|
||||
parent_id=folder_id,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import UTC
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from simple_salesforce import Salesforce
|
||||
@@ -19,36 +19,23 @@ from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.salesforce.doc_conversion import extract_sections
|
||||
from onyx.connectors.salesforce.utils import extract_dict_text
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.utils import batch_list
|
||||
|
||||
|
||||
# TODO: this connector does not work well at large scales
|
||||
# the large query against a large Salesforce instance has been reported to take 1.5 hours.
|
||||
# Additionally it seems to eat up more memory over time if the connection is long running (again a scale issue).
|
||||
|
||||
|
||||
DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
MAX_QUERY_LENGTH = 10000 # max query length is 20,000 characters
|
||||
ID_PREFIX = "SALESFORCE_"
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# max query length is 20,000 characters, leave 5000 characters for slop
|
||||
_MAX_QUERY_LENGTH = 10000
|
||||
# There are 22 extra characters per ID so 200 * 22 = 4400 characters which is
|
||||
# still well under the max query length
|
||||
_MAX_ID_BATCH_SIZE = 200
|
||||
|
||||
|
||||
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
_ID_PREFIX = "SALESFORCE_"
|
||||
|
||||
|
||||
def _build_time_filter_for_salesforce(
|
||||
start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
|
||||
) -> str:
|
||||
if start is None or end is None:
|
||||
return ""
|
||||
start_datetime = datetime.fromtimestamp(start, UTC)
|
||||
end_datetime = datetime.fromtimestamp(end, UTC)
|
||||
return (
|
||||
f" WHERE LastModifiedDate > {start_datetime.isoformat()} "
|
||||
f"AND LastModifiedDate < {end_datetime.isoformat()}"
|
||||
)
|
||||
|
||||
|
||||
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
@@ -57,34 +44,33 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
requested_objects: list[str] = [],
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self._sf_client: Salesforce | None = None
|
||||
self.sf_client: Salesforce | None = None
|
||||
self.parent_object_list = (
|
||||
[obj.capitalize() for obj in requested_objects]
|
||||
if requested_objects
|
||||
else _DEFAULT_PARENT_OBJECT_TYPES
|
||||
else DEFAULT_PARENT_OBJECT_TYPES
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._sf_client = Salesforce(
|
||||
self.sf_client = Salesforce(
|
||||
username=credentials["sf_username"],
|
||||
password=credentials["sf_password"],
|
||||
security_token=credentials["sf_security_token"],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def sf_client(self) -> Salesforce:
|
||||
if self._sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
return self._sf_client
|
||||
|
||||
def _get_sf_type_object_json(self, type_name: str) -> Any:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
sf_object = SFType(
|
||||
type_name, self.sf_client.session_id, self.sf_client.sf_instance
|
||||
)
|
||||
return sf_object.describe()
|
||||
|
||||
def _get_name_from_id(self, id: str) -> str:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
try:
|
||||
user_object_info = self.sf_client.query(
|
||||
f"SELECT Name FROM User WHERE Id = '{id}'"
|
||||
@@ -98,10 +84,14 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def _convert_object_instance_to_document(
|
||||
self, object_dict: dict[str, Any]
|
||||
) -> Document:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
salesforce_id = object_dict["Id"]
|
||||
onyx_salesforce_id = f"{_ID_PREFIX}{salesforce_id}"
|
||||
base_url = f"https://{self.sf_client.sf_instance}"
|
||||
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
|
||||
extracted_link = f"https://{self.sf_client.sf_instance}/{salesforce_id}"
|
||||
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
|
||||
extracted_object_text = extract_dict_text(object_dict)
|
||||
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
|
||||
extracted_primary_owners = [
|
||||
BasicExpertInfo(
|
||||
@@ -111,7 +101,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
doc = Document(
|
||||
id=onyx_salesforce_id,
|
||||
sections=extract_sections(object_dict, base_url),
|
||||
sections=[Section(link=extracted_link, text=extracted_object_text)],
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier=extracted_semantic_identifier,
|
||||
doc_updated_at=extracted_doc_updated_at,
|
||||
@@ -121,6 +111,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
return doc
|
||||
|
||||
def _is_valid_child_object(self, child_relationship: dict) -> bool:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
if not child_relationship["childSObject"]:
|
||||
return False
|
||||
if not child_relationship["relationshipName"]:
|
||||
@@ -149,7 +142,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
return True
|
||||
|
||||
def _get_all_children_of_sf_type(self, sf_type: str) -> list[dict]:
|
||||
logger.debug(f"Fetching children for SF type: {sf_type}")
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
object_description = self._get_sf_type_object_json(sf_type)
|
||||
|
||||
children_objects: list[dict] = []
|
||||
@@ -164,6 +159,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
return children_objects
|
||||
|
||||
def _get_all_fields_for_sf_type(self, sf_type: str) -> list[str]:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
object_description = self._get_sf_type_object_json(sf_type)
|
||||
|
||||
fields = [
|
||||
@@ -174,60 +172,23 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
return fields
|
||||
|
||||
def _get_parent_object_ids(
|
||||
self, parent_sf_type: str, time_filter_query: str
|
||||
) -> list[str]:
|
||||
"""Fetch all IDs for a given parent object type."""
|
||||
logger.debug(f"Fetching IDs for parent type: {parent_sf_type}")
|
||||
query = f"SELECT Id FROM {parent_sf_type}{time_filter_query}"
|
||||
query_result = self.sf_client.query_all(query)
|
||||
ids = [record["Id"] for record in query_result["records"]]
|
||||
logger.debug(f"Found {len(ids)} IDs for parent type: {parent_sf_type}")
|
||||
return ids
|
||||
|
||||
def _process_id_batch(
|
||||
self,
|
||||
id_batch: list[str],
|
||||
queries: list[str],
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Process a batch of IDs using the given queries."""
|
||||
# Initialize results dictionary for this batch
|
||||
logger.debug(f"Processing batch of {len(id_batch)} IDs")
|
||||
query_results: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# For each query, fetch and combine results for the batch
|
||||
for query in queries:
|
||||
id_filter = f" WHERE Id IN {tuple(id_batch)}"
|
||||
batch_query = query + id_filter
|
||||
logger.debug(f"Executing query with length: {len(batch_query)}")
|
||||
query_result = self.sf_client.query_all(batch_query)
|
||||
logger.debug(f"Retrieved {len(query_result['records'])} records for query")
|
||||
|
||||
for record_dict in query_result["records"]:
|
||||
query_results.setdefault(record_dict["Id"], {}).update(record_dict)
|
||||
|
||||
# Convert results to documents
|
||||
return query_results
|
||||
|
||||
def _generate_query_per_parent_type(self, parent_sf_type: str) -> Iterator[str]:
|
||||
"""
|
||||
parent_sf_type is a string that represents the Salesforce object type.
|
||||
This function generates queries that will fetch:
|
||||
- all the fields of the parent object type
|
||||
- all the fields of the child objects of the parent object type
|
||||
This function takes in an object_type and generates query(s) designed to grab
|
||||
information associated to objects of that type.
|
||||
It does that by getting all the fields of the parent object type.
|
||||
Then it gets all the child objects of that object type and all the fields of
|
||||
those children as well.
|
||||
"""
|
||||
logger.debug(f"Generating queries for parent type: {parent_sf_type}")
|
||||
parent_fields = self._get_all_fields_for_sf_type(parent_sf_type)
|
||||
logger.debug(f"Found {len(parent_fields)} fields for parent type")
|
||||
child_sf_types = self._get_all_children_of_sf_type(parent_sf_type)
|
||||
logger.debug(f"Found {len(child_sf_types)} child types")
|
||||
|
||||
query = f"SELECT {', '.join(parent_fields)}"
|
||||
for child_object_dict in child_sf_types:
|
||||
fields = self._get_all_fields_for_sf_type(child_object_dict["object_type"])
|
||||
query_addition = f", \n(SELECT {', '.join(fields)} FROM {child_object_dict['relationship_name']})"
|
||||
|
||||
if len(query_addition) + len(query) > _MAX_QUERY_LENGTH:
|
||||
if len(query_addition) + len(query) > MAX_QUERY_LENGTH:
|
||||
query += f"\n FROM {parent_sf_type}"
|
||||
yield query
|
||||
query = "SELECT Id" + query_addition
|
||||
@@ -238,43 +199,45 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
yield query
|
||||
|
||||
def _batch_retrieval(
|
||||
self,
|
||||
id_batches: list[list[str]],
|
||||
queries: list[str],
|
||||
) -> GenerateDocumentsOutput:
|
||||
doc_batch: list[Document] = []
|
||||
# For each batch of IDs, perform all queries and convert to documents
|
||||
# so they can be yielded in batches
|
||||
for id_batch in id_batches:
|
||||
query_results = self._process_id_batch(id_batch, queries)
|
||||
for doc in query_results.values():
|
||||
doc_batch.append(self._convert_object_instance_to_document(doc))
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
yield doc_batch
|
||||
|
||||
def _fetch_from_salesforce(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> GenerateDocumentsOutput:
|
||||
logger.debug(f"Starting Salesforce fetch from {start} to {end}")
|
||||
time_filter_query = _build_time_filter_for_salesforce(start, end)
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
|
||||
doc_batch: list[Document] = []
|
||||
for parent_object_type in self.parent_object_list:
|
||||
logger.info(f"Processing parent object type: {parent_object_type}")
|
||||
logger.debug(f"Processing: {parent_object_type}")
|
||||
|
||||
all_ids = self._get_parent_object_ids(parent_object_type, time_filter_query)
|
||||
logger.info(f"Found {len(all_ids)} IDs for {parent_object_type}")
|
||||
id_batches = batch_list(all_ids, _MAX_ID_BATCH_SIZE)
|
||||
query_results: dict = {}
|
||||
for query in self._generate_query_per_parent_type(parent_object_type):
|
||||
if start is not None and end is not None:
|
||||
if start and start.tzinfo is None:
|
||||
start = start.replace(tzinfo=timezone.utc)
|
||||
if end and end.tzinfo is None:
|
||||
end = end.replace(tzinfo=timezone.utc)
|
||||
query += f" WHERE LastModifiedDate > {start.isoformat()} AND LastModifiedDate < {end.isoformat()}"
|
||||
|
||||
# Generate all queries we'll need
|
||||
queries = list(self._generate_query_per_parent_type(parent_object_type))
|
||||
logger.info(f"Generated {len(queries)} queries for {parent_object_type}")
|
||||
yield from self._batch_retrieval(id_batches, queries)
|
||||
query_result = self.sf_client.query_all(query)
|
||||
|
||||
for record_dict in query_result["records"]:
|
||||
query_results.setdefault(record_dict["Id"], {}).update(record_dict)
|
||||
|
||||
logger.info(
|
||||
f"Number of {parent_object_type} Objects processed: {len(query_results)}"
|
||||
)
|
||||
|
||||
for combined_object_dict in query_results.values():
|
||||
doc_batch.append(
|
||||
self._convert_object_instance_to_document(combined_object_dict)
|
||||
)
|
||||
|
||||
if len(doc_batch) > self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_from_salesforce()
|
||||
@@ -282,20 +245,26 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
return self._fetch_from_salesforce(start=start, end=end)
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
start_datetime = datetime.utcfromtimestamp(start)
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
return self._fetch_from_salesforce(start=start_datetime, end=end_datetime)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.sf_client is None:
|
||||
raise ConnectorMissingCredentialError("Salesforce")
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
for parent_object_type in self.parent_object_list:
|
||||
query = f"SELECT Id FROM {parent_object_type}"
|
||||
query_result = self.sf_client.query_all(query)
|
||||
doc_metadata_list.extend(
|
||||
SlimDocument(
|
||||
id=f"{_ID_PREFIX}{instance_dict.get('Id', '')}",
|
||||
id=f"{ID_PREFIX}{instance_dict.get('Id', '')}",
|
||||
perm_sync_data={},
|
||||
)
|
||||
for instance_dict in query_result["records"]
|
||||
|
||||
@@ -1,148 +0,0 @@
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
|
||||
from onyx.connectors.models import Section
|
||||
|
||||
# All of these types of keys are handled by specific fields in the doc
|
||||
# conversion process (E.g. URLs) or are not useful for the user (E.g. UUIDs)
|
||||
_SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
|
||||
|
||||
|
||||
def _clean_salesforce_dict(data: dict | list) -> dict | list:
|
||||
"""Clean and transform Salesforce API response data by recursively:
|
||||
1. Extracting records from the response if present
|
||||
2. Merging attributes into the main dictionary
|
||||
3. Filtering out keys matching certain patterns (Id, Date, stamp, url)
|
||||
4. Removing '__c' suffix from custom field names
|
||||
5. Removing None values and empty containers
|
||||
|
||||
Args:
|
||||
data: A dictionary or list from Salesforce API response
|
||||
|
||||
Returns:
|
||||
Cleaned dictionary or list with transformed keys and filtered values
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
if "records" in data.keys():
|
||||
data = data["records"]
|
||||
if isinstance(data, dict):
|
||||
if "attributes" in data.keys():
|
||||
if isinstance(data["attributes"], dict):
|
||||
data.update(data.pop("attributes"))
|
||||
|
||||
if isinstance(data, dict):
|
||||
filtered_dict = {}
|
||||
for key, value in data.items():
|
||||
if not re.search(_SF_JSON_FILTER, key, re.IGNORECASE):
|
||||
# remove the custom object indicator for display
|
||||
if "__c" in key:
|
||||
key = key[:-3]
|
||||
if isinstance(value, (dict, list)):
|
||||
filtered_value = _clean_salesforce_dict(value)
|
||||
# Only add non-empty dictionaries or lists
|
||||
if filtered_value:
|
||||
filtered_dict[key] = filtered_value
|
||||
elif value is not None:
|
||||
filtered_dict[key] = value
|
||||
return filtered_dict
|
||||
elif isinstance(data, list):
|
||||
filtered_list = []
|
||||
for item in data:
|
||||
if isinstance(item, (dict, list)):
|
||||
filtered_item = _clean_salesforce_dict(item)
|
||||
# Only add non-empty dictionaries or lists
|
||||
if filtered_item:
|
||||
filtered_list.append(filtered_item)
|
||||
elif item is not None:
|
||||
filtered_list.append(filtered_item)
|
||||
return filtered_list
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def _json_to_natural_language(data: dict | list, indent: int = 0) -> str:
|
||||
"""Convert a nested dictionary or list into a human-readable string format.
|
||||
|
||||
Recursively traverses the data structure and formats it with:
|
||||
- Key-value pairs on separate lines
|
||||
- Nested structures indented for readability
|
||||
- Lists and dictionaries handled with appropriate formatting
|
||||
|
||||
Args:
|
||||
data: The dictionary or list to convert
|
||||
indent: Number of spaces to indent (default: 0)
|
||||
|
||||
Returns:
|
||||
A formatted string representation of the data structure
|
||||
"""
|
||||
result = []
|
||||
indent_str = " " * indent
|
||||
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
result.append(f"{indent_str}{key}:")
|
||||
result.append(_json_to_natural_language(value, indent + 2))
|
||||
else:
|
||||
result.append(f"{indent_str}{key}: {value}")
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
result.append(_json_to_natural_language(item, indent + 2))
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def _extract_dict_text(raw_dict: dict) -> str:
|
||||
"""Extract text from a Salesforce API response dictionary by:
|
||||
1. Cleaning the dictionary
|
||||
2. Converting the cleaned dictionary to natural language
|
||||
"""
|
||||
processed_dict = _clean_salesforce_dict(raw_dict)
|
||||
natural_language_for_dict = _json_to_natural_language(processed_dict)
|
||||
return natural_language_for_dict
|
||||
|
||||
|
||||
def _field_value_is_child_object(field_value: dict) -> bool:
|
||||
"""
|
||||
Checks if the field value is a child object.
|
||||
"""
|
||||
return (
|
||||
isinstance(field_value, OrderedDict)
|
||||
and "records" in field_value.keys()
|
||||
and isinstance(field_value["records"], list)
|
||||
and len(field_value["records"]) > 0
|
||||
and "Id" in field_value["records"][0].keys()
|
||||
)
|
||||
|
||||
|
||||
def extract_sections(salesforce_object: dict, base_url: str) -> list[Section]:
|
||||
"""
|
||||
This goes through the salesforce_object and extracts the top level fields as a Section.
|
||||
It also goes through the child objects and extracts them as Sections.
|
||||
"""
|
||||
top_level_dict = {}
|
||||
|
||||
child_object_sections = []
|
||||
for field_name, field_value in salesforce_object.items():
|
||||
# If the field value is not a child object, add it to the top level dict
|
||||
# to turn into text for the top level section
|
||||
if not _field_value_is_child_object(field_value):
|
||||
top_level_dict[field_name] = field_value
|
||||
continue
|
||||
|
||||
# If the field value is a child object, extract the child objects and add them as sections
|
||||
for record in field_value["records"]:
|
||||
child_object_id = record["Id"]
|
||||
child_object_sections.append(
|
||||
Section(
|
||||
text=f"Child Object(s): {field_name}\n{_extract_dict_text(record)}",
|
||||
link=f"{base_url}/{child_object_id}",
|
||||
)
|
||||
)
|
||||
|
||||
top_level_id = salesforce_object["Id"]
|
||||
top_level_section = Section(
|
||||
text=_extract_dict_text(top_level_dict),
|
||||
link=f"{base_url}/{top_level_id}",
|
||||
)
|
||||
return [top_level_section, *child_object_sections]
|
||||
66
backend/onyx/connectors/salesforce/utils.py
Normal file
66
backend/onyx/connectors/salesforce/utils.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import re
|
||||
from typing import Union
|
||||
|
||||
SF_JSON_FILTER = r"Id$|Date$|stamp$|url$"
|
||||
|
||||
|
||||
def _clean_salesforce_dict(data: Union[dict, list]) -> Union[dict, list]:
|
||||
if isinstance(data, dict):
|
||||
if "records" in data.keys():
|
||||
data = data["records"]
|
||||
if isinstance(data, dict):
|
||||
if "attributes" in data.keys():
|
||||
if isinstance(data["attributes"], dict):
|
||||
data.update(data.pop("attributes"))
|
||||
|
||||
if isinstance(data, dict):
|
||||
filtered_dict = {}
|
||||
for key, value in data.items():
|
||||
if not re.search(SF_JSON_FILTER, key, re.IGNORECASE):
|
||||
if "__c" in key: # remove the custom object indicator for display
|
||||
key = key[:-3]
|
||||
if isinstance(value, (dict, list)):
|
||||
filtered_value = _clean_salesforce_dict(value)
|
||||
if filtered_value: # Only add non-empty dictionaries or lists
|
||||
filtered_dict[key] = filtered_value
|
||||
elif value is not None:
|
||||
filtered_dict[key] = value
|
||||
return filtered_dict
|
||||
elif isinstance(data, list):
|
||||
filtered_list = []
|
||||
for item in data:
|
||||
if isinstance(item, (dict, list)):
|
||||
filtered_item = _clean_salesforce_dict(item)
|
||||
if filtered_item: # Only add non-empty dictionaries or lists
|
||||
filtered_list.append(filtered_item)
|
||||
elif item is not None:
|
||||
filtered_list.append(filtered_item)
|
||||
return filtered_list
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
def _json_to_natural_language(data: Union[dict, list], indent: int = 0) -> str:
|
||||
result = []
|
||||
indent_str = " " * indent
|
||||
|
||||
if isinstance(data, dict):
|
||||
for key, value in data.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
result.append(f"{indent_str}{key}:")
|
||||
result.append(_json_to_natural_language(value, indent + 2))
|
||||
else:
|
||||
result.append(f"{indent_str}{key}: {value}")
|
||||
elif isinstance(data, list):
|
||||
for item in data:
|
||||
result.append(_json_to_natural_language(item, indent))
|
||||
else:
|
||||
result.append(f"{indent_str}{data}")
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
|
||||
def extract_dict_text(raw_dict: dict) -> str:
|
||||
processed_dict = _clean_salesforce_dict(raw_dict)
|
||||
natural_language_dict = _json_to_natural_language(processed_dict)
|
||||
return natural_language_dict
|
||||
@@ -33,7 +33,6 @@ from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.sitemap import list_pages_for_site
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -242,12 +241,6 @@ class WebConnector(LoadConnector):
|
||||
self.to_visit_list = extract_urls_from_sitemap(_ensure_valid_url(base_url))
|
||||
|
||||
elif web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.UPLOAD:
|
||||
# Explicitly check if running in multi-tenant mode to prevent potential security risks
|
||||
if MULTI_TENANT:
|
||||
raise ValueError(
|
||||
"Upload input for web connector is not supported in cloud environments"
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
"This is not a UI supported Web Connector flow, "
|
||||
"are you sure you want to do this?"
|
||||
|
||||
@@ -40,13 +40,6 @@ class ZendeskClient:
|
||||
response = requests.get(
|
||||
f"{self.base_url}/{endpoint}", auth=self.auth, params=params
|
||||
)
|
||||
|
||||
if response.status_code == 429:
|
||||
retry_after = response.headers.get("Retry-After")
|
||||
if retry_after is not None:
|
||||
# Sleep for the duration indicated by the Retry-After header
|
||||
time.sleep(int(retry_after))
|
||||
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
@@ -54,11 +54,9 @@ def get_total_users_count(db_session: Session) -> int:
|
||||
return user_count + invited_users
|
||||
|
||||
|
||||
async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
async def get_user_count() -> int:
|
||||
async with get_async_session_with_tenant() as session:
|
||||
stmt = select(func.count(User.id))
|
||||
if only_admin_users:
|
||||
stmt = stmt.where(User.role == UserRole.ADMIN)
|
||||
result = await session.execute(stmt)
|
||||
user_count = result.scalar()
|
||||
if user_count is None:
|
||||
|
||||
@@ -7,7 +7,6 @@ from sqlalchemy import exists
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -91,22 +90,15 @@ def get_connector_credential_pairs(
|
||||
user: User | None = None,
|
||||
get_editable: bool = True,
|
||||
ids: list[int] | None = None,
|
||||
eager_load_connector: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
|
||||
if eager_load_connector:
|
||||
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
|
||||
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
|
||||
if not include_disabled:
|
||||
stmt = stmt.where(
|
||||
ConnectorCredentialPair.status == ConnectorCredentialPairStatus.ACTIVE
|
||||
)
|
||||
) # noqa
|
||||
if ids:
|
||||
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
@@ -318,9 +310,6 @@ def associate_default_cc_pair(db_session: Session) -> None:
|
||||
if existing_association is not None:
|
||||
return
|
||||
|
||||
# DefaultCCPair has id 1 since it is the first CC pair created
|
||||
# It is DEFAULT_CC_PAIR_ID, but can't set it explicitly because it messed with the
|
||||
# auto-incrementing id
|
||||
association = ConnectorCredentialPair(
|
||||
connector_id=0,
|
||||
credential_id=0,
|
||||
@@ -361,12 +350,7 @@ def add_credential_to_connector(
|
||||
last_successful_index_time: datetime | None = None,
|
||||
) -> StatusResponse:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
get_editable=False,
|
||||
)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
@@ -443,12 +427,7 @@ def remove_credential_from_connector(
|
||||
db_session: Session,
|
||||
) -> StatusResponse[int]:
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
get_editable=False,
|
||||
)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
|
||||
if connector is None:
|
||||
raise HTTPException(status_code=404, detail="Connector does not exist")
|
||||
|
||||
@@ -86,7 +86,7 @@ def _add_user_filters(
|
||||
"""
|
||||
Filter Credentials by:
|
||||
- if the user is in the user_group that owns the Credential
|
||||
- if the user is a curator, they must also have a curator relationship
|
||||
- if the user is not a global_curator, they must also have a curator relationship
|
||||
to the user_group
|
||||
- if editing is being done, we also filter out Credentials that are owned by groups
|
||||
that the user isn't a curator for
|
||||
@@ -97,7 +97,6 @@ def _add_user_filters(
|
||||
where_clause = User__UserGroup.user_id == user.id
|
||||
if user.role == UserRole.CURATOR:
|
||||
where_clause &= User__UserGroup.is_curator == True # noqa: E712
|
||||
|
||||
if get_editable:
|
||||
user_groups = select(User__UserGroup.user_group_id).where(
|
||||
User__UserGroup.user_id == user.id
|
||||
@@ -153,16 +152,10 @@ def fetch_credential_by_id(
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
assume_admin: bool = False,
|
||||
get_editable: bool = True,
|
||||
) -> Credential | None:
|
||||
stmt = select(Credential).distinct()
|
||||
stmt = stmt.where(Credential.id == credential_id)
|
||||
stmt = _add_user_filters(
|
||||
stmt=stmt,
|
||||
user=user,
|
||||
assume_admin=assume_admin,
|
||||
get_editable=get_editable,
|
||||
)
|
||||
stmt = _add_user_filters(stmt, user, assume_admin=assume_admin)
|
||||
result = db_session.execute(stmt)
|
||||
credential = result.scalar_one_or_none()
|
||||
return credential
|
||||
|
||||
@@ -27,7 +27,7 @@ from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from onyx.configs.app_configs import AWS_REGION_NAME
|
||||
from onyx.configs.app_configs import AWS_REGION
|
||||
from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS
|
||||
from onyx.configs.app_configs import LOG_POSTGRES_LATENCY
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
@@ -273,7 +273,7 @@ async def get_async_connection() -> Any:
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
db = POSTGRES_DB
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION)
|
||||
|
||||
# asyncpg requires 'ssl="require"' if SSL needed
|
||||
return await asyncpg.connect(
|
||||
@@ -315,7 +315,7 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION_NAME)
|
||||
token = get_iam_auth_token(host, port, user, AWS_REGION)
|
||||
cparams["password"] = token
|
||||
cparams["ssl"] = ssl_context
|
||||
|
||||
@@ -525,6 +525,6 @@ def provide_iam_token(dialect: Any, conn_rec: Any, cargs: Any, cparams: Any) ->
|
||||
host = POSTGRES_HOST
|
||||
port = POSTGRES_PORT
|
||||
user = POSTGRES_USER
|
||||
region = os.getenv("AWS_REGION_NAME", "us-east-2")
|
||||
region = os.getenv("AWS_REGION", "us-east-2")
|
||||
# Configure for psycopg2 with IAM token
|
||||
configure_psycopg2_iam_auth(cparams, host, port, user, region)
|
||||
|
||||
@@ -54,7 +54,6 @@ from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.pydantic_type import PydanticType
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
@@ -66,8 +65,6 @@ from onyx.utils.headers import HeaderItemDict
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import RerankerProvider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
__abstract__ = True
|
||||
@@ -75,8 +72,6 @@ class Base(DeclarativeBase):
|
||||
|
||||
class EncryptedString(TypeDecorator):
|
||||
impl = LargeBinary
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: str | None, dialect: Dialect) -> bytes | None:
|
||||
if value is not None:
|
||||
@@ -91,8 +86,6 @@ class EncryptedString(TypeDecorator):
|
||||
|
||||
class EncryptedJson(TypeDecorator):
|
||||
impl = LargeBinary
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: dict | None, dialect: Dialect) -> bytes | None:
|
||||
if value is not None:
|
||||
@@ -109,21 +102,6 @@ class EncryptedJson(TypeDecorator):
|
||||
return value
|
||||
|
||||
|
||||
class NullFilteredString(TypeDecorator):
|
||||
impl = String
|
||||
# This type's behavior is fully deterministic and doesn't depend on any external factors.
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: str | None, dialect: Dialect) -> str | None:
|
||||
if value is not None and "\x00" in value:
|
||||
logger.warning(f"NUL characters found in value: {value}")
|
||||
return value.replace("\x00", "")
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: str | None, dialect: Dialect) -> str | None:
|
||||
return value
|
||||
|
||||
|
||||
"""
|
||||
Auth/Authz (users, permissions, access) Tables
|
||||
"""
|
||||
@@ -473,16 +451,16 @@ class Document(Base):
|
||||
|
||||
# this should correspond to the ID of the document
|
||||
# (as is passed around in Onyx)
|
||||
id: Mapped[str] = mapped_column(NullFilteredString, primary_key=True)
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
from_ingestion_api: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, nullable=True
|
||||
)
|
||||
# 0 for neutral, positive for mostly endorse, negative for mostly reject
|
||||
boost: Mapped[int] = mapped_column(Integer, default=DEFAULT_BOOST)
|
||||
hidden: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
semantic_id: Mapped[str] = mapped_column(NullFilteredString)
|
||||
semantic_id: Mapped[str] = mapped_column(String)
|
||||
# First Section's link
|
||||
link: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True)
|
||||
link: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# The updated time is also used as a measure of the last successful state of the doc
|
||||
# pulled from the source (to help skip reindexing already updated docs in case of
|
||||
|
||||
@@ -7,15 +7,8 @@ from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
def validate_user_role_update(requested_role: UserRole, current_role: UserRole) -> None:
|
||||
@@ -192,43 +185,3 @@ def batch_add_ext_perm_user_if_not_exists(
|
||||
db_session.commit()
|
||||
|
||||
return found_users + new_users
|
||||
|
||||
|
||||
def delete_user_from_db(
|
||||
user_to_delete: User,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
for oauth_account in user_to_delete.oauth_accounts:
|
||||
db_session.delete(oauth_account)
|
||||
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.external_perm",
|
||||
"delete_user__ext_group_for_user__no_commit",
|
||||
)(
|
||||
db_session=db_session,
|
||||
user_id=user_to_delete.id,
|
||||
)
|
||||
db_session.query(SamlAccount).filter(
|
||||
SamlAccount.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(DocumentSet__User).filter(
|
||||
DocumentSet__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(User__UserGroup).filter(
|
||||
User__UserGroup.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.delete(user_to_delete)
|
||||
db_session.commit()
|
||||
|
||||
# NOTE: edge case may exist with race conditions
|
||||
# with this `invited user` scheme generally.
|
||||
user_emails = get_invited_users()
|
||||
remaining_users = [
|
||||
remaining_user_email
|
||||
for remaining_user_email in user_emails
|
||||
if remaining_user_email != user_to_delete.email
|
||||
]
|
||||
write_invited_users(remaining_users)
|
||||
|
||||
@@ -57,7 +57,4 @@ def get_uuid_from_chunk(
|
||||
for referenced_chunk_id in chunk.large_chunk_reference_ids
|
||||
]
|
||||
)
|
||||
# Add tenant_id if it exists
|
||||
if hasattr(chunk, "tenant_id") and chunk.tenant_id:
|
||||
unique_identifier_string += f"_tenant_{chunk.tenant_id}"
|
||||
return uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)
|
||||
|
||||
@@ -149,7 +149,6 @@ class Indexable(abc.ABC):
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
fresh_index: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""
|
||||
Takes a list of document chunks and indexes them in the document index
|
||||
@@ -175,8 +174,6 @@ class Indexable(abc.ABC):
|
||||
- chunks: Document chunks with all of the information needed for indexing to the document
|
||||
index.
|
||||
- fresh_index: Boolean indicating whether this is a fresh index with no existing documents.
|
||||
- tenant_id: The tenant id to index the chunks for. If not provided, the chunks will be
|
||||
indexed for the default tenant.
|
||||
|
||||
Returns:
|
||||
List of document ids which map to unique documents and are used for deduping chunks
|
||||
|
||||
@@ -180,8 +180,6 @@ def _get_chunks_via_visit_api(
|
||||
selection += f" and {index_name}.chunk_id<={chunk_request.max_chunk_ind}"
|
||||
if not get_large_chunks:
|
||||
selection += f" and {index_name}.large_chunk_reference_ids == null"
|
||||
if filters.tenant_id:
|
||||
selection += f" and {index_name}.tenant_id == '{filters.tenant_id}'"
|
||||
|
||||
# Setting up the selection criteria in the query parameters
|
||||
params = {
|
||||
@@ -239,7 +237,6 @@ def get_all_vespa_ids_for_document_id(
|
||||
index_name: str,
|
||||
filters: IndexFilters | None = None,
|
||||
get_large_chunks: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
) -> list[str]:
|
||||
document_chunks = _get_chunks_via_visit_api(
|
||||
chunk_request=VespaChunkRequest(document_id=document_id),
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import concurrent.futures
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.document_index.vespa.chunk_retrieval import (
|
||||
@@ -19,16 +18,12 @@ CONTENT_SUMMARY = "content_summary"
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _delete_vespa_doc_chunks(
|
||||
document_id: str,
|
||||
index_name: str,
|
||||
http_client: httpx.Client,
|
||||
tenant_id: str | None = None,
|
||||
document_id: str, index_name: str, http_client: httpx.Client
|
||||
) -> None:
|
||||
doc_chunk_ids = get_all_vespa_ids_for_document_id(
|
||||
document_id=document_id,
|
||||
index_name=index_name,
|
||||
get_large_chunks=True,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
for chunk_id in doc_chunk_ids:
|
||||
@@ -42,13 +37,8 @@ def _delete_vespa_doc_chunks(
|
||||
raise
|
||||
|
||||
|
||||
class VespaDeletionRequest(BaseModel):
|
||||
document_id: str
|
||||
tenant_id: str | None
|
||||
|
||||
|
||||
def delete_vespa_docs(
|
||||
deletion_requests: list[VespaDeletionRequest],
|
||||
document_ids: list[str],
|
||||
index_name: str,
|
||||
http_client: httpx.Client,
|
||||
executor: concurrent.futures.ThreadPoolExecutor | None = None,
|
||||
@@ -62,13 +52,9 @@ def delete_vespa_docs(
|
||||
try:
|
||||
doc_deletion_future = {
|
||||
executor.submit(
|
||||
_delete_vespa_doc_chunks,
|
||||
deletion_request.document_id,
|
||||
index_name,
|
||||
http_client,
|
||||
deletion_request.tenant_id,
|
||||
): deletion_request.document_id
|
||||
for deletion_request in deletion_requests
|
||||
_delete_vespa_doc_chunks, doc_id, index_name, http_client
|
||||
): doc_id
|
||||
for doc_id in document_ids
|
||||
}
|
||||
for future in concurrent.futures.as_completed(doc_deletion_future):
|
||||
# Will raise exception if the deletion raised an exception
|
||||
|
||||
@@ -39,7 +39,6 @@ from onyx.document_index.vespa.chunk_retrieval import (
|
||||
)
|
||||
from onyx.document_index.vespa.chunk_retrieval import query_vespa
|
||||
from onyx.document_index.vespa.deletion import delete_vespa_docs
|
||||
from onyx.document_index.vespa.deletion import VespaDeletionRequest
|
||||
from onyx.document_index.vespa.indexing_utils import batch_index_vespa_chunks
|
||||
from onyx.document_index.vespa.indexing_utils import clean_chunk_id_copy
|
||||
from onyx.document_index.vespa.indexing_utils import (
|
||||
@@ -317,22 +316,17 @@ class VespaIndex(DocumentIndex):
|
||||
# IMPORTANT: This must be done one index at a time, do not use secondary index here
|
||||
cleaned_chunks = [clean_chunk_id_copy(chunk) for chunk in chunks]
|
||||
|
||||
# Build a map from doc_id -> tenant_id (if tenant_id is not None).
|
||||
# This allows us to create VespaDeletionRequest objects with tenant IDs below.
|
||||
doc_id_to_tenant_id = {}
|
||||
for chunk in cleaned_chunks:
|
||||
doc_id_to_tenant_id[chunk.source_document.id] = chunk.tenant_id
|
||||
|
||||
existing_docs: set[str] = set()
|
||||
|
||||
# NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficial for
|
||||
# indexing / updates / deletes since we have to make a large volume of requests.
|
||||
with (
|
||||
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
|
||||
get_vespa_http_client() as http_client,
|
||||
):
|
||||
if not fresh_index:
|
||||
print("Checking for existing documents")
|
||||
|
||||
# Determine which documents already exist in Vespa.
|
||||
# Check for existing documents, existing documents need to have all of their chunks deleted
|
||||
# prior to indexing as the document size (num chunks) may have shrunk
|
||||
first_chunks = [
|
||||
chunk for chunk in cleaned_chunks if chunk.chunk_id == 0
|
||||
]
|
||||
@@ -346,23 +340,14 @@ class VespaIndex(DocumentIndex):
|
||||
)
|
||||
)
|
||||
|
||||
# Pass VespaDeletionRequest objects (document_id and tenant_id) instead of just doc_ids.
|
||||
for doc_id_batch in batch_generator(existing_docs, BATCH_SIZE):
|
||||
deletion_requests = [
|
||||
VespaDeletionRequest(
|
||||
document_id=doc_id,
|
||||
tenant_id=doc_id_to_tenant_id.get(doc_id), # Might be None
|
||||
)
|
||||
for doc_id in doc_id_batch
|
||||
]
|
||||
delete_vespa_docs(
|
||||
deletion_requests=deletion_requests,
|
||||
document_ids=doc_id_batch,
|
||||
index_name=self.index_name,
|
||||
http_client=http_client,
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
# Index the cleaned chunks in batches.
|
||||
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
|
||||
batch_index_vespa_chunks(
|
||||
chunks=chunk_batch,
|
||||
@@ -603,7 +588,7 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
return total_chunks_updated
|
||||
|
||||
def delete(self, doc_ids: list[str], tenant_id: str | None = None) -> None:
|
||||
def delete(self, doc_ids: list[str]) -> None:
|
||||
logger.info(f"Deleting {len(doc_ids)} documents from Vespa")
|
||||
|
||||
doc_ids = [replace_invalid_doc_id_characters(doc_id) for doc_id in doc_ids]
|
||||
@@ -617,10 +602,7 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
for index_name in index_names:
|
||||
delete_vespa_docs(
|
||||
document_ids=doc_ids,
|
||||
index_name=index_name,
|
||||
http_client=http_client,
|
||||
tenant_id=tenant_id,
|
||||
document_ids=doc_ids, index_name=index_name, http_client=http_client
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@@ -55,7 +55,9 @@ def remove_invalid_unicode_chars(text: str) -> str:
|
||||
return _illegal_xml_chars_RE.sub("", text)
|
||||
|
||||
|
||||
def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx.Client:
|
||||
def get_vespa_http_client(
|
||||
no_timeout: bool = False, http2: bool = False
|
||||
) -> httpx.Client:
|
||||
"""
|
||||
Configure and return an HTTP client for communicating with Vespa,
|
||||
including authentication if needed.
|
||||
|
||||
@@ -260,21 +260,6 @@ def index_doc_batch_prepare(
|
||||
def filter_documents(document_batch: list[Document]) -> list[Document]:
|
||||
documents: list[Document] = []
|
||||
for document in document_batch:
|
||||
# Remove any NUL characters from title/semantic_id
|
||||
# This is a known issue with the Zendesk connector
|
||||
# Postgres cannot handle NUL characters in text fields
|
||||
if document.title:
|
||||
document.title = document.title.replace("\x00", "")
|
||||
if document.semantic_identifier:
|
||||
document.semantic_identifier = document.semantic_identifier.replace(
|
||||
"\x00", ""
|
||||
)
|
||||
|
||||
# Remove NUL characters from all sections
|
||||
for section in document.sections:
|
||||
if section.text is not None:
|
||||
section.text = section.text.replace("\x00", "")
|
||||
|
||||
empty_contents = not any(section.text.strip() for section in document.sections)
|
||||
if (
|
||||
(not document.title or not document.title.strip())
|
||||
|
||||
@@ -266,27 +266,18 @@ class DefaultMultiLLM(LLM):
|
||||
# )
|
||||
self._custom_config = custom_config
|
||||
|
||||
# Create a dictionary for model-specific arguments if it's None
|
||||
model_kwargs = model_kwargs or {}
|
||||
|
||||
# NOTE: have to set these as environment variables for Litellm since
|
||||
# not all are able to passed in but they always support them set as env
|
||||
# variables. We'll also try passing them in, since litellm just ignores
|
||||
# addtional kwargs (and some kwargs MUST be passed in rather than set as
|
||||
# env variables)
|
||||
if custom_config:
|
||||
# Specifically pass in "vertex_credentials" as a model_kwarg to the
|
||||
# completion call for vertex AI. More details here:
|
||||
# https://docs.litellm.ai/docs/providers/vertex
|
||||
vertex_credentials_key = "vertex_credentials"
|
||||
vertex_credentials = custom_config.get(vertex_credentials_key)
|
||||
if vertex_credentials and model_provider == "vertex_ai":
|
||||
model_kwargs[vertex_credentials_key] = vertex_credentials
|
||||
else:
|
||||
# standard case
|
||||
for k, v in custom_config.items():
|
||||
os.environ[k] = v
|
||||
for k, v in custom_config.items():
|
||||
os.environ[k] = v
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
if custom_config:
|
||||
model_kwargs.update(custom_config)
|
||||
if extra_headers:
|
||||
model_kwargs.update({"extra_headers": extra_headers})
|
||||
if extra_body:
|
||||
|
||||
@@ -74,9 +74,6 @@ from onyx.server.manage.search_settings import router as search_settings_router
|
||||
from onyx.server.manage.slack_bot import router as slack_bot_management_router
|
||||
from onyx.server.manage.users import router as user_router
|
||||
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
|
||||
from onyx.server.middleware.rate_limiting import close_limiter
|
||||
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
|
||||
from onyx.server.middleware.rate_limiting import setup_limiter
|
||||
from onyx.server.onyx_api.ingestion import router as onyx_api_router
|
||||
from onyx.server.openai_assistants_api.full_openai_assistants_api import (
|
||||
get_full_openai_assistants_api_router,
|
||||
@@ -156,20 +153,6 @@ def include_router_with_global_prefix_prepended(
|
||||
application.include_router(router, **final_kwargs)
|
||||
|
||||
|
||||
def include_auth_router_with_prefix(
|
||||
application: FastAPI, router: APIRouter, prefix: str, tags: list[str] | None = None
|
||||
) -> None:
|
||||
"""Wrapper function to include an 'auth' router with prefix + rate-limiting dependencies."""
|
||||
final_tags = tags or ["auth"]
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
router,
|
||||
prefix=prefix,
|
||||
tags=final_tags,
|
||||
dependencies=get_auth_rate_limiters(),
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
# Set recursion limit
|
||||
@@ -211,15 +194,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
setup_multitenant_onyx()
|
||||
|
||||
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
||||
|
||||
# Set up rate limiter
|
||||
await setup_limiter()
|
||||
|
||||
yield
|
||||
|
||||
# Close rate limiter
|
||||
await close_limiter()
|
||||
|
||||
|
||||
def log_http_error(_: Request, exc: Exception) -> JSONResponse:
|
||||
status_code = getattr(exc, "status_code", 500)
|
||||
@@ -307,37 +283,42 @@ def get_application() -> FastAPI:
|
||||
pass
|
||||
|
||||
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_auth_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_register_router(UserRead, UserCreate),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_reset_password_router(),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_verify_router(UserRead),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_users_router(UserRead, UserUpdate),
|
||||
prefix="/users",
|
||||
tags=["users"],
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
oauth_client,
|
||||
@@ -349,13 +330,15 @@ def get_application() -> FastAPI:
|
||||
redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback",
|
||||
),
|
||||
prefix="/auth/oauth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
# Need basic auth router for `logout` endpoint
|
||||
include_auth_router_with_prefix(
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
fastapi_users.get_logout_router(auth_backend),
|
||||
prefix="/auth",
|
||||
tags=["auth"],
|
||||
)
|
||||
|
||||
application.add_exception_handler(
|
||||
|
||||
@@ -118,7 +118,7 @@ class RedisConnectorIndex:
|
||||
|
||||
The slack in timing is needed to avoid race conditions where simply checking
|
||||
the celery queue and task status could result in race conditions."""
|
||||
self.redis.set(self.active_key, 0, ex=3600)
|
||||
self.redis.set(self.active_key, 0, ex=300)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
@@ -172,9 +172,6 @@ class RedisConnectorIndex:
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorIndex.ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_LOCK_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import functools
|
||||
import threading
|
||||
from collections.abc import Callable
|
||||
@@ -6,7 +5,6 @@ from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
import redis
|
||||
from redis import asyncio as aioredis
|
||||
from redis.client import Redis
|
||||
|
||||
from onyx.configs.app_configs import REDIS_DB_NUMBER
|
||||
@@ -198,33 +196,3 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
|
||||
# redis_client.set('key', 'value')
|
||||
# value = redis_client.get('key')
|
||||
# print(value.decode()) # Output: 'value'
|
||||
|
||||
_async_redis_connection = None
|
||||
_async_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def get_async_redis_connection() -> aioredis.Redis:
|
||||
"""
|
||||
Provides a shared async Redis connection, using the same configs (host, port, SSL, etc.).
|
||||
Ensures that the connection is created only once (lazily) and reused for all future calls.
|
||||
"""
|
||||
global _async_redis_connection
|
||||
|
||||
# If we haven't yet created an async Redis connection, we need to create one
|
||||
if _async_redis_connection is None:
|
||||
# Acquire the lock to ensure that only one coroutine attempts to create the connection
|
||||
async with _async_lock:
|
||||
# Double-check inside the lock to avoid race conditions
|
||||
if _async_redis_connection is None:
|
||||
scheme = "rediss" if REDIS_SSL else "redis"
|
||||
url = f"{scheme}://{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER}"
|
||||
|
||||
# Create a new Redis connection (or connection pool) from the URL
|
||||
_async_redis_connection = aioredis.from_url(
|
||||
url,
|
||||
password=REDIS_PASSWORD,
|
||||
max_connections=REDIS_POOL_MAX_CONNECTIONS,
|
||||
)
|
||||
|
||||
# Return the established connection (or pool) for all future operations
|
||||
return _async_redis_connection
|
||||
|
||||
@@ -216,8 +216,8 @@ def seed_initial_documents(
|
||||
# Retries here because the index may take a few seconds to become ready
|
||||
# as we just sent over the Vespa schema and there is a slight delay
|
||||
|
||||
index_with_retries = retry_builder(tries=15)(document_index.index)
|
||||
index_with_retries(chunks=chunks, fresh_index=True)
|
||||
index_with_retries = retry_builder()(document_index.index)
|
||||
index_with_retries(chunks=chunks, fresh_index=cohere_enabled)
|
||||
|
||||
# Mock a run for the UI even though it did not actually call out to anything
|
||||
mock_successful_index_attempt(
|
||||
|
||||
@@ -5,7 +5,6 @@ from fastapi.dependencies.models import Dependant
|
||||
from starlette.routing import BaseRoute
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
@@ -110,7 +109,6 @@ def check_router_auth(
|
||||
or depends_fn == current_curator_or_admin_user
|
||||
or depends_fn == api_key_dep
|
||||
or depends_fn == current_user_with_expired_token
|
||||
or depends_fn == current_chat_accesssible_user
|
||||
or depends_fn == control_plane_dep
|
||||
or depends_fn == current_cloud_superuser
|
||||
):
|
||||
|
||||
@@ -510,7 +510,7 @@ def associate_credential_to_connector(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[int]:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
"onyx.db.user_group", "validate_user_creation_permissions", None
|
||||
)(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
@@ -532,8 +532,7 @@ def associate_credential_to_connector(
|
||||
)
|
||||
|
||||
return response
|
||||
except IntegrityError as e:
|
||||
logger.error(f"IntegrityError: {e}")
|
||||
except IntegrityError:
|
||||
raise HTTPException(status_code=400, detail="Name must be unique")
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.celery_utils import get_deletion_attempt_snapshot
|
||||
@@ -681,7 +680,7 @@ def create_connector_from_model(
|
||||
_validate_connector_allowed(connector_data.source)
|
||||
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
"onyx.db.user_group", "validate_user_creation_permissions", None
|
||||
)(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
@@ -717,7 +716,7 @@ def create_connector_with_mock_credential(
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
"onyx.db.user_group", "validate_user_creation_permissions", None
|
||||
)(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
@@ -777,7 +776,7 @@ def update_connector_from_model(
|
||||
try:
|
||||
_validate_connector_allowed(connector_data.source)
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
"onyx.db.user_group", "validate_user_creation_permissions", None
|
||||
)(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
@@ -1056,10 +1055,10 @@ class BasicCCPairInfo(BaseModel):
|
||||
|
||||
@router.get("/connector-status")
|
||||
def get_basic_connector_indexing_status(
|
||||
_: User = Depends(current_chat_accesssible_user),
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[BasicCCPairInfo]:
|
||||
cc_pairs = get_connector_credential_pairs(db_session, eager_load_connector=True)
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
return [
|
||||
BasicCCPairInfo(
|
||||
has_successful_run=cc_pair.last_successful_index_time is not None,
|
||||
|
||||
@@ -122,7 +122,7 @@ def create_credential_from_model(
|
||||
) -> ObjectCreationIdResponse:
|
||||
if not _ignore_credential_permissions(credential_info.source):
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
"onyx.db.user_group", "validate_user_creation_permissions", None
|
||||
)(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
@@ -164,12 +164,7 @@ def get_credential_by_id(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CredentialSnapshot | StatusResponse[int]:
|
||||
credential = fetch_credential_by_id(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
get_editable=False,
|
||||
)
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
if credential is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
|
||||
@@ -31,7 +31,7 @@ def create_document_set(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> int:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
"onyx.db.user_group", "validate_user_creation_permissions", None
|
||||
)(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
@@ -56,7 +56,7 @@ def patch_document_set(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
"onyx.db.user_group", "validate_user_creation_permissions", None
|
||||
)(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
|
||||
@@ -10,7 +10,6 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
@@ -324,7 +323,7 @@ def get_image_generation_tool(
|
||||
|
||||
@basic_router.get("")
|
||||
def list_personas(
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
include_deleted: bool = False,
|
||||
persona_ids: list[int] = Query(None),
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import user_needs_to_be_verified
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.server.manage.models import AuthTypeResponse
|
||||
@@ -19,9 +18,7 @@ def healthcheck() -> StatusResponse:
|
||||
@router.get("/auth/type")
|
||||
def get_auth_type() -> AuthTypeResponse:
|
||||
return AuthTypeResponse(
|
||||
auth_type=AUTH_TYPE,
|
||||
requires_verification=user_needs_to_be_verified(),
|
||||
anonymous_user_enabled=anonymous_user_enabled(),
|
||||
auth_type=AUTH_TYPE, requires_verification=user_needs_to_be_verified()
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from fastapi import Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_provider
|
||||
@@ -57,6 +57,7 @@ def test_llm_configuration(
|
||||
)
|
||||
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [(test_llm, (llm,))]
|
||||
|
||||
if (
|
||||
test_llm_request.fast_default_model_name
|
||||
and test_llm_request.fast_default_model_name
|
||||
@@ -189,7 +190,7 @@ def set_provider_as_default(
|
||||
|
||||
@basic_router.get("/provider")
|
||||
def list_llm_provider_basics(
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
return [
|
||||
|
||||
@@ -37,7 +37,6 @@ class AuthTypeResponse(BaseModel):
|
||||
# specifies whether the current auth setup requires
|
||||
# users to have verified emails
|
||||
requires_verification: bool
|
||||
anonymous_user_enabled: bool | None = None
|
||||
|
||||
|
||||
class UserPreferences(BaseModel):
|
||||
@@ -62,7 +61,6 @@ class UserInfo(BaseModel):
|
||||
current_token_expiry_length: int | None = None
|
||||
is_cloud_superuser: bool = False
|
||||
organization_name: str | None = None
|
||||
is_anonymous_user: bool | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -72,7 +70,6 @@ class UserInfo(BaseModel):
|
||||
expiry_length: int | None = None,
|
||||
is_cloud_superuser: bool = False,
|
||||
organization_name: str | None = None,
|
||||
is_anonymous_user: bool | None = None,
|
||||
) -> "UserInfo":
|
||||
return cls(
|
||||
id=str(user.id),
|
||||
@@ -99,7 +96,6 @@ class UserInfo(BaseModel):
|
||||
current_token_created_at=current_token_created_at,
|
||||
current_token_expiry_length=expiry_length,
|
||||
is_cloud_superuser=is_cloud_superuser,
|
||||
is_anonymous_user=is_anonymous_user,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -21,14 +21,12 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import SUPER_USERS
|
||||
from onyx.auth.email_utils import send_user_email_invite
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.noauth_user import fetch_no_auth_user
|
||||
from onyx.auth.noauth_user import set_no_auth_user_preferences
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserStatus
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
@@ -43,8 +41,11 @@ from onyx.db.auth import get_total_users_count
|
||||
from onyx.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona__User
|
||||
from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.db.users import list_users
|
||||
from onyx.db.users import validate_user_role_update
|
||||
@@ -60,6 +61,7 @@ from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import InvitedUserSnapshot
|
||||
from onyx.server.models import MinimalUserSnapshot
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.server.utils import send_user_email_invite
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -368,10 +370,45 @@ async def delete_user(
|
||||
db_session.expunge(user_to_delete)
|
||||
|
||||
try:
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
logger.info(f"Deleted user {user_to_delete.email}")
|
||||
for oauth_account in user_to_delete.oauth_accounts:
|
||||
db_session.delete(oauth_account)
|
||||
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.external_perm",
|
||||
"delete_user__ext_group_for_user__no_commit",
|
||||
)(
|
||||
db_session=db_session,
|
||||
user_id=user_to_delete.id,
|
||||
)
|
||||
db_session.query(SamlAccount).filter(
|
||||
SamlAccount.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(DocumentSet__User).filter(
|
||||
DocumentSet__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.query(User__UserGroup).filter(
|
||||
User__UserGroup.user_id == user_to_delete.id
|
||||
).delete()
|
||||
db_session.delete(user_to_delete)
|
||||
db_session.commit()
|
||||
|
||||
# NOTE: edge case may exist with race conditions
|
||||
# with this `invited user` scheme generally.
|
||||
user_emails = get_invited_users()
|
||||
remaining_users = [
|
||||
user for user in user_emails if user != user_email.user_email
|
||||
]
|
||||
write_invited_users(remaining_users)
|
||||
|
||||
logger.info(f"Deleted user {user_to_delete.email}")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
full_traceback = traceback.format_exc()
|
||||
logger.error(f"Full stack trace:\n{full_traceback}")
|
||||
db_session.rollback()
|
||||
logger.error(f"Error deleting user {user_to_delete.email}: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Error deleting user")
|
||||
@@ -485,15 +522,13 @@ def verify_user_logged_in(
|
||||
# NOTE: this does not use `current_user` / `current_admin_user` because we don't want
|
||||
# to enforce user verification here - the frontend always wants to get the info about
|
||||
# the current user regardless of if they are currently verified
|
||||
|
||||
if user is None:
|
||||
# if auth type is disabled, return a dummy user with preferences from
|
||||
# the key-value store
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
store = get_kv_store()
|
||||
return fetch_no_auth_user(store)
|
||||
if anonymous_user_enabled():
|
||||
store = get_kv_store()
|
||||
return fetch_no_auth_user(store, anonymous_user_enabled=True)
|
||||
|
||||
raise BasicAuthenticationError(detail="User Not Authenticated")
|
||||
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from typing import List
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import Request
|
||||
from fastapi_limiter import FastAPILimiter
|
||||
from fastapi_limiter.depends import RateLimiter
|
||||
|
||||
from onyx.configs.app_configs import RATE_LIMIT_MAX_REQUESTS
|
||||
from onyx.configs.app_configs import RATE_LIMIT_WINDOW_SECONDS
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
|
||||
|
||||
async def setup_limiter() -> None:
|
||||
# Use the centralized async Redis connection
|
||||
redis = await get_async_redis_connection()
|
||||
await FastAPILimiter.init(redis)
|
||||
|
||||
|
||||
async def close_limiter() -> None:
|
||||
# This closes the FastAPILimiter connection so we don't leave open connections to Redis.
|
||||
await FastAPILimiter.close()
|
||||
|
||||
|
||||
async def rate_limit_key(request: Request) -> str:
|
||||
# Uses both IP and User-Agent to make collisions less likely if IP is behind NAT.
|
||||
# If request.client is None, a fallback is used to avoid completely unknown keys.
|
||||
# This helps ensure we have a unique key for each 'user' in simple scenarios.
|
||||
ip_part = request.client.host if request.client else "unknown"
|
||||
ua_part = request.headers.get("user-agent", "none").replace(" ", "_")
|
||||
return f"{ip_part}-{ua_part}"
|
||||
|
||||
|
||||
def get_auth_rate_limiters() -> List[Callable]:
|
||||
if not (RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS):
|
||||
return []
|
||||
|
||||
return [
|
||||
Depends(
|
||||
RateLimiter(
|
||||
times=RATE_LIMIT_MAX_REQUESTS,
|
||||
seconds=RATE_LIMIT_WINDOW_SECONDS,
|
||||
# Use the custom key function to distinguish users
|
||||
identifier=rate_limit_key,
|
||||
)
|
||||
)
|
||||
]
|
||||
629
backend/onyx/server/oauth.py
Normal file
629
backend/onyx/server/oauth.py
Normal file
@@ -0,0 +1,629 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
|
||||
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/oauth")
|
||||
|
||||
|
||||
class SlackOAuth:
|
||||
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
|
||||
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
|
||||
|
||||
# SCOPE is per https://docs.onyx.app/connectors/slack
|
||||
BOT_SCOPE = (
|
||||
"channels:history,"
|
||||
"channels:read,"
|
||||
"groups:history,"
|
||||
"groups:read,"
|
||||
"channels:join,"
|
||||
"im:history,"
|
||||
"users:read,"
|
||||
"users:read.email,"
|
||||
"usergroups:read"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = SlackOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
class ConfluenceCloudOAuth:
|
||||
"""work in progress"""
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
|
||||
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
CONFLUENCE_OAUTH_SCOPE = (
|
||||
"read:confluence-props%20"
|
||||
"read:confluence-content.all%20"
|
||||
"read:confluence-content.summary%20"
|
||||
"read:confluence-content.permission%20"
|
||||
"read:confluence-user%20"
|
||||
"read:confluence-groups%20"
|
||||
"readonly:content.attachment:confluence"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
# eventually for Confluence Data Center
|
||||
# oauth_url = (
|
||||
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# f"&redirect_uri={redirectme_uri}"
|
||||
# )
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
"https://auth.atlassian.com/authorize"
|
||||
f"?audience=api.atlassian.com"
|
||||
f"&client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
f"&state={state}"
|
||||
"&response_type=code"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = ConfluenceCloudOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
class GoogleDriveOAuth:
|
||||
# https://developers.google.com/identity/protocols/oauth2
|
||||
# https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# SCOPE is per https://docs.onyx.app/connectors/google-drive
|
||||
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
|
||||
SCOPE = (
|
||||
"https://www.googleapis.com/auth/drive.readonly%20"
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# without prompt=consent, a refresh token is only issued the first time the user approves
|
||||
url = (
|
||||
f"https://accounts.google.com/o/oauth2/v2/auth"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
"&response_type=code"
|
||||
f"&scope={cls.SCOPE}"
|
||||
"&access_type=offline"
|
||||
f"&state={state}"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = GoogleDriveOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/prepare-authorization-request")
|
||||
def prepare_authorization_request(
|
||||
connector: DocumentSource,
|
||||
redirect_on_success: str | None,
|
||||
user: User = Depends(current_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Used by the frontend to generate the url for the user's browser during auth request.
|
||||
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# urlsafe b64 encode the uuid for the oauth url
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
|
||||
if connector == DocumentSource.SLACK:
|
||||
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
|
||||
session = SlackOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
# elif connector == DocumentSource.CONFLUENCE:
|
||||
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
|
||||
# session = ConfluenceCloudOAuth.session_dump_json(
|
||||
# email=user.email, redirect_on_success=redirect_on_success
|
||||
# )
|
||||
# elif connector == DocumentSource.JIRA:
|
||||
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
|
||||
elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
|
||||
session = GoogleDriveOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
else:
|
||||
oauth_url = None
|
||||
|
||||
if not oauth_url:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"The document source type {connector} does not have OAuth implemented",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# store important session state to retrieve when the user is redirected back
|
||||
# 10 min is the max we want an oauth flow to be valid
|
||||
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
|
||||
|
||||
return JSONResponse(content={"url": oauth_url})
|
||||
|
||||
|
||||
@router.post("/connector/slack/callback")
|
||||
def handle_slack_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Slack client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = SlackOAuth.parse_session(session_json)
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
SlackOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": SlackOAuth.CLIENT_ID,
|
||||
"client_secret": SlackOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": SlackOAuth.REDIRECT_URI,
|
||||
},
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
if not response_data.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed: {response_data.get('error')}",
|
||||
)
|
||||
|
||||
# Extract token and team information
|
||||
access_token: str = response_data.get("access_token")
|
||||
team_id: str = response_data.get("team", {}).get("id")
|
||||
authed_user_id: str = response_data.get("authed_user", {}).get("id")
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={"slack_bot_token": access_token},
|
||||
admin_public=True,
|
||||
source=DocumentSource.SLACK,
|
||||
name="Slack OAuth",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Slack OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Slack OAuth completed successfully.",
|
||||
"team_id": team_id,
|
||||
"authed_user_id": authed_user_id,
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Work in progress
|
||||
# @router.post("/connector/confluence/callback")
|
||||
# def handle_confluence_oauth_callback(
|
||||
# code: str,
|
||||
# state: str,
|
||||
# user: User = Depends(current_user),
|
||||
# db_session: Session = Depends(get_session),
|
||||
# tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
# ) -> JSONResponse:
|
||||
# if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
|
||||
# raise HTTPException(
|
||||
# status_code=500,
|
||||
# detail="Confluence client ID or client secret is not configured."
|
||||
# )
|
||||
|
||||
# r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# # recover the state
|
||||
# padded_state = state + '=' * (-len(state) % 4) # Add padding back (Base64 decoding requires padding)
|
||||
# uuid_bytes = base64.urlsafe_b64decode(padded_state) # Decode the Base64 string back to bytes
|
||||
|
||||
# # Convert bytes back to a UUID
|
||||
# oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
# oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
# result = r.get(r_key)
|
||||
# if not result:
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail=f"Confluence OAuth failed - OAuth state key not found: key={r_key}"
|
||||
# )
|
||||
|
||||
# try:
|
||||
# session = ConfluenceCloudOAuth.parse_session(result)
|
||||
|
||||
# # Exchange the authorization code for an access token
|
||||
# response = requests.post(
|
||||
# ConfluenceCloudOAuth.TOKEN_URL,
|
||||
# headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
# data={
|
||||
# "client_id": ConfluenceCloudOAuth.CLIENT_ID,
|
||||
# "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
|
||||
# "code": code,
|
||||
# "redirect_uri": ConfluenceCloudOAuth.DEV_REDIRECT_URI,
|
||||
# },
|
||||
# )
|
||||
|
||||
# response_data = response.json()
|
||||
|
||||
# if not response_data.get("ok"):
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail=f"ConfluenceCloudOAuth OAuth failed: {response_data.get('error')}"
|
||||
# )
|
||||
|
||||
# # Extract token and team information
|
||||
# access_token: str = response_data.get("access_token")
|
||||
# team_id: str = response_data.get("team", {}).get("id")
|
||||
# authed_user_id: str = response_data.get("authed_user", {}).get("id")
|
||||
|
||||
# credential_info = CredentialBase(
|
||||
# credential_json={"slack_bot_token": access_token},
|
||||
# admin_public=True,
|
||||
# source=DocumentSource.CONFLUENCE,
|
||||
# name="Confluence OAuth",
|
||||
# )
|
||||
|
||||
# logger.info(f"Slack access token: {access_token}")
|
||||
|
||||
# credential = create_credential(credential_info, user, db_session)
|
||||
|
||||
# logger.info(f"new_credential_id={credential.id}")
|
||||
# except Exception as e:
|
||||
# return JSONResponse(
|
||||
# status_code=500,
|
||||
# content={
|
||||
# "success": False,
|
||||
# "message": f"An error occurred during Slack OAuth: {str(e)}",
|
||||
# },
|
||||
# )
|
||||
# finally:
|
||||
# r.delete(r_key)
|
||||
|
||||
# # return the result
|
||||
# return JSONResponse(
|
||||
# content={
|
||||
# "success": True,
|
||||
# "message": "Slack OAuth completed successfully.",
|
||||
# "team_id": team_id,
|
||||
# "authed_user_id": authed_user_id,
|
||||
# "redirect_on_success": session.redirect_on_success,
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
@router.post("/connector/google-drive/callback")
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = GoogleDriveOAuth.parse_session(session_json)
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
GoogleDriveOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": GoogleDriveOAuth.CLIENT_ID,
|
||||
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
authorization_response: dict[str, Any] = response.json()
|
||||
|
||||
# the connector wants us to store the json in its authorized_user_info format
|
||||
# returned from OAuthCredentials.get_authorized_user_info().
|
||||
# So refresh immediately via get_google_oauth_creds with the params filled in
|
||||
# from fields in authorization_response to get the json we need
|
||||
authorized_user_info = {}
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
|
||||
)
|
||||
if not oauth_creds:
|
||||
raise RuntimeError("get_google_oauth_creds returned None.")
|
||||
|
||||
# save off the credentials
|
||||
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
|
||||
credential_dict: dict[str, str] = {}
|
||||
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
|
||||
credential_dict[
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD
|
||||
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
name="OAuth (interactive)",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Google Drive OAuth completed successfully.",
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
@@ -4,7 +4,6 @@ from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import api_key_dep
|
||||
from onyx.configs.constants import DEFAULT_CC_PAIR_ID
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
@@ -80,7 +79,7 @@ def upsert_ingestion_doc(
|
||||
document.source = DocumentSource.FILE
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=doc_info.cc_pair_id or DEFAULT_CC_PAIR_ID, db_session=db_session
|
||||
cc_pair_id=doc_info.cc_pair_id or 0, db_session=db_session
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -19,7 +19,6 @@ from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
@@ -146,7 +145,7 @@ def update_chat_session_model(
|
||||
def get_chat_session(
|
||||
session_id: UUID,
|
||||
is_shared: bool = False,
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionDetailResponse:
|
||||
user_id = user.id if user is not None else None
|
||||
@@ -183,15 +182,12 @@ def get_chat_session(
|
||||
description=chat_session.description,
|
||||
persona_id=chat_session.persona_id,
|
||||
persona_name=chat_session.persona.name if chat_session.persona else None,
|
||||
persona_icon_color=chat_session.persona.icon_color
|
||||
if chat_session.persona
|
||||
else None,
|
||||
persona_icon_shape=chat_session.persona.icon_shape
|
||||
if chat_session.persona
|
||||
else None,
|
||||
current_alternate_model=chat_session.current_alternate_model,
|
||||
messages=[
|
||||
translate_db_message_to_chat_message_detail(msg) for msg in session_messages
|
||||
translate_db_message_to_chat_message_detail(
|
||||
msg, remove_doc_content=is_shared # if shared, don't leak doc content
|
||||
)
|
||||
for msg in session_messages
|
||||
],
|
||||
time_created=chat_session.time_created,
|
||||
shared_status=chat_session.shared_status,
|
||||
@@ -201,7 +197,7 @@ def get_chat_session(
|
||||
@router.post("/create-chat-session")
|
||||
def create_new_chat_session(
|
||||
chat_session_creation_request: ChatSessionCreationRequest,
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateChatSessionID:
|
||||
user_id = user.id if user is not None else None
|
||||
@@ -334,7 +330,7 @@ async def is_connected(request: Request) -> Callable[[], bool]:
|
||||
def handle_new_chat_message(
|
||||
chat_message_req: CreateChatMessageRequest,
|
||||
request: Request,
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_limited_user),
|
||||
_rate_limit_check: None = Depends(check_token_rate_limits),
|
||||
is_connected_func: Callable[[], bool] = Depends(is_connected),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
|
||||
@@ -225,8 +225,6 @@ class ChatSessionDetailResponse(BaseModel):
|
||||
description: str | None
|
||||
persona_id: int | None = None
|
||||
persona_name: str | None
|
||||
persona_icon_color: str | None
|
||||
persona_icon_shape: int | None
|
||||
messages: list[ChatMessageDetail]
|
||||
time_created: datetime
|
||||
shared_status: ChatSessionSharedStatus
|
||||
|
||||
@@ -22,7 +22,6 @@ from onyx.db.chat import get_search_docs_for_chat_message
|
||||
from onyx.db.chat import get_valid_messages_from_query_sessions
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
@@ -49,7 +48,6 @@ def admin_search(
|
||||
question: AdminSearchRequest,
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> AdminSearchResponse:
|
||||
query = question.query
|
||||
logger.notice(f"Received admin search query: {query}")
|
||||
@@ -60,7 +58,6 @@ def admin_search(
|
||||
time_cutoff=question.filters.time_cutoff,
|
||||
tags=question.filters.tags,
|
||||
access_control_list=user_acl_filters,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import ChatMessage
|
||||
@@ -31,7 +31,7 @@ TOKEN_BUDGET_UNIT = 1_000
|
||||
|
||||
|
||||
def check_token_rate_limits(
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_user),
|
||||
) -> None:
|
||||
# short circuit if no rate limits are set up
|
||||
# NOTE: result of `any_rate_limit_exists` is cached, so this call is fast 99% of the time
|
||||
|
||||
@@ -44,7 +44,6 @@ class Settings(BaseModel):
|
||||
maximum_chat_retention_days: int | None = None
|
||||
gpu_enabled: bool | None = None
|
||||
product_gating: GatingType = GatingType.NONE
|
||||
anonymous_user_enabled: bool | None = None
|
||||
|
||||
|
||||
class UserSettings(Settings):
|
||||
|
||||
@@ -1,38 +1,21 @@
|
||||
from typing import cast
|
||||
|
||||
from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.settings.models import Settings
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
def load_settings() -> Settings:
|
||||
if MULTI_TENANT:
|
||||
# If multi-tenant, anonymous user is always false
|
||||
anonymous_user_enabled = False
|
||||
else:
|
||||
redis_client = get_redis_client(tenant_id=None)
|
||||
value = redis_client.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED)
|
||||
if value is not None:
|
||||
assert isinstance(value, bytes)
|
||||
anonymous_user_enabled = int(value.decode("utf-8")) == 1
|
||||
else:
|
||||
# Default to False
|
||||
anonymous_user_enabled = False
|
||||
# Optionally store the default back to Redis
|
||||
redis_client.set(OnyxRedisLocks.ANONYMOUS_USER_ENABLED, "0")
|
||||
dynamic_config_store = get_kv_store()
|
||||
try:
|
||||
settings = Settings(**cast(dict, dynamic_config_store.load(KV_SETTINGS_KEY)))
|
||||
except KvKeyNotFoundError:
|
||||
settings = Settings()
|
||||
dynamic_config_store.store(KV_SETTINGS_KEY, settings.model_dump())
|
||||
|
||||
settings = Settings(anonymous_user_enabled=anonymous_user_enabled)
|
||||
return settings
|
||||
|
||||
|
||||
def store_settings(settings: Settings) -> None:
|
||||
if not MULTI_TENANT and settings.anonymous_user_enabled is not None:
|
||||
# Only non-multi-tenant scenario can set the anonymous user enabled flag
|
||||
redis_client = get_redis_client(tenant_id=None)
|
||||
redis_client.set(
|
||||
OnyxRedisLocks.ANONYMOUS_USER_ENABLED,
|
||||
"1" if settings.anonymous_user_enabled else "0",
|
||||
)
|
||||
|
||||
get_kv_store().store(KV_SETTINGS_KEY, settings.model_dump())
|
||||
|
||||
@@ -1,13 +1,20 @@
|
||||
import json
|
||||
import smtplib
|
||||
from datetime import datetime
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi import status
|
||||
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.configs.app_configs import SMTP_PASS
|
||||
from onyx.configs.app_configs import SMTP_PORT
|
||||
from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
@@ -47,22 +54,39 @@ def mask_string(sensitive_str: str) -> str:
|
||||
def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
|
||||
masked_creds = {}
|
||||
for key, val in credential_dict.items():
|
||||
if isinstance(val, str):
|
||||
# we want to pass the authentication_method field through so the frontend
|
||||
# can disambiguate credentials created by different methods
|
||||
if key == DB_CREDENTIALS_AUTHENTICATION_METHOD:
|
||||
masked_creds[key] = val
|
||||
else:
|
||||
masked_creds[key] = mask_string(val)
|
||||
continue
|
||||
|
||||
if isinstance(val, int):
|
||||
masked_creds[key] = "*****"
|
||||
continue
|
||||
|
||||
raise ValueError(
|
||||
f"Unable to mask credentials of type other than string, cannot process request."
|
||||
f"Recieved type: {type(val)}"
|
||||
)
|
||||
if not isinstance(val, str):
|
||||
raise ValueError(
|
||||
f"Unable to mask credentials of type other than string, cannot process request."
|
||||
f"Recieved type: {type(val)}"
|
||||
)
|
||||
|
||||
masked_creds[key] = mask_string(val)
|
||||
return masked_creds
|
||||
|
||||
|
||||
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
||||
msg = MIMEMultipart()
|
||||
msg["Subject"] = "Invitation to Join Onyx Workspace"
|
||||
msg["From"] = current_user.email
|
||||
msg["To"] = user_email
|
||||
|
||||
email_body = dedent(
|
||||
f"""\
|
||||
Hello,
|
||||
|
||||
You have been invited to join a workspace on Onyx.
|
||||
|
||||
To join the workspace, please visit the following link:
|
||||
|
||||
{WEB_DOMAIN}/auth/login
|
||||
|
||||
Best regards,
|
||||
The Onyx Team
|
||||
"""
|
||||
)
|
||||
|
||||
msg.attach(MIMEText(email_body, "plain"))
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as smtp_server:
|
||||
smtp_server.starttls()
|
||||
smtp_server.login(SMTP_USER, SMTP_PASS)
|
||||
smtp_server.send_message(msg)
|
||||
|
||||
@@ -15,12 +15,11 @@ F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def retry_builder(
|
||||
tries: int = 20,
|
||||
tries: int = 10,
|
||||
delay: float = 0.1,
|
||||
max_delay: float | None = 60,
|
||||
max_delay: float | None = None,
|
||||
backoff: float = 2,
|
||||
jitter: tuple[float, float] | float = 1,
|
||||
exceptions: type[Exception] | tuple[type[Exception], ...] = (Exception,),
|
||||
) -> Callable[[F], F]:
|
||||
"""Builds a generic wrapper/decorator for calls to external APIs that
|
||||
may fail due to rate limiting, flakes, or other reasons. Applies exponential
|
||||
@@ -34,7 +33,6 @@ def retry_builder(
|
||||
backoff=backoff,
|
||||
jitter=jitter,
|
||||
logger=cast(Logger, logger),
|
||||
exceptions=exceptions,
|
||||
)
|
||||
def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@@ -22,6 +22,7 @@ from onyx.utils.variable_functionality import (
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
_DANSWER_TELEMETRY_ENDPOINT = "https://telemetry.onyx.app/anonymous_telemetry"
|
||||
_CACHED_UUID: str | None = None
|
||||
_CACHED_INSTANCE_DOMAIN: str | None = None
|
||||
@@ -117,9 +118,12 @@ def mt_cloud_telemetry(
|
||||
event: MilestoneRecordType,
|
||||
properties: dict | None = None,
|
||||
) -> None:
|
||||
print(f"mt_cloud_telemetry {distinct_id} {event} {properties}")
|
||||
if not MULTI_TENANT:
|
||||
print("mt_cloud_telemetry not MULTI_TENANT")
|
||||
return
|
||||
|
||||
print("mt_cloud_telemetry MULTI_TENANT")
|
||||
# MIT version should not need to include any Posthog code
|
||||
# This is only for Onyx MT Cloud, this code should also never be hit, no reason for any orgs to
|
||||
# be running the Multi Tenant version of Onyx.
|
||||
@@ -137,8 +141,11 @@ def create_milestone_and_report(
|
||||
properties: dict | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
print(f"create_milestone_and_report {user} {event_type} {db_session}")
|
||||
_, is_new = create_milestone_if_not_exists(user, event_type, db_session)
|
||||
print(f"create_milestone_and_report {is_new}")
|
||||
if is_new:
|
||||
print("create_milestone_and_report is_new")
|
||||
mt_cloud_telemetry(
|
||||
distinct_id=distinct_id,
|
||||
event=event_type,
|
||||
|
||||
@@ -23,7 +23,6 @@ httpcore==1.0.5
|
||||
httpx[http2]==0.27.0
|
||||
httpx-oauth==0.15.1
|
||||
huggingface-hub==0.20.1
|
||||
inflection==0.5.1
|
||||
jira==3.5.1
|
||||
jsonref==1.1.0
|
||||
trafilatura==1.12.2
|
||||
@@ -44,7 +43,6 @@ openpyxl==3.1.2
|
||||
playwright==1.41.2
|
||||
psutil==5.9.5
|
||||
psycopg2-binary==2.9.9
|
||||
pyairtable==3.0.1
|
||||
pycryptodome==3.19.1
|
||||
pydantic==2.8.2
|
||||
PyGithub==1.58.2
|
||||
@@ -83,5 +81,4 @@ stripe==10.12.0
|
||||
urllib3==2.2.3
|
||||
mistune==0.8.4
|
||||
sentry-sdk==2.14.0
|
||||
prometheus_client==0.21.0
|
||||
fastapi-limiter==0.1.6
|
||||
prometheus_client==0.21.0
|
||||
@@ -31,7 +31,7 @@ INTENT_MODEL_TAG = "v1.0.3"
|
||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||
|
||||
# Used to distinguish alternative indices
|
||||
ALT_INDEX_SUFFIX = "__danswer_alt_index"
|
||||
ALT_INDEX_SUFFIX = "__onyx_alt_index"
|
||||
|
||||
# Used for loading defaults for automatic deployments and dev flows
|
||||
# For local, use: mixedbread-ai/mxbai-rerank-xsmall-v1
|
||||
|
||||
@@ -1,192 +0,0 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
("table_name", os.environ["AIRTABLE_TEST_TABLE_NAME"]),
|
||||
("table_id", os.environ["AIRTABLE_TEST_TABLE_ID"]),
|
||||
]
|
||||
)
|
||||
def airtable_connector(request: pytest.FixtureRequest) -> AirtableConnector:
|
||||
param_type, table_identifier = request.param
|
||||
connector = AirtableConnector(
|
||||
base_id=os.environ["AIRTABLE_TEST_BASE_ID"],
|
||||
table_name_or_id=table_identifier,
|
||||
)
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
"airtable_access_token": os.environ["AIRTABLE_ACCESS_TOKEN"],
|
||||
}
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
def create_test_document(
|
||||
id: str,
|
||||
title: str,
|
||||
description: str,
|
||||
priority: str,
|
||||
status: str,
|
||||
# Link to another record is skipped for now
|
||||
# category: str,
|
||||
ticket_id: str,
|
||||
created_time: str,
|
||||
status_last_changed: str,
|
||||
submitted_by: str,
|
||||
assignee: str,
|
||||
days_since_status_change: int | None,
|
||||
attachments: list | None = None,
|
||||
) -> Document:
|
||||
link_base = f"https://airtable.com/{os.environ['AIRTABLE_TEST_BASE_ID']}/{os.environ['AIRTABLE_TEST_TABLE_ID']}"
|
||||
sections = [
|
||||
Section(
|
||||
text=f"Title:\n------------------------\n{title}\n------------------------",
|
||||
link=f"{link_base}/{id}",
|
||||
),
|
||||
Section(
|
||||
text=f"Description:\n------------------------\n{description}\n------------------------",
|
||||
link=f"{link_base}/{id}",
|
||||
),
|
||||
]
|
||||
|
||||
if attachments:
|
||||
for attachment in attachments:
|
||||
sections.append(
|
||||
Section(
|
||||
text=f"Attachment:\n------------------------\n{attachment}\n------------------------",
|
||||
link=f"{link_base}/{id}",
|
||||
),
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=f"airtable__{id}",
|
||||
sections=sections,
|
||||
source=DocumentSource.AIRTABLE,
|
||||
semantic_identifier=f"{os.environ['AIRTABLE_TEST_TABLE_NAME']}: {title}",
|
||||
metadata={
|
||||
# "Category": category,
|
||||
"Assignee": assignee,
|
||||
"Submitted by": submitted_by,
|
||||
"Priority": priority,
|
||||
"Status": status,
|
||||
"Created time": created_time,
|
||||
"ID": ticket_id,
|
||||
"Status last changed": status_last_changed,
|
||||
**(
|
||||
{"Days since status change": str(days_since_status_change)}
|
||||
if days_since_status_change is not None
|
||||
else {}
|
||||
),
|
||||
},
|
||||
doc_updated_at=None,
|
||||
primary_owners=None,
|
||||
secondary_owners=None,
|
||||
title=None,
|
||||
from_ingestion_api=False,
|
||||
additional_info=None,
|
||||
)
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
def test_airtable_connector_basic(
|
||||
mock_get_api_key: MagicMock, airtable_connector: AirtableConnector
|
||||
) -> None:
|
||||
doc_batch_generator = airtable_connector.load_from_state()
|
||||
|
||||
doc_batch = next(doc_batch_generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 2
|
||||
|
||||
expected_docs = [
|
||||
create_test_document(
|
||||
id="rec8BnxDLyWeegOuO",
|
||||
title="Slow Internet",
|
||||
description="The internet connection is very slow.",
|
||||
priority="Medium",
|
||||
status="In Progress",
|
||||
# Link to another record is skipped for now
|
||||
# category="Data Science",
|
||||
ticket_id="2",
|
||||
created_time="2024-12-24T21:02:49.000Z",
|
||||
status_last_changed="2024-12-24T21:02:49.000Z",
|
||||
days_since_status_change=0,
|
||||
assignee="Chris Weaver (chris@onyx.app)",
|
||||
submitted_by="Chris Weaver (chris@onyx.app)",
|
||||
),
|
||||
create_test_document(
|
||||
id="reccSlIA4pZEFxPBg",
|
||||
title="Printer Issue",
|
||||
description="The office printer is not working.",
|
||||
priority="High",
|
||||
status="Open",
|
||||
# Link to another record is skipped for now
|
||||
# category="Software Development",
|
||||
ticket_id="1",
|
||||
created_time="2024-12-24T21:02:49.000Z",
|
||||
status_last_changed="2024-12-24T21:02:49.000Z",
|
||||
days_since_status_change=0,
|
||||
assignee="Chris Weaver (chris@onyx.app)",
|
||||
submitted_by="Chris Weaver (chris@onyx.app)",
|
||||
attachments=["Test.pdf:\ntesting!!!"],
|
||||
),
|
||||
]
|
||||
|
||||
# Compare each document field by field
|
||||
for actual, expected in zip(doc_batch, expected_docs):
|
||||
assert actual.id == expected.id, f"ID mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.source == expected.source
|
||||
), f"Source mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.semantic_identifier == expected.semantic_identifier
|
||||
), f"Semantic identifier mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.metadata == expected.metadata
|
||||
), f"Metadata mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.doc_updated_at == expected.doc_updated_at
|
||||
), f"Updated at mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.primary_owners == expected.primary_owners
|
||||
), f"Primary owners mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.secondary_owners == expected.secondary_owners
|
||||
), f"Secondary owners mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.title == expected.title
|
||||
), f"Title mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.from_ingestion_api == expected.from_ingestion_api
|
||||
), f"Ingestion API flag mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.additional_info == expected.additional_info
|
||||
), f"Additional info mismatch for document {actual.id}"
|
||||
|
||||
# Compare sections
|
||||
assert len(actual.sections) == len(
|
||||
expected.sections
|
||||
), f"Number of sections mismatch for document {actual.id}"
|
||||
for i, (actual_section, expected_section) in enumerate(
|
||||
zip(actual.sections, expected.sections)
|
||||
):
|
||||
assert (
|
||||
actual_section.text == expected_section.text
|
||||
), f"Section {i} text mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual_section.link == expected_section.link
|
||||
), f"Section {i} link mismatch for document {actual.id}"
|
||||
@@ -1,42 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.salesforce.connector import SalesforceConnector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def salesforce_connector() -> SalesforceConnector:
|
||||
connector = SalesforceConnector(
|
||||
requested_objects=["Account", "Contact", "Opportunity"],
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"sf_username": os.environ["SF_USERNAME"],
|
||||
"sf_password": os.environ["SF_PASSWORD"],
|
||||
"sf_security_token": os.environ["SF_SECURITY_TOKEN"],
|
||||
}
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
# TODO: make the credentials not expire
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"Credentials change over time, so this test will fail if run when "
|
||||
"the credentials expire."
|
||||
)
|
||||
)
|
||||
def test_salesforce_connector_slim(salesforce_connector: SalesforceConnector) -> None:
|
||||
# Get all doc IDs from the full connector
|
||||
all_full_doc_ids = set()
|
||||
for doc_batch in salesforce_connector.load_from_state():
|
||||
all_full_doc_ids.update([doc.id for doc in doc_batch])
|
||||
|
||||
# Get all doc IDs from the slim connector
|
||||
all_slim_doc_ids = set()
|
||||
for slim_doc_batch in salesforce_connector.retrieve_all_slim_documents():
|
||||
all_slim_doc_ids.update([doc.id for doc in slim_doc_batch])
|
||||
|
||||
# The set of full doc IDs should be always be a subset of the slim doc IDs
|
||||
assert all_full_doc_ids.issubset(all_slim_doc_ids)
|
||||
@@ -1,73 +0,0 @@
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestSettings
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class SettingsManager:
|
||||
@staticmethod
|
||||
def get_settings(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> tuple[Dict[str, Any], str]:
|
||||
headers = (
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
)
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/api/manage/admin/settings",
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
return (
|
||||
{},
|
||||
f"Failed to get settings - {response.json().get('detail', 'Unknown error')}",
|
||||
)
|
||||
|
||||
return response.json(), ""
|
||||
|
||||
@staticmethod
|
||||
def update_settings(
|
||||
settings: DATestSettings,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> tuple[Dict[str, Any], str]:
|
||||
headers = (
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
)
|
||||
headers.pop("Content-Type", None)
|
||||
|
||||
payload = settings.model_dump()
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/api/manage/admin/settings",
|
||||
json=payload,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
return (
|
||||
{},
|
||||
f"Failed to update settings - {response.json().get('detail', 'Unknown error')}",
|
||||
)
|
||||
|
||||
return response.json(), ""
|
||||
|
||||
@staticmethod
|
||||
def get_setting(
|
||||
key: str,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> Optional[Any]:
|
||||
settings, error = SettingsManager.get_settings(user_performing_action)
|
||||
if error:
|
||||
return None
|
||||
return settings.get(key)
|
||||
@@ -1,4 +1,3 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
@@ -151,18 +150,3 @@ class StreamedResponse(BaseModel):
|
||||
relevance_summaries: list[dict[str, Any]] | None = None
|
||||
tool_result: Any | None = None
|
||||
user: str | None = None
|
||||
|
||||
|
||||
class DATestGatingType(str, Enum):
|
||||
FULL = "full"
|
||||
PARTIAL = "partial"
|
||||
NONE = "none"
|
||||
|
||||
|
||||
class DATestSettings(BaseModel):
|
||||
"""General settings"""
|
||||
|
||||
maximum_chat_retention_days: int | None = None
|
||||
gpu_enabled: bool | None = None
|
||||
product_gating: DATestGatingType = DATestGatingType.NONE
|
||||
anonymous_user_enabled: bool | None = None
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
from tests.integration.common_utils.managers.settings import SettingsManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestSettings
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_limited(reset: None) -> None:
|
||||
"""Verify that with a limited role key, limited endpoints are accessible and
|
||||
others are not."""
|
||||
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
SettingsManager.update_settings(DATestSettings(anonymous_user_enabled=True))
|
||||
print(admin_user.headers)
|
||||
@@ -50,6 +50,10 @@ services:
|
||||
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
||||
# if set, allows for the use of the token budget system
|
||||
- TOKEN_BUDGET_GLOBALLY_ENABLED=${TOKEN_BUDGET_GLOBALLY_ENABLED:-}
|
||||
# Enables the use of bedrock models
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-}
|
||||
- AWS_REGION_NAME=${AWS_REGION_NAME:-}
|
||||
# Query Options
|
||||
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
|
||||
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
|
||||
@@ -96,16 +100,14 @@ services:
|
||||
# Chat Configs
|
||||
- HARD_DELETE_CHATS=${HARD_DELETE_CHATS:-}
|
||||
|
||||
# Enables the use of bedrock models or IAM Auth
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-}
|
||||
- AWS_REGION_NAME=${AWS_REGION_NAME:-}
|
||||
|
||||
# Enterprise Edition only
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
|
||||
# Seeding configuration
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
|
||||
- AWS_REGION=${AWS_REGION-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
@@ -230,7 +232,7 @@ services:
|
||||
# Enterprise Edition stuff
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
|
||||
- AWS_REGION_NAME=${AWS_REGION_NAME:-}
|
||||
- AWS_REGION=${AWS_REGION-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
@@ -265,7 +267,7 @@ services:
|
||||
- NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
|
||||
- NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-}
|
||||
- NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN:-}
|
||||
- NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=${NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED:-}
|
||||
|
||||
# Enterprise Edition only
|
||||
- NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-}
|
||||
# DO NOT TURN ON unless you have EXPLICIT PERMISSION from Onyx.
|
||||
|
||||
@@ -44,7 +44,10 @@ services:
|
||||
- LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-}
|
||||
# if set, allows for the use of the token budget system
|
||||
- TOKEN_BUDGET_GLOBALLY_ENABLED=${TOKEN_BUDGET_GLOBALLY_ENABLED:-}
|
||||
|
||||
# Enables the use of bedrock models
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-}
|
||||
- AWS_REGION_NAME=${AWS_REGION_NAME:-}
|
||||
# Query Options
|
||||
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
|
||||
- HYBRID_ALPHA=${HYBRID_ALPHA:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
|
||||
@@ -83,18 +86,16 @@ services:
|
||||
- CELERY_BROKER_POOL_LIMIT=${CELERY_BROKER_POOL_LIMIT:-}
|
||||
- LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS=${LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS:-}
|
||||
|
||||
# Enables the use of bedrock models or IAM Auth
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-}
|
||||
- AWS_REGION_NAME=${AWS_REGION_NAME:-}
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||
|
||||
# Chat Configs
|
||||
- HARD_DELETE_CHATS=${HARD_DELETE_CHATS:-}
|
||||
|
||||
# Enterprise Edition only
|
||||
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||
- AWS_REGION=${AWS_REGION-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
@@ -200,7 +201,7 @@ services:
|
||||
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||
- AWS_REGION_NAME=${AWS_REGION_NAME-}
|
||||
- AWS_REGION=${AWS_REGION-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
|
||||
@@ -72,7 +72,6 @@ services:
|
||||
- NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
|
||||
- NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-}
|
||||
- NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-}
|
||||
- NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=${NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED:-}
|
||||
depends_on:
|
||||
- api_server
|
||||
restart: always
|
||||
|
||||
@@ -23,7 +23,7 @@ services:
|
||||
- REDIS_HOST=cache
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||
- AWS_REGION_NAME=${AWS_REGION_NAME-}
|
||||
- AWS_REGION=${AWS_REGION-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
@@ -60,7 +60,7 @@ services:
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||
- AWS_REGION_NAME=${AWS_REGION_NAME-}
|
||||
- AWS_REGION=${AWS_REGION-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
|
||||
@@ -24,7 +24,7 @@ services:
|
||||
- REDIS_HOST=cache
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||
- AWS_REGION_NAME=${AWS_REGION_NAME-}
|
||||
- AWS_REGION=${AWS_REGION-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
@@ -65,7 +65,7 @@ services:
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH}
|
||||
- AWS_REGION_NAME=${AWS_REGION_NAME-}
|
||||
- AWS_REGION=${AWS_REGION-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
@@ -99,7 +99,6 @@ services:
|
||||
- NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
|
||||
- NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-}
|
||||
- NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-}
|
||||
- NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=${NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED:-}
|
||||
depends_on:
|
||||
- api_server
|
||||
restart: always
|
||||
@@ -238,7 +237,7 @@ services:
|
||||
volumes:
|
||||
- ../data/certbot/conf:/etc/letsencrypt
|
||||
- ../data/certbot/www:/var/www/certbot
|
||||
logging:
|
||||
logging::wq
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
@@ -260,3 +259,6 @@ volumes:
|
||||
# Created by the container itself
|
||||
model_cache_huggingface:
|
||||
indexing_huggingface_model_cache:
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -75,9 +75,6 @@ ENV NEXT_PUBLIC_SENTRY_DSN=${NEXT_PUBLIC_SENTRY_DSN}
|
||||
ARG NEXT_PUBLIC_GTM_ENABLED
|
||||
ENV NEXT_PUBLIC_GTM_ENABLED=${NEXT_PUBLIC_GTM_ENABLED}
|
||||
|
||||
ARG NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED
|
||||
ENV NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=${NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED}
|
||||
|
||||
RUN npx next build
|
||||
|
||||
# Step 2. Production image, copy all the files and run next
|
||||
@@ -153,9 +150,6 @@ ENV NEXT_PUBLIC_SENTRY_DSN=${NEXT_PUBLIC_SENTRY_DSN}
|
||||
ARG NEXT_PUBLIC_GTM_ENABLED
|
||||
ENV NEXT_PUBLIC_GTM_ENABLED=${NEXT_PUBLIC_GTM_ENABLED}
|
||||
|
||||
ARG NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED
|
||||
ENV NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=${NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED}
|
||||
|
||||
# Note: Don't expose ports here, Compose will handle that for us if necessary.
|
||||
# If you want to run this without compose, specify the ports to
|
||||
# expose via cli
|
||||
|
||||
214
web/package-lock.json
generated
214
web/package-lock.json
generated
@@ -40,7 +40,6 @@
|
||||
"favicon-fetch": "^1.0.0",
|
||||
"formik": "^2.2.9",
|
||||
"js-cookie": "^3.0.5",
|
||||
"katex": "^0.16.17",
|
||||
"lodash": "^4.17.21",
|
||||
"lucide-react": "^0.454.0",
|
||||
"mdast-util-find-and-replace": "^3.0.1",
|
||||
@@ -58,10 +57,8 @@
|
||||
"react-markdown": "^9.0.1",
|
||||
"react-select": "^5.8.0",
|
||||
"recharts": "^2.13.1",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"rehype-prism-plus": "^2.0.0",
|
||||
"remark-gfm": "^4.0.0",
|
||||
"remark-math": "^6.0.0",
|
||||
"semver": "^7.5.4",
|
||||
"sharp": "^0.33.5",
|
||||
"stripe": "^17.0.0",
|
||||
@@ -4778,12 +4775,6 @@
|
||||
"integrity": "sha512-dRLjCWHYg4oaA77cxO64oO+7JwCwnIzkZPdrrC71jQmQtlhM556pwKo5bUzqvZndkVbeFLIIi+9TC40JNF5hNQ==",
|
||||
"dev": true
|
||||
},
|
||||
"node_modules/@types/katex": {
|
||||
"version": "0.16.7",
|
||||
"resolved": "https://registry.npmjs.org/@types/katex/-/katex-0.16.7.tgz",
|
||||
"integrity": "sha512-HMwFiRujE5PjrgwHQ25+bsLJgowjGjm5Z8FVSf0N6PwgJrwxH0QxzHYDcKsTfV3wva0vzrpqMTJS2jXPr5BMEQ==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@types/lodash": {
|
||||
"version": "4.17.4",
|
||||
"resolved": "https://registry.npmjs.org/@types/lodash/-/lodash-4.17.4.tgz",
|
||||
@@ -8125,51 +8116,6 @@
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-from-dom": {
|
||||
"version": "5.0.1",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-from-dom/-/hast-util-from-dom-5.0.1.tgz",
|
||||
"integrity": "sha512-N+LqofjR2zuzTjCPzyDUdSshy4Ma6li7p/c3pA78uTwzFgENbgbUrm2ugwsOdcjI1muO+o6Dgzp9p8WHtn/39Q==",
|
||||
"license": "ISC",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"hastscript": "^9.0.0",
|
||||
"web-namespaces": "^2.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-from-dom/node_modules/hast-util-parse-selector": {
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-parse-selector/-/hast-util-parse-selector-4.0.0.tgz",
|
||||
"integrity": "sha512-wkQCkSYoOGCRKERFWcxMVMOcYE2K1AaNLU8DXS9arxnLOUEWbOXKXiJUNzEpqZ3JOKpnha3jkFrumEjVliDe7A==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-from-dom/node_modules/hastscript": {
|
||||
"version": "9.0.0",
|
||||
"resolved": "https://registry.npmjs.org/hastscript/-/hastscript-9.0.0.tgz",
|
||||
"integrity": "sha512-jzaLBGavEDKHrc5EfFImKN7nZKKBdSLIdGvCwDZ9TfzbF2ffXiov8CKE445L2Z1Ek2t/m4SKQ2j6Ipv7NyUolw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"comma-separated-tokens": "^2.0.0",
|
||||
"hast-util-parse-selector": "^4.0.0",
|
||||
"property-information": "^6.0.0",
|
||||
"space-separated-tokens": "^2.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-from-html": {
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-from-html/-/hast-util-from-html-2.0.1.tgz",
|
||||
@@ -8187,22 +8133,6 @@
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-from-html-isomorphic": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-from-html-isomorphic/-/hast-util-from-html-isomorphic-2.0.0.tgz",
|
||||
"integrity": "sha512-zJfpXq44yff2hmE0XmwEOzdWin5xwH+QIhMLOScpX91e/NSGPsAzNCvLQDIEPyO2TXi+lBmU6hjLIhV8MwP2kw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"hast-util-from-dom": "^5.0.0",
|
||||
"hast-util-from-html": "^2.0.0",
|
||||
"unist-util-remove-position": "^5.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-from-parse5": {
|
||||
"version": "8.0.1",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-from-parse5/-/hast-util-from-parse5-8.0.1.tgz",
|
||||
@@ -8250,19 +8180,6 @@
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-is-element": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-is-element/-/hast-util-is-element-3.0.0.tgz",
|
||||
"integrity": "sha512-Val9mnv2IWpLbNPqc/pUem+a7Ipj2aHacCwgNfTiK0vJKl0LF+4Ba4+v1oPHFpf3bLYmreq0/l3Gud9S5OH42g==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-parse-selector": {
|
||||
"version": "3.1.1",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-parse-selector/-/hast-util-parse-selector-3.1.1.tgz",
|
||||
@@ -8326,22 +8243,6 @@
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-to-text": {
|
||||
"version": "4.0.2",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-to-text/-/hast-util-to-text-4.0.2.tgz",
|
||||
"integrity": "sha512-KK6y/BN8lbaq654j7JgBydev7wuNMcID54lkRav1P0CaE1e47P72AWWPiGKXTJU271ooYzcvTAn/Zt0REnvc7A==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"@types/unist": "^3.0.0",
|
||||
"hast-util-is-element": "^3.0.0",
|
||||
"unist-util-find-after": "^5.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-whitespace": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-whitespace/-/hast-util-whitespace-3.0.0.tgz",
|
||||
@@ -9325,31 +9226,6 @@
|
||||
"node": ">=4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/katex": {
|
||||
"version": "0.16.17",
|
||||
"resolved": "https://registry.npmjs.org/katex/-/katex-0.16.17.tgz",
|
||||
"integrity": "sha512-OyzSrXBllz+Jdc9Auiw0kt21gbZ4hkz8Q5srVAb2U9INcYIfGKbxe+bvNvEz1bQ/NrDeRRho5eLCyk/L03maAw==",
|
||||
"funding": [
|
||||
"https://opencollective.com/katex",
|
||||
"https://github.com/sponsors/katex"
|
||||
],
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"commander": "^8.3.0"
|
||||
},
|
||||
"bin": {
|
||||
"katex": "cli.js"
|
||||
}
|
||||
},
|
||||
"node_modules/katex/node_modules/commander": {
|
||||
"version": "8.3.0",
|
||||
"resolved": "https://registry.npmjs.org/commander/-/commander-8.3.0.tgz",
|
||||
"integrity": "sha512-OkTL9umf+He2DZkUq8f8J9of7yL6RJKI24dVITBmNfZBmri9zYZQrKkuXiKhyfPSu8tUhnVBB1iKXevvnlR4Ww==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 12"
|
||||
}
|
||||
},
|
||||
"node_modules/keyv": {
|
||||
"version": "4.5.4",
|
||||
"resolved": "https://registry.npmjs.org/keyv/-/keyv-4.5.4.tgz",
|
||||
@@ -9683,25 +9559,6 @@
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/mdast-util-math": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/mdast-util-math/-/mdast-util-math-3.0.0.tgz",
|
||||
"integrity": "sha512-Tl9GBNeG/AhJnQM221bJR2HPvLOSnLE/T9cJI9tlc6zwQk2nPk/4f0cHkOdEixQPC/j8UtKDdITswvLAy1OZ1w==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"@types/mdast": "^4.0.0",
|
||||
"devlop": "^1.0.0",
|
||||
"longest-streak": "^3.0.0",
|
||||
"mdast-util-from-markdown": "^2.0.0",
|
||||
"mdast-util-to-markdown": "^2.1.0",
|
||||
"unist-util-remove-position": "^5.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/mdast-util-mdx-expression": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/mdast-util-mdx-expression/-/mdast-util-mdx-expression-2.0.0.tgz",
|
||||
@@ -10046,25 +9903,6 @@
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/micromark-extension-math": {
|
||||
"version": "3.1.0",
|
||||
"resolved": "https://registry.npmjs.org/micromark-extension-math/-/micromark-extension-math-3.1.0.tgz",
|
||||
"integrity": "sha512-lvEqd+fHjATVs+2v/8kg9i5Q0AP2k85H0WUOwpIVvUML8BapsMvh1XAogmQjOCsLpoKRCVQqEkQBB3NhVBcsOg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/katex": "^0.16.0",
|
||||
"devlop": "^1.0.0",
|
||||
"katex": "^0.16.0",
|
||||
"micromark-factory-space": "^2.0.0",
|
||||
"micromark-util-character": "^2.0.0",
|
||||
"micromark-util-symbol": "^2.0.0",
|
||||
"micromark-util-types": "^2.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/micromark-factory-destination": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/micromark-factory-destination/-/micromark-factory-destination-2.0.0.tgz",
|
||||
@@ -14149,7 +13987,6 @@
|
||||
"version": "9.0.1",
|
||||
"resolved": "https://registry.npmjs.org/react-markdown/-/react-markdown-9.0.1.tgz",
|
||||
"integrity": "sha512-186Gw/vF1uRkydbsOIkcGXw7aHq0sZOCRFFjGrr7b9+nVZg4UfA4enXCaxm4fUzecU38sWfrNDitGhshuU7rdg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"devlop": "^1.0.0",
|
||||
@@ -14498,25 +14335,6 @@
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/rehype-katex": {
|
||||
"version": "7.0.1",
|
||||
"resolved": "https://registry.npmjs.org/rehype-katex/-/rehype-katex-7.0.1.tgz",
|
||||
"integrity": "sha512-OiM2wrZ/wuhKkigASodFoo8wimG3H12LWQaH8qSPVJn9apWKFSH3YOCtbKpBorTVw/eI7cuT21XBbvwEswbIOA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"@types/katex": "^0.16.0",
|
||||
"hast-util-from-html-isomorphic": "^2.0.0",
|
||||
"hast-util-to-text": "^4.0.0",
|
||||
"katex": "^0.16.0",
|
||||
"unist-util-visit-parents": "^6.0.0",
|
||||
"vfile": "^6.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/rehype-parse": {
|
||||
"version": "9.0.0",
|
||||
"resolved": "https://registry.npmjs.org/rehype-parse/-/rehype-parse-9.0.0.tgz",
|
||||
@@ -14535,7 +14353,6 @@
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/rehype-prism-plus/-/rehype-prism-plus-2.0.0.tgz",
|
||||
"integrity": "sha512-FeM/9V2N7EvDZVdR2dqhAzlw5YI49m9Tgn7ZrYJeYHIahM6gcXpH0K1y2gNnKanZCydOMluJvX2cB9z3lhY8XQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"hast-util-to-string": "^3.0.0",
|
||||
"parse-numeric-range": "^1.3.0",
|
||||
@@ -14559,7 +14376,6 @@
|
||||
"version": "4.0.0",
|
||||
"resolved": "https://registry.npmjs.org/remark-gfm/-/remark-gfm-4.0.0.tgz",
|
||||
"integrity": "sha512-U92vJgBPkbw4Zfu/IiW2oTZLSL3Zpv+uI7My2eq8JxKgqraFdU8YUGicEJCEgSbeaG+QDFqIcwwfMTOEelPxuA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/mdast": "^4.0.0",
|
||||
"mdast-util-gfm": "^3.0.0",
|
||||
@@ -14573,22 +14389,6 @@
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/remark-math": {
|
||||
"version": "6.0.0",
|
||||
"resolved": "https://registry.npmjs.org/remark-math/-/remark-math-6.0.0.tgz",
|
||||
"integrity": "sha512-MMqgnP74Igy+S3WwnhQ7kqGlEerTETXMvJhrUzDikVZ2/uogJCb+WHUg97hK9/jcfc0dkD73s3LN8zU49cTEtA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/mdast": "^4.0.0",
|
||||
"mdast-util-math": "^3.0.0",
|
||||
"micromark-extension-math": "^3.0.0",
|
||||
"unified": "^11.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/remark-parse": {
|
||||
"version": "11.0.0",
|
||||
"resolved": "https://registry.npmjs.org/remark-parse/-/remark-parse-11.0.0.tgz",
|
||||
@@ -15882,20 +15682,6 @@
|
||||
"unist-util-visit-parents": "^6.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/unist-util-find-after": {
|
||||
"version": "5.0.0",
|
||||
"resolved": "https://registry.npmjs.org/unist-util-find-after/-/unist-util-find-after-5.0.0.tgz",
|
||||
"integrity": "sha512-amQa0Ep2m6hE2g72AugUItjbuM8X8cGQnFoHk0pGfrFeT9GZhzN5SW8nRsiGKK7Aif4CrACPENkA6P/Lw6fHGQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/unist": "^3.0.0",
|
||||
"unist-util-is": "^6.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/unist-util-is": {
|
||||
"version": "6.0.0",
|
||||
"resolved": "https://registry.npmjs.org/unist-util-is/-/unist-util-is-6.0.0.tgz",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user