Compare commits

..

73 Commits

Author SHA1 Message Date
Richard Kuo (Onyx)
54e61611c5 prototype for surfacing docs without a query 2025-03-27 16:52:31 -07:00
rkuo-danswer
f08fa878a6 refactor file extension checking and add test for blob s3 (#4369)
* refactor file extension checking and add test for blob s3

* code review

* fix checking ext

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 18:57:44 +00:00
pablonyx
d307534781 add some debug logging (#4328) 2025-03-27 11:49:32 -07:00
rkuo-danswer
6f54791910 adjust some vars in real time (#4365)
* adjust some vars in real time

* some sanity checking

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 17:30:08 +00:00
pablonyx
0d5497bb6b Add multi-tenant user invitation flow test (#4360) 2025-03-27 09:53:15 -07:00
Chris Weaver
7648627503 Save all logs + add log persistence to most Onyx-owned containers (#4368)
* Save all logs + add log persistence to most Onyx-owned containers

* Separate volumes for each container

* Small fixes
2025-03-26 22:25:39 -07:00
pablonyx
927554d5ca slight robustification (#4367) 2025-03-27 03:23:36 +00:00
pablonyx
7dcec6caf5 Fix session touching (#4363)
* fix session touching

* Revert "fix session touching"

This reverts commit c473d5c9a2.

* Revert "Revert "fix session touching""

This reverts commit 26a71d40b6.

* update

* quick nit
2025-03-27 01:18:46 +00:00
rkuo-danswer
036648146d possible fix for confluence query filter (#4280)
* possible fix for confluence query filter

* nuke the attachment filter query ... it doesn't work!

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 00:35:14 +00:00
rkuo-danswer
2aa4697ac8 permission sync runs so often that it starves out other tasks if run at high priority (#4364)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 00:22:53 +00:00
rkuo-danswer
bc9b4e4f45 use slack's built in rate limit handler for the bot (#4362)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-26 21:55:04 +00:00
evan-danswer
178a64f298 fix issue with drive connector service account indexing (#4356)
* fix issue with drive connector service account indexing

* correct checkpoint resumption

* final set of fixes

* nit

* fix typing

* logging and CW comments

* nit
2025-03-26 20:54:26 +00:00
pablonyx
c79f1edf1d add a flush (#4361) 2025-03-26 14:40:52 -07:00
pablonyx
7c8e23aa54 Fix saml conversion from ext_perm -> basic (#4343)
* fix saml conversion from ext_perm -> basic

* quick nit

* minor fix

* finalize

* update

* quick fix
2025-03-26 20:36:51 +00:00
pablonyx
d37b427d52 fix email flow (#4339) 2025-03-26 18:59:12 +00:00
pablonyx
a65fefd226 test fix 2025-03-26 12:43:38 -07:00
rkuo-danswer
bb09bde519 Bugfix/google drive size threshold 2 (#4355) 2025-03-26 12:06:36 -07:00
Tim Rosenblatt
0f6cf0fc58 Fixes docker logs helper text in run-nginx.sh (#3678)
The docker container name is slightly wrong, and this commit fixes it.
2025-03-26 09:03:35 -07:00
pablonyx
fed06b592d Auto refresh credentials (#4268)
* Auto refresh credentials

* remove dupes

* clean up + tests

* k

* quick nit

* add brief comment

* misc typing
2025-03-26 01:53:31 +00:00
pablonyx
8d92a1524e fix invitation on cloud (#4351)
* fix invitation on cloud

* k
2025-03-26 01:25:17 +00:00
pablonyx
ecfea9f5ed Email formatting devices (#4353)
* update email formatting

* k

* update

* k

* nit
2025-03-25 21:42:32 +00:00
rkuo-danswer
b269f1ba06 fix broken function call (#4354)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-25 21:07:31 +00:00
pablonyx
30c878efa5 Quick fix (#4341)
* quick fix

* Revert "quick fix"

This reverts commit f113616276.

* smaller chnage
2025-03-25 18:39:55 +00:00
pablonyx
2024776c19 Respect contextvars when parallelizing for Google Drive (#4291)
* k

* k

* fix typing
2025-03-25 17:40:12 +00:00
pablonyx
431316929c k (#4336) 2025-03-25 17:00:35 +00:00
pablonyx
c5b9c6e308 update (#4344) 2025-03-25 16:56:23 +00:00
pablonyx
73dd188b3f update (#4338) 2025-03-25 16:55:25 +00:00
evan-danswer
79b061abbc Daylight savings time handling (#4345)
* confluence timezone improvements

* confluence timezone improvements
2025-03-25 16:11:30 +00:00
rkuo-danswer
552f1ead4f use correct namespace in redis for certain keys (#4340)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-25 04:10:31 +00:00
evan-danswer
17925b49e8 typing fix (#4342)
* typing fix

* changed type hint to help future coders
2025-03-25 01:01:13 +00:00
rkuo-danswer
55fb5c3ca5 add size threshold for google drive (#4329)
* add size threshold for google drive

* greptile nits

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-24 04:09:28 +00:00
evan-danswer
99546e4a4d zendesk checkpointed connector (#4311)
* zendesk v1

* logic fix

* zendesk testing

* add unit tests

* zendesk caching

* CW comments

* fix unit tests
2025-03-23 20:43:13 +00:00
pablonyx
c25d56f4a5 Improved drive flow UX (#4331)
* wip

* k

* looking good

* clenaed up

* quick nit
2025-03-23 19:21:03 +00:00
Chris Weaver
35f3f4f120 Small slack bot fixes (#4333) 2025-03-22 23:22:17 +00:00
Weves
25b69a8aca Adjust spammy log 2025-03-22 14:52:09 -07:00
pablonyx
1b7d710b2a Fix links from file metadata (#4324)
* quick fix

* clarify comment

* fix file metadata

* k
2025-03-22 18:21:47 +00:00
pablonyx
ae3d3db3f4 Update slack bot listing endpoint (#4325)
* update slack bot listing endpoint

* nit
2025-03-22 18:21:31 +00:00
evan-danswer
fb79a9e700 Checkpointed GitHub connector (#4307)
* WIP github checkpointing

* first draft of github checkpointing

* nit

* CW comments

* github basic connector test

* connector test env var

* secrets cant start with GITHUB_

* unit tests and bug fix

* connector failures

* address CW comments

* validation fix

* validation fix

* remove prints

* fixed tests

* 100 items per page
2025-03-22 01:48:05 +00:00
rkuo-danswer
587ba11bbc alembic script logging fixes (#4322)
* log fixing

* fix typos

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-22 00:50:58 +00:00
pablonyx
fce81ebb60 Minor ux nits (#4327)
* k

* quick fix
2025-03-21 21:50:56 +00:00
Chris Weaver
61facfb0a8 Fix slack connector (#4326) 2025-03-21 21:30:03 +00:00
Chris Weaver
52b96854a2 Handle move errors (#4317)
* Handle move errors

* Make a warning
2025-03-21 11:11:12 -07:00
Chris Weaver
d123713c00 Fix GPU status request in sync flow (#4318)
* Fix GPU status request in sync flow

* tweak

* Fix test

* Fix more tests
2025-03-21 11:11:00 -07:00
Chris Weaver
775c847f82 Reduce drive retries (#4312)
* Reduce drive retries

* timestamp format fix

---------

Co-authored-by: Evan Lohn <evan@danswer.ai>
2025-03-21 00:23:55 +00:00
rkuo-danswer
6d330131fd wire off image downloading for confluence and gdrive if not enabled i… (#4305)
* wire off image downloading for confluence and gdrive if not enabled in settings

* fix partial func

* fix confluence basic test

* add test for skipping/allowing images

* review comments

* skip allow images test

* mock function using the db

* mock at the proper level

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-20 23:10:28 +00:00
Chris Weaver
0292ca2445 Add option to control # of slack threads (#4310) 2025-03-20 16:56:05 +00:00
Weves
15dd1e72ca Remove slack channel validation 2025-03-20 08:34:54 -07:00
Weves
91c9be37c0 Fix loader 2025-03-20 08:30:46 -07:00
Weves
2a01c854a0 Fix cases where the bot is disabled 2025-03-20 08:30:46 -07:00
rkuo-danswer
85ebadc8eb sanitize llm keys and handle updates properly (#4270)
* sanitize llm keys and handle updates properly

* fix llm provider testing

* fix test

* mypy

* fix default model editing

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-03-20 01:13:02 +00:00
Chris Weaver
5dda53eec3 Notion improvement (#4306)
* Notion connector improvements

* Enable recursive index by default

* Small tweak
2025-03-19 23:16:05 +00:00
Chris Weaver
72bf427cc2 Address invalid connector state (#4304)
* Address invalid connector state

* Fixes

* Address mypy

* Address RK comment
2025-03-19 21:15:06 +00:00
Chris Weaver
f421c6010b Checkpointed Jira connector (#4286)
* Checkpointed Jira connector

* nit

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* typing improvements and test fixes

* cleaner typing

* remove default because it is from the future

* mypy

* Address EL comments

---------

Co-authored-by: evan-danswer <evan@danswer.ai>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-03-19 20:41:01 +00:00
rkuo-danswer
0b87549f35 Feature/email whitelabeling (#4260)
* work in progress

* work in progress

* WIP

* refactor, use inline attachment for image (base64 encoding doesn't work)

* pretty sure this belongs behind a multi_tenant check

* code review / refactor

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-19 13:08:44 -07:00
evan-danswer
06624a988d Gdrive checkpointed connector (#4262)
* WIP rebased

* style

* WIP, testing theory

* fix type issue

* fixed filtering bug

* fix silliness

* correct serialization and validation of threadsafedict

* concurrent drive access

* nits

* nit

* oauth bug fix

* testing fix

* fix slim retrieval

* fix integration tests

* fix testing change

* CW comments

* nit

* guarantee completion stage existence

* fix default values
2025-03-19 18:49:35 +00:00
Chris Weaver
ae774105e3 Fix slack connector creation (#4303)
* Make it fail fast + succeed validation if rate limiting is happening

* Add logging + reduce spam
2025-03-19 18:26:49 +00:00
evan-danswer
4dafc3aa6d Update README.md 2025-03-18 21:14:05 -07:00
evan-danswer
5d7d471823 Update README.md
fix bullet points
2025-03-18 19:34:08 -07:00
Weves
61366df34c Add execute permission 2025-03-18 12:03:32 -07:00
Chris Weaver
1a444245f6 Memory tracking script (#4297)
* Add simple container-level memory tracking script
2025-03-18 12:00:09 -07:00
rkuo-danswer
c32d234491 xfail highspot connector tests (#4296)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-18 11:47:17 -07:00
pablonyx
07b68436cf use ONYX_CLOUD_CELERY_TASK_PREFIX for pre provisioning (#4293) 2025-03-18 17:34:22 +00:00
Chris Weaver
293d1a4476 Add process-level memory monitoring (#4294)
* Add process-level memory monitoring

* Switch to every 5 minutes
2025-03-17 22:39:52 -07:00
SubashMohan
ba514aaaa2 Highspot connector (#4277) 2025-03-17 08:36:02 -07:00
Arun Philip
f45798b5dd add overflow-auto to show all content in Modal (#4140) 2025-03-15 11:56:19 -07:00
Weves
64ff5df083 Fix basic auth for non-ee 2025-03-14 11:40:17 -07:00
rkuo-danswer
cf1b7e7a93 add proper boolean validation to field (#4283)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-14 03:38:25 +00:00
Chris Weaver
63692a6bd3 Fix perm sync memory usage (#4282)
* Fix slack perm sync memory usage

* Make perm syncing run in batches rather than fetching everything

* Update backend/ee/onyx/external_permissions/slack/doc_sync.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Update backend/ee/onyx/external_permissions/slack/doc_sync.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Loud error on slack doc sync missing permissions

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-03-14 02:26:22 +00:00
evan-danswer
934700b928 better drive url cleaning (#4247)
* better drive url cleaning

* nit

* address JR comments
2025-03-13 21:16:24 +00:00
Chris Weaver
b1a7cff9e0 Enable claude 3.7 (#4279) 2025-03-13 18:33:06 +00:00
joachim-danswer
463340b8a1 Reduce ranking scores for short chunks without actual information (#4098)
* remove title for slack

* initial working code

* simplification

* improvements

* name change to information_content_model

* avoid boost_score > 1.0

* nit

* EL comments and improvements

Improvements:
  - proper import of information content model from cache or HF
  - warm up for information content model

Other:
  - EL PR review comments

* nit

* requirements version update

* fixed docker file

* new home for model_server configs

* default off

* small updates

* YS comments - pt 1

* renaming to chunk_boost & chunk table def

* saving and deleting chunk stats in new table

* saving and updating chunk stats

* improved dict score update

* create columns for individual boost factors

* RK comments

* Update migration

* manual import reordering
2025-03-13 17:35:45 +00:00
rkuo-danswer
ba82888e1e change max workers to 2 for the moment (#4278)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-13 09:58:24 -07:00
rkuo-danswer
39465d3104 change default build info in dockerfile's to something more obviously source only (#4275)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-03-13 09:42:10 -07:00
178 changed files with 8996 additions and 2169 deletions

View File

@@ -9,6 +9,10 @@ on:
- cron: "0 16 * * *"
env:
# AWS
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
@@ -45,11 +49,16 @@ env:
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
# Github
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
# Gitbook
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
# Notion
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
# Highspot
HIGHSPOT_KEY: ${{ secrets.HIGHSPOT_KEY }}
HIGHSPOT_SECRET: ${{ secrets.HIGHSPOT_SECRET }}
jobs:
connectors-check:

View File

@@ -8,7 +8,7 @@ Edition features outside of personal development or testing purposes. Please rea
founders@onyx.app for more information. Please visit https://github.com/onyx-dot-app/onyx"
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.8-dev
ARG ONYX_VERSION=0.0.0-dev
# DO_NOT_TRACK is used to disable telemetry for Unstructured
ENV ONYX_VERSION=${ONYX_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true" \
@@ -102,6 +102,7 @@ COPY ./alembic /app/alembic
COPY ./alembic_tenants /app/alembic_tenants
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf
COPY ./static /app/static
# Escape hatch scripts
COPY ./scripts/debugging /app/scripts/debugging

View File

@@ -7,7 +7,7 @@ You can find it at https://hub.docker.com/r/onyx/onyx-model-server. For more det
visit https://github.com/onyx-dot-app/onyx."
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.8-dev
ARG ONYX_VERSION=0.0.0-dev
ENV ONYX_VERSION=${ONYX_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"

View File

@@ -84,7 +84,7 @@ keys = console
keys = generic
[logger_root]
level = WARN
level = INFO
handlers = console
qualname =

View File

@@ -25,6 +25,9 @@ from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
from onyx.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be
# hidden! (defaults to level=WARN)
# Alembic Config object
config = context.config
@@ -36,6 +39,7 @@ if config.config_file_name is not None and config.attributes.get(
target_metadata = [Base.metadata, ResultModelBase.metadata]
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
logger = logging.getLogger(__name__)
ssl_context: ssl.SSLContext | None = None
@@ -64,7 +68,7 @@ def include_object(
return True
def get_schema_options() -> tuple[str, bool, bool]:
def get_schema_options() -> tuple[str, bool, bool, bool]:
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
@@ -76,6 +80,10 @@ def get_schema_options() -> tuple[str, bool, bool]:
create_schema = x_args.get("create_schema", "true").lower() == "true"
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
# continue on error with individual tenant
# only applies to online migrations
continue_on_error = x_args.get("continue", "false").lower() == "true"
if (
MULTI_TENANT
and schema_name == POSTGRES_DEFAULT_SCHEMA
@@ -86,14 +94,12 @@ def get_schema_options() -> tuple[str, bool, bool]:
"Please specify a tenant-specific schema."
)
return schema_name, create_schema, upgrade_all_tenants
return schema_name, create_schema, upgrade_all_tenants, continue_on_error
def do_run_migrations(
connection: Connection, schema_name: str, create_schema: bool
) -> None:
logger.info(f"About to migrate schema: {schema_name}")
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
connection.execute(text("COMMIT"))
@@ -134,7 +140,12 @@ def provide_iam_token_for_alembic(
async def run_async_migrations() -> None:
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
(
schema_name,
create_schema,
upgrade_all_tenants,
continue_on_error,
) = get_schema_options()
engine = create_async_engine(
build_connection_string(),
@@ -151,9 +162,15 @@ async def run_async_migrations() -> None:
if upgrade_all_tenants:
tenant_schemas = get_all_tenant_ids()
i_tenant = 0
num_tenants = len(tenant_schemas)
for schema in tenant_schemas:
i_tenant += 1
logger.info(
f"Migrating schema: index={i_tenant} num_tenants={num_tenants} schema={schema}"
)
try:
logger.info(f"Migrating schema: {schema}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
@@ -162,7 +179,12 @@ async def run_async_migrations() -> None:
)
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
raise
if not continue_on_error:
logger.error("--continue is not set, raising exception!")
raise
logger.warning("--continue is set, continuing to next schema.")
else:
try:
logger.info(f"Migrating schema: {schema_name}")
@@ -180,7 +202,11 @@ async def run_async_migrations() -> None:
def run_migrations_offline() -> None:
schema_name, _, upgrade_all_tenants = get_schema_options()
"""This doesn't really get used when we migrate in the cloud."""
logger.info("run_migrations_offline starting.")
schema_name, _, upgrade_all_tenants, continue_on_error = get_schema_options()
url = build_connection_string()
if upgrade_all_tenants:
@@ -230,6 +256,7 @@ def run_migrations_offline() -> None:
def run_migrations_online() -> None:
logger.info("run_migrations_online starting.")
asyncio.run(run_async_migrations())

View File

@@ -28,6 +28,20 @@ depends_on = None
def upgrade() -> None:
# First, drop any existing indexes to avoid conflicts
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
# Drop existing columns if they exist
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
# Create a GIN index for full-text search on chat_message.message
op.execute(
"""

View File

@@ -25,6 +25,10 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
#####
# Auto Permission Sync
#####
DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# In seconds, default is 5 minutes
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
@@ -39,6 +43,7 @@ CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
@@ -72,6 +77,13 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
)
GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("SLACK_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# The posthog client does not accept empty API keys or hosts however it fails silently
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app

View File

@@ -2,6 +2,7 @@
Rules defined here:
https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.html
"""
from collections.abc import Generator
from typing import Any
from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC
@@ -263,13 +264,11 @@ def _fetch_all_page_restrictions(
space_permissions_by_space_key: dict[str, ExternalAccess],
is_cloud: bool,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
) -> Generator[DocExternalAccess, None, None]:
"""
For all pages, if a page has restrictions, then use those restrictions.
Otherwise, use the space's restrictions.
"""
document_restrictions: list[DocExternalAccess] = []
for slim_doc in slim_docs:
if callback:
if callback.should_stop():
@@ -286,11 +285,9 @@ def _fetch_all_page_restrictions(
confluence_client=confluence_client,
perm_sync_data=slim_doc.perm_sync_data,
):
document_restrictions.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=restrictions,
)
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=restrictions,
)
# If there are restrictions, then we don't need to use the space's restrictions
continue
@@ -324,11 +321,9 @@ def _fetch_all_page_restrictions(
continue
# If there are no restrictions, then use the space's restrictions
document_restrictions.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=space_permissions,
)
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=space_permissions,
)
if (
not space_permissions.is_public
@@ -342,13 +337,12 @@ def _fetch_all_page_restrictions(
)
logger.debug("Finished fetching all page restrictions for space")
return document_restrictions
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
) -> Generator[DocExternalAccess, None, None]:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@@ -387,7 +381,7 @@ def confluence_doc_sync(
slim_docs.extend(doc_batch)
logger.debug("Fetching all page restrictions for space")
return _fetch_all_page_restrictions(
yield from _fetch_all_page_restrictions(
confluence_client=confluence_connector.confluence_client,
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,

View File

@@ -1,3 +1,4 @@
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
@@ -34,7 +35,7 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
) -> Generator[DocExternalAccess, None, None]:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@@ -48,7 +49,6 @@ def gmail_doc_sync(
cc_pair, gmail_connector, callback=callback
)
document_external_access: list[DocExternalAccess] = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
@@ -60,17 +60,14 @@ def gmail_doc_sync(
if slim_doc.perm_sync_data is None:
logger.warning(f"No permissions found for document {slim_doc.id}")
continue
if user_email := slim_doc.perm_sync_data.get("user_email"):
ext_access = ExternalAccess(
external_user_emails=set([user_email]),
external_user_group_ids=set(),
is_public=False,
)
document_external_access.append(
DocExternalAccess(
doc_id=slim_doc.id,
external_access=ext_access,
)
yield DocExternalAccess(
doc_id=slim_doc.id,
external_access=ext_access,
)
return document_external_access

View File

@@ -1,3 +1,4 @@
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
@@ -147,7 +148,7 @@ def _get_permissions_from_slim_doc(
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
) -> Generator[DocExternalAccess, None, None]:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@@ -161,7 +162,6 @@ def gdrive_doc_sync(
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)
document_external_accesses = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
@@ -174,10 +174,7 @@ def gdrive_doc_sync(
google_drive_connector=google_drive_connector,
slim_doc=slim_doc,
)
document_external_accesses.append(
DocExternalAccess(
external_access=ext_access,
doc_id=slim_doc.id,
)
yield DocExternalAccess(
external_access=ext_access,
doc_id=slim_doc.id,
)
return document_external_accesses

View File

@@ -1,3 +1,5 @@
from collections.abc import Generator
from slack_sdk import WebClient
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
@@ -14,35 +16,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def _get_slack_document_ids_and_channels(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> dict[str, list[str]]:
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
channel_doc_map: dict[str, list[str]] = {}
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
if channel_id not in channel_doc_map:
channel_doc_map[channel_id] = []
channel_doc_map[channel_id].append(doc_metadata.id)
if callback:
if callback.should_stop():
raise RuntimeError(
"_get_slack_document_ids_and_channels: Stop signal detected"
)
callback.progress("_get_slack_document_ids_and_channels", 1)
return channel_doc_map
def _fetch_workspace_permissions(
user_id_to_email_map: dict[str, str],
) -> ExternalAccess:
@@ -122,10 +95,37 @@ def _fetch_channel_permissions(
return channel_permissions
def _get_slack_document_access(
cc_pair: ConnectorCredentialPair,
channel_permissions: dict[str, ExternalAccess],
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
yield DocExternalAccess(
external_access=channel_permissions[channel_id],
doc_id=doc_metadata.id,
)
if callback:
if callback.should_stop():
raise RuntimeError("_get_slack_document_access: Stop signal detected")
callback.progress("_get_slack_document_access", 1)
def slack_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
) -> Generator[DocExternalAccess, None, None]:
"""
Adds the external permissions to the documents in postgres
if the document doesn't already exists in postgres, we create
@@ -136,9 +136,12 @@ def slack_doc_sync(
token=cc_pair.credential.credential_json["slack_bot_token"]
)
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
channel_doc_map = _get_slack_document_ids_and_channels(
cc_pair=cc_pair, callback=callback
)
if not user_id_to_email_map:
raise ValueError(
"No user id to email map found. Please check to make sure that "
"your Slack bot token has the `users:read.email` scope"
)
workspace_permissions = _fetch_workspace_permissions(
user_id_to_email_map=user_id_to_email_map,
)
@@ -148,18 +151,8 @@ def slack_doc_sync(
user_id_to_email_map=user_id_to_email_map,
)
document_external_accesses = []
for channel_id, ext_access in channel_permissions.items():
doc_ids = channel_doc_map.get(channel_id)
if not doc_ids:
# No documents found for channel the channel_id
continue
for doc_id in doc_ids:
document_external_accesses.append(
DocExternalAccess(
external_access=ext_access,
doc_id=doc_id,
)
)
return document_external_accesses
yield from _get_slack_document_access(
cc_pair=cc_pair,
channel_permissions=channel_permissions,
callback=callback,
)

View File

@@ -1,7 +1,10 @@
from collections.abc import Callable
from collections.abc import Generator
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
@@ -23,7 +26,7 @@ DocSyncFuncType = Callable[
ConnectorCredentialPair,
IndexingHeartbeatInterface | None,
],
list[DocExternalAccess],
Generator[DocExternalAccess, None, None],
]
GroupSyncFuncType = Callable[
@@ -65,13 +68,13 @@ GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all doc permissions every 5 minutes
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
DocumentSource.SLACK: 5 * 60,
DocumentSource.SLACK: SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
}
# If nothing is specified here, we run the doc_sync every time the celery beat runs
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all group permissions every 30 minutes
DocumentSource.GOOGLE_DRIVE: 5 * 60,
DocumentSource.GOOGLE_DRIVE: GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY,
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
}

View File

@@ -64,7 +64,15 @@ def get_application() -> FastAPI:
add_tenant_id_middleware(application, logger)
if AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
# For Google OAuth, refresh tokens are requested by:
# 1. Adding the right scopes
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
oauth_client = GoogleOAuth2(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
# Use standard scopes that include profile and email
scopes=["openid", "email", "profile"],
)
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -87,6 +95,16 @@ def get_application() -> FastAPI:
)
if AUTH_TYPE == AuthType.OIDC:
# Ensure we request offline_access for refresh tokens
try:
oidc_scopes = list(OIDC_SCOPE_OVERRIDE or BASE_SCOPES)
if "offline_access" not in oidc_scopes:
oidc_scopes.append("offline_access")
except Exception as e:
logger.warning(f"Error configuring OIDC scopes: {e}")
# Fall back to default scopes if there's an error
oidc_scopes = BASE_SCOPES
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -94,8 +112,8 @@ def get_application() -> FastAPI:
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
OPENID_CONFIG_URL,
# BASE_SCOPES is the same as not setting this
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
# Use the configured scopes
base_scopes=oidc_scopes,
),
auth_backend,
USER_AUTH_SECRET,

View File

@@ -15,8 +15,8 @@ from sqlalchemy.orm import Session
from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload
from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
from ee.onyx.server.enterprise_settings.store import _LOGO_FILENAME
from ee.onyx.server.enterprise_settings.store import _LOGOTYPE_FILENAME
from ee.onyx.server.enterprise_settings.store import get_logo_filename
from ee.onyx.server.enterprise_settings.store import get_logotype_filename
from ee.onyx.server.enterprise_settings.store import load_analytics_script
from ee.onyx.server.enterprise_settings.store import load_settings
from ee.onyx.server.enterprise_settings.store import store_analytics_script
@@ -28,7 +28,7 @@ from onyx.auth.users import get_user_manager
from onyx.auth.users import UserManager
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.file_store import PostgresBackedFileStore
from onyx.utils.logger import setup_logger
admin_router = APIRouter(prefix="/admin/enterprise-settings")
@@ -131,31 +131,49 @@ def put_logo(
upload_logo(file=file, db_session=db_session, is_logotype=is_logotype)
def fetch_logo_or_logotype(is_logotype: bool, db_session: Session) -> Response:
def fetch_logo_helper(db_session: Session) -> Response:
try:
file_store = get_default_file_store(db_session)
filename = _LOGOTYPE_FILENAME if is_logotype else _LOGO_FILENAME
file_io = file_store.read_file(filename, mode="b")
# NOTE: specifying "image/jpeg" here, but it still works for pngs
# TODO: do this properly
return Response(content=file_io.read(), media_type="image/jpeg")
file_store = PostgresBackedFileStore(db_session)
onyx_file = file_store.get_file_with_mime_type(get_logo_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")
except Exception:
raise HTTPException(
status_code=404,
detail=f"No {'logotype' if is_logotype else 'logo'} file found",
detail="No logo file found",
)
else:
return Response(content=onyx_file.data, media_type=onyx_file.mime_type)
def fetch_logotype_helper(db_session: Session) -> Response:
try:
file_store = PostgresBackedFileStore(db_session)
onyx_file = file_store.get_file_with_mime_type(get_logotype_filename())
if not onyx_file:
raise ValueError("get_onyx_file returned None!")
except Exception:
raise HTTPException(
status_code=404,
detail="No logotype file found",
)
else:
return Response(content=onyx_file.data, media_type=onyx_file.mime_type)
@basic_router.get("/logotype")
def fetch_logotype(db_session: Session = Depends(get_session)) -> Response:
return fetch_logo_or_logotype(is_logotype=True, db_session=db_session)
return fetch_logotype_helper(db_session)
@basic_router.get("/logo")
def fetch_logo(
is_logotype: bool = False, db_session: Session = Depends(get_session)
) -> Response:
return fetch_logo_or_logotype(is_logotype=is_logotype, db_session=db_session)
if is_logotype:
return fetch_logotype_helper(db_session)
return fetch_logo_helper(db_session)
@admin_router.put("/custom-analytics-script")

View File

@@ -13,6 +13,7 @@ from ee.onyx.server.enterprise_settings.models import EnterpriseSettings
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import KV_CUSTOM_ANALYTICS_SCRIPT_KEY
from onyx.configs.constants import KV_ENTERPRISE_SETTINGS_KEY
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
from onyx.file_store.file_store import get_default_file_store
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
@@ -21,8 +22,18 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_LOGO_FILENAME = "__logo__"
_LOGOTYPE_FILENAME = "__logotype__"
def load_settings() -> EnterpriseSettings:
"""Loads settings data directly from DB. This should be used primarily
for checking what is actually in the DB, aka for editing and saving back settings.
Runtime settings actually used by the application should be checked with
load_runtime_settings as defaults may be applied at runtime.
"""
dynamic_config_store = get_kv_store()
try:
settings = EnterpriseSettings(
@@ -36,9 +47,24 @@ def load_settings() -> EnterpriseSettings:
def store_settings(settings: EnterpriseSettings) -> None:
"""Stores settings directly to the kv store / db."""
get_kv_store().store(KV_ENTERPRISE_SETTINGS_KEY, settings.model_dump())
def load_runtime_settings() -> EnterpriseSettings:
"""Loads settings from DB and applies any defaults or transformations for use
at runtime.
Should not be stored back to the DB.
"""
enterprise_settings = load_settings()
if not enterprise_settings.application_name:
enterprise_settings.application_name = ONYX_DEFAULT_APPLICATION_NAME
return enterprise_settings
_CUSTOM_ANALYTICS_SECRET_KEY = os.environ.get("CUSTOM_ANALYTICS_SECRET_KEY")
@@ -60,10 +86,6 @@ def store_analytics_script(analytics_script_upload: AnalyticsScriptUpload) -> No
get_kv_store().store(KV_CUSTOM_ANALYTICS_SCRIPT_KEY, analytics_script_upload.script)
_LOGO_FILENAME = "__logo__"
_LOGOTYPE_FILENAME = "__logotype__"
def is_valid_file_type(filename: str) -> bool:
valid_extensions = (".png", ".jpg", ".jpeg")
return filename.endswith(valid_extensions)
@@ -116,3 +138,11 @@ def upload_logo(
file_type=file_type,
)
return True
def get_logo_filename() -> str:
return _LOGO_FILENAME
def get_logotype_filename() -> str:
return _LOGOTYPE_FILENAME

View File

@@ -36,8 +36,12 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/auth/saml")
# Define non-authenticated user roles that should be re-created during SAML login
NON_AUTHENTICATED_ROLES = {UserRole.SLACK_USER, UserRole.EXT_PERM_USER}
async def upsert_saml_user(email: str) -> User:
logger.debug(f"Attempting to upsert SAML user with email: {email}")
get_async_session_context = contextlib.asynccontextmanager(
get_async_session
) # type:ignore
@@ -48,9 +52,13 @@ async def upsert_saml_user(email: str) -> User:
async with get_user_db_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
try:
return await user_manager.get_by_email(email)
user = await user_manager.get_by_email(email)
# If user has a non-authenticated role, treat as non-existent
if user.role in NON_AUTHENTICATED_ROLES:
raise exceptions.UserNotExists()
return user
except exceptions.UserNotExists:
logger.notice("Creating user from SAML login")
logger.info("Creating user from SAML login")
user_count = await get_user_count()
role = UserRole.ADMIN if user_count == 0 else UserRole.BASIC
@@ -59,11 +67,10 @@ async def upsert_saml_user(email: str) -> User:
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
user: User = await user_manager.create(
user = await user_manager.create(
UserCreate(
email=email,
password=hashed_pass,
is_verified=True,
role=role,
)
)

View File

@@ -87,11 +87,14 @@ async def get_or_provision_tenant(
# If we have a pre-provisioned tenant, assign it to the user
await assign_tenant_to_user(tenant_id, email, referral_source)
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
return tenant_id
else:
# If no pre-provisioned tenant is available, create a new one on-demand
tenant_id = await create_tenant(email, referral_source)
return tenant_id
# Notify control plane if we have created / assigned a new tenant
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)
return tenant_id
except Exception as e:
# If we've encountered an error, log and raise an exception
@@ -116,10 +119,6 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane if not already done in provision_tenant
if not DEV_MODE and referral_source:
await notify_control_plane(tenant_id, email, referral_source)
except Exception as e:
logger.exception(f"Tenant provisioning failed: {str(e)}")
# Attempt to rollback the tenant provisioning
@@ -271,6 +270,7 @@ def configure_default_api_keys(db_session: Session) -> None:
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
display_model_names=["claude-3-5-sonnet-20241022"],
api_key_changed=True,
)
try:
full_provider = upsert_llm_provider(anthropic_provider, db_session)
@@ -283,7 +283,7 @@ def configure_default_api_keys(db_session: Session) -> None:
)
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
openai_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
@@ -291,9 +291,10 @@ def configure_default_api_keys(db_session: Session) -> None:
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
api_key_changed=True,
)
try:
full_provider = upsert_llm_provider(open_provider, db_session)
full_provider = upsert_llm_provider(openai_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure OpenAI provider: {e}")
@@ -559,7 +560,3 @@ async def assign_tenant_to_user(
except Exception:
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
raise Exception("Failed to assign tenant to user")
# Notify control plane with retry logic
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)

View File

@@ -70,6 +70,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
"""
Add users to a tenant with proper transaction handling.
Checks if users already have a tenant mapping to avoid duplicates.
If a user already has an active mapping to any tenant, the new mapping will be added as inactive.
"""
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
@@ -88,9 +89,25 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
.first()
)
# If user already has an active mapping, add this one as inactive
if not existing_mapping:
# Only add if mapping doesn't exist
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
# Check if the user already has an active mapping to any tenant
has_active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.first()
)
db_session.add(
UserTenantMapping(
email=email,
tenant_id=tenant_id,
active=False if has_active_mapping else True,
)
)
# Commit the transaction
db_session.commit()

View File

@@ -65,11 +65,17 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
app.state.gpu_type = gpu_type
if TEMP_HF_CACHE_PATH.is_dir():
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
try:
if TEMP_HF_CACHE_PATH.is_dir():
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
_move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH)
shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True)
logger.notice("Moved contents of temp_huggingface to huggingface cache.")
except Exception as e:
logger.warning(
f"Error moving contents of temp_huggingface to huggingface cache: {e}. "
"This is not a critical error and the model server will continue to run."
)
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.notice(f"Torch Threads: {torch.get_num_threads()}")

View File

@@ -20,7 +20,7 @@ class ExternalAccess:
class DocExternalAccess:
"""
This is just a class to wrap the external access and the document ID
together. It's used for syncing document permissions to Redis.
together. It's used for syncing document permissions to Vespa.
"""
external_access: ExternalAccess

View File

@@ -1,5 +1,6 @@
import smtplib
from datetime import datetime
from email.mime.image import MIMEImage
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.utils import formatdate
@@ -13,8 +14,13 @@ from onyx.configs.app_configs import SMTP_SERVER
from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
from onyx.configs.constants import ONYX_SLACK_URL
from onyx.db.models import User
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.file import FileWithMimeType
from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import MULTI_TENANT
HTML_EMAIL_TEMPLATE = """\
@@ -56,6 +62,11 @@ HTML_EMAIL_TEMPLATE = """\
}}
.header img {{
max-width: 140px;
width: 140px;
height: auto;
filter: brightness(1.1) contrast(1.2);
border-radius: 8px;
padding: 5px;
}}
.body-content {{
padding: 20px 30px;
@@ -72,12 +83,16 @@ HTML_EMAIL_TEMPLATE = """\
}}
.cta-button {{
display: inline-block;
padding: 12px 20px;
background-color: #000000;
padding: 14px 24px;
background-color: #0055FF;
color: #ffffff !important;
text-decoration: none;
border-radius: 4px;
font-weight: 500;
font-weight: 600;
font-size: 16px;
margin-top: 10px;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
text-align: center;
}}
.footer {{
font-size: 13px;
@@ -97,8 +112,8 @@ HTML_EMAIL_TEMPLATE = """\
<td class="header">
<img
style="background-color: #ffffff; border-radius: 8px;"
src="https://www.onyx.app/logos/customer/onyx.png"
alt="Onyx Logo"
src="cid:logo.png"
alt="{application_name} Logo"
>
</td>
</tr>
@@ -113,9 +128,8 @@ HTML_EMAIL_TEMPLATE = """\
</tr>
<tr>
<td class="footer">
© {year} Onyx. All rights reserved.
<br>
Have questions? Join our Slack community <a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA">here</a>.
© {year} {application_name}. All rights reserved.
{slack_fragment}
</td>
</tr>
</table>
@@ -125,17 +139,27 @@ HTML_EMAIL_TEMPLATE = """\
def build_html_email(
heading: str, message: str, cta_text: str | None = None, cta_link: str | None = None
application_name: str | None,
heading: str,
message: str,
cta_text: str | None = None,
cta_link: str | None = None,
) -> str:
slack_fragment = ""
if application_name == ONYX_DEFAULT_APPLICATION_NAME:
slack_fragment = f'<br>Have questions? Join our Slack community <a href="{ONYX_SLACK_URL}">here</a>.'
if cta_text and cta_link:
cta_block = f'<a class="cta-button" href="{cta_link}">{cta_text}</a>'
else:
cta_block = ""
return HTML_EMAIL_TEMPLATE.format(
application_name=application_name,
title=heading,
heading=heading,
message=message,
cta_block=cta_block,
slack_fragment=slack_fragment,
year=datetime.now().year,
)
@@ -146,10 +170,12 @@ def send_email(
html_body: str,
text_body: str,
mail_from: str = EMAIL_FROM,
inline_png: tuple[str, bytes] | None = None,
) -> None:
if not EMAIL_CONFIGURED:
raise ValueError("Email is not configured.")
# Create a multipart/alternative message - this indicates these are alternative versions of the same content
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["To"] = user_email
@@ -158,11 +184,30 @@ def send_email(
msg["Date"] = formatdate(localtime=True)
msg["Message-ID"] = make_msgid(domain="onyx.app")
part_text = MIMEText(text_body, "plain")
part_html = MIMEText(html_body, "html")
# Add text part first (lowest priority)
text_part = MIMEText(text_body, "plain")
msg.attach(text_part)
msg.attach(part_text)
msg.attach(part_html)
if inline_png:
# For HTML with images, create a multipart/related container
related = MIMEMultipart("related")
# Add the HTML part to the related container
html_part = MIMEText(html_body, "html")
related.attach(html_part)
# Add image with proper Content-ID to the related container
img = MIMEImage(inline_png[1], _subtype="png")
img.add_header("Content-ID", f"<{inline_png[0]}>")
img.add_header("Content-Disposition", "inline", filename=inline_png[0])
related.attach(img)
# Add the related part to the message (higher priority than text)
msg.attach(related)
else:
# No images, just add HTML directly (higher priority than text)
html_part = MIMEText(html_body, "html")
msg.attach(html_part)
try:
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
@@ -174,8 +219,21 @@ def send_email(
def send_subscription_cancellation_email(user_email: str) -> None:
"""This is templated but isn't meaningful for whitelabeling."""
# Example usage of the reusable HTML
subject = "Your Onyx Subscription Has Been Canceled"
try:
load_runtime_settings_fn = fetch_versioned_implementation(
"onyx.server.enterprise_settings.store", "load_runtime_settings"
)
settings = load_runtime_settings_fn()
application_name = settings.application_name
except ModuleNotFoundError:
application_name = ONYX_DEFAULT_APPLICATION_NAME
onyx_file = OnyxRuntime.get_emailable_logo()
subject = f"Your {application_name} Subscription Has Been Canceled"
heading = "Subscription Canceled"
message = (
"<p>We're sorry to see you go.</p>"
@@ -184,23 +242,48 @@ def send_subscription_cancellation_email(user_email: str) -> None:
)
cta_text = "Renew Subscription"
cta_link = "https://www.onyx.app/pricing"
html_content = build_html_email(heading, message, cta_text, cta_link)
html_content = build_html_email(
application_name,
heading,
message,
cta_text,
cta_link,
)
text_content = (
"We're sorry to see you go.\n"
"Your subscription has been canceled and will end on your next billing date.\n"
"If you change your mind, visit https://www.onyx.app/pricing"
)
send_email(user_email, subject, html_content, text_content)
send_email(
user_email,
subject,
html_content,
text_content,
inline_png=("logo.png", onyx_file.data),
)
def send_user_email_invite(
user_email: str, current_user: User, auth_type: AuthType
) -> None:
subject = "Invitation to Join Onyx Organization"
onyx_file: FileWithMimeType | None = None
try:
load_runtime_settings_fn = fetch_versioned_implementation(
"onyx.server.enterprise_settings.store", "load_runtime_settings"
)
settings = load_runtime_settings_fn()
application_name = settings.application_name
except ModuleNotFoundError:
application_name = ONYX_DEFAULT_APPLICATION_NAME
onyx_file = OnyxRuntime.get_emailable_logo()
subject = f"Invitation to Join {application_name} Organization"
heading = "You've Been Invited!"
# the exact action taken by the user, and thus the message, depends on the auth type
message = f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
message = f"<p>You have been invited by {current_user.email} to join an organization on {application_name}.</p>"
if auth_type == AuthType.CLOUD:
message += (
"<p>To join the organization, please click the button below to set a password "
@@ -226,19 +309,32 @@ def send_user_email_invite(
cta_text = "Join Organization"
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
html_content = build_html_email(heading, message, cta_text, cta_link)
html_content = build_html_email(
application_name,
heading,
message,
cta_text,
cta_link,
)
# text content is the fallback for clients that don't support HTML
# not as critical, so not having special cases for each auth type
text_content = (
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
f"You have been invited by {current_user.email} to join an organization on {application_name}.\n"
"To join the organization, please visit the following link:\n"
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
)
if auth_type == AuthType.CLOUD:
text_content += "You'll be asked to set a password or login with Google to complete your registration."
send_email(user_email, subject, html_content, text_content)
send_email(
user_email,
subject,
html_content,
text_content,
inline_png=("logo.png", onyx_file.data),
)
def send_forgot_password_email(
@@ -248,27 +344,80 @@ def send_forgot_password_email(
mail_from: str = EMAIL_FROM,
) -> None:
# Builds a forgot password email with or without fancy HTML
subject = "Onyx Forgot Password"
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
if MULTI_TENANT:
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
html_content = build_html_email("Reset Your Password", message)
text_content = f"Click the following link to reset your password: {link}"
send_email(user_email, subject, html_content, text_content, mail_from)
try:
load_runtime_settings_fn = fetch_versioned_implementation(
"onyx.server.enterprise_settings.store", "load_runtime_settings"
)
settings = load_runtime_settings_fn()
application_name = settings.application_name
except ModuleNotFoundError:
application_name = ONYX_DEFAULT_APPLICATION_NAME
onyx_file = OnyxRuntime.get_emailable_logo()
subject = f"Reset Your {application_name} Password"
heading = "Reset Your Password"
tenant_param = f"&tenant={tenant_id}" if tenant_id and MULTI_TENANT else ""
message = "<p>Please click the button below to reset your password. This link will expire in 24 hours.</p>"
cta_text = "Reset Password"
cta_link = f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
html_content = build_html_email(
application_name,
heading,
message,
cta_text,
cta_link,
)
text_content = (
f"Please click the following link to reset your password. This link will expire in 24 hours.\n"
f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
)
send_email(
user_email,
subject,
html_content,
text_content,
mail_from,
inline_png=("logo.png", onyx_file.data),
)
def send_user_verification_email(
user_email: str,
token: str,
new_organization: bool = False,
mail_from: str = EMAIL_FROM,
) -> None:
# Builds a verification email
subject = "Onyx Email Verification"
try:
load_runtime_settings_fn = fetch_versioned_implementation(
"onyx.server.enterprise_settings.store", "load_runtime_settings"
)
settings = load_runtime_settings_fn()
application_name = settings.application_name
except ModuleNotFoundError:
application_name = ONYX_DEFAULT_APPLICATION_NAME
onyx_file = OnyxRuntime.get_emailable_logo()
subject = f"{application_name} Email Verification"
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
if new_organization:
link = add_url_params(link, {"first_user": "true"})
message = (
f"<p>Click the following link to verify your email address:</p><p>{link}</p>"
)
html_content = build_html_email("Verify Your Email", message)
html_content = build_html_email(
application_name,
"Verify Your Email",
message,
)
text_content = f"Click the following link to verify your email address: {link}"
send_email(user_email, subject, html_content, text_content, mail_from)
send_email(
user_email,
subject,
html_content,
text_content,
mail_from,
inline_png=("logo.png", onyx_file.data),
)

View File

@@ -0,0 +1,211 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
import httpx
from fastapi_users.manager import BaseUserManager
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.configs.app_configs import OAUTH_CLIENT_ID
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
from onyx.db.models import OAuthAccount
from onyx.db.models import User
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Standard OAuth refresh token endpoints
REFRESH_ENDPOINTS = {
"google": "https://oauth2.googleapis.com/token",
}
# NOTE: Keeping this as a utility function for potential future debugging,
# but not using it in production code
async def _test_expire_oauth_token(
user: User,
oauth_account: OAuthAccount,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
expire_in_seconds: int = 10,
) -> bool:
"""
Utility function for testing - Sets an OAuth token to expire in a short time
to facilitate testing of the refresh flow.
Not used in production code.
"""
try:
new_expires_at = int(
(datetime.now(timezone.utc).timestamp() + expire_in_seconds)
)
updated_data: Dict[str, Any] = {"expires_at": new_expires_at}
await user_manager.user_db.update_oauth_account(
user, cast(Any, oauth_account), updated_data
)
return True
except Exception as e:
logger.exception(f"Error setting artificial expiration: {str(e)}")
return False
async def refresh_oauth_token(
user: User,
oauth_account: OAuthAccount,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
) -> bool:
"""
Attempt to refresh an OAuth token that's about to expire or has expired.
Returns True if successful, False otherwise.
"""
if not oauth_account.refresh_token:
logger.warning(
f"No refresh token available for {user.email}'s {oauth_account.oauth_name} account"
)
return False
provider = oauth_account.oauth_name
if provider not in REFRESH_ENDPOINTS:
logger.warning(f"Refresh endpoint not configured for provider: {provider}")
return False
try:
logger.info(f"Refreshing OAuth token for {user.email}'s {provider} account")
async with httpx.AsyncClient() as client:
response = await client.post(
REFRESH_ENDPOINTS[provider],
data={
"client_id": OAUTH_CLIENT_ID,
"client_secret": OAUTH_CLIENT_SECRET,
"refresh_token": oauth_account.refresh_token,
"grant_type": "refresh_token",
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
if response.status_code != 200:
logger.error(
f"Failed to refresh OAuth token: Status {response.status_code}"
)
return False
token_data = response.json()
new_access_token = token_data.get("access_token")
new_refresh_token = token_data.get(
"refresh_token", oauth_account.refresh_token
)
expires_in = token_data.get("expires_in")
# Calculate new expiry time if provided
new_expires_at: Optional[int] = None
if expires_in:
new_expires_at = int(
(datetime.now(timezone.utc).timestamp() + expires_in)
)
# Update the OAuth account
updated_data: Dict[str, Any] = {
"access_token": new_access_token,
"refresh_token": new_refresh_token,
}
if new_expires_at:
updated_data["expires_at"] = new_expires_at
# Update oidc_expiry in user model if we're tracking it
if TRACK_EXTERNAL_IDP_EXPIRY:
oidc_expiry = datetime.fromtimestamp(
new_expires_at, tz=timezone.utc
)
await user_manager.user_db.update(
user, {"oidc_expiry": oidc_expiry}
)
# Update the OAuth account
await user_manager.user_db.update_oauth_account(
user, cast(Any, oauth_account), updated_data
)
logger.info(f"Successfully refreshed OAuth token for {user.email}")
return True
except Exception as e:
logger.exception(f"Error refreshing OAuth token: {str(e)}")
return False
async def check_and_refresh_oauth_tokens(
user: User,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
) -> None:
"""
Check if any OAuth tokens are expired or about to expire and refresh them.
"""
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
return
now_timestamp = datetime.now(timezone.utc).timestamp()
# Buffer time to refresh tokens before they expire (in seconds)
buffer_seconds = 300 # 5 minutes
for oauth_account in user.oauth_accounts:
# Skip accounts without refresh tokens
if not oauth_account.refresh_token:
continue
# If token is about to expire, refresh it
if (
oauth_account.expires_at
and oauth_account.expires_at - now_timestamp < buffer_seconds
):
logger.info(f"OAuth token for {user.email} is about to expire - refreshing")
success = await refresh_oauth_token(
user, oauth_account, db_session, user_manager
)
if not success:
logger.warning(
"Failed to refresh OAuth token. User may need to re-authenticate."
)
async def check_oauth_account_has_refresh_token(
user: User,
oauth_account: OAuthAccount,
) -> bool:
"""
Check if an OAuth account has a refresh token.
Returns True if a refresh token exists, False otherwise.
"""
return bool(oauth_account.refresh_token)
async def get_oauth_accounts_requiring_refresh_token(user: User) -> List[OAuthAccount]:
"""
Returns a list of OAuth accounts for a user that are missing refresh tokens.
These accounts will need re-authentication to get refresh tokens.
"""
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
return []
accounts_needing_refresh = []
for oauth_account in user.oauth_accounts:
has_refresh_token = await check_oauth_account_has_refresh_token(
user, oauth_account
)
if not has_refresh_token:
accounts_needing_refresh.append(oauth_account)
return accounts_needing_refresh

View File

@@ -5,12 +5,16 @@ import string
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
from typing import Protocol
from typing import Tuple
from typing import TypeVar
import jwt
from email_validator import EmailNotValidError
@@ -105,6 +109,7 @@ from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
@@ -580,8 +585,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
logger.notice(
f"Verification requested for user {user.id}. Verification token: {token}"
)
send_user_verification_email(user.email, token)
user_count = await get_user_count()
send_user_verification_email(
user.email, token, new_organization=user_count == 1
)
async def authenticate(
self, credentials: OAuth2PasswordRequestForm
@@ -593,7 +600,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_tenant_id_for_email",
None,
POSTGRES_DEFAULT_SCHEMA,
)(
email=email,
)
@@ -687,16 +694,20 @@ cookie_transport = CookieTransport(
)
def get_redis_strategy() -> RedisStrategy:
return TenantAwareRedisStrategy()
T = TypeVar("T", covariant=True)
ID = TypeVar("ID", contravariant=True)
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
return DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
)
# Protocol for strategies that support token refreshing without inheritance.
class RefreshableStrategy(Protocol):
"""Protocol for authentication strategies that support token refreshing."""
async def refresh_token(self, token: Optional[str], user: Any) -> str:
"""
Refresh an existing token by extending its lifetime.
Returns either the same token with extended expiration or a new token.
"""
...
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
@@ -755,6 +766,75 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
redis = await get_async_redis_connection()
await redis.delete(f"{self.key_prefix}{token}")
async def refresh_token(self, token: Optional[str], user: User) -> str:
"""Refresh a token by extending its expiration time in Redis."""
if token is None:
# If no token provided, create a new one
return await self.write_token(user)
redis = await get_async_redis_connection()
token_key = f"{self.key_prefix}{token}"
# Check if token exists
token_data_str = await redis.get(token_key)
if not token_data_str:
# Token not found, create new one
return await self.write_token(user)
# Token exists, extend its lifetime
token_data = json.loads(token_data_str)
await redis.set(
token_key,
json.dumps(token_data),
ex=self.lifetime_seconds,
)
return token
class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]):
"""Database strategy with token refreshing capabilities."""
def __init__(
self,
access_token_db: AccessTokenDatabase[AccessToken],
lifetime_seconds: Optional[int] = None,
):
super().__init__(access_token_db, lifetime_seconds)
self._access_token_db = access_token_db
async def refresh_token(self, token: Optional[str], user: User) -> str:
"""Refresh a token by updating its expiration time in the database."""
if token is None:
return await self.write_token(user)
# Find the token in database
access_token = await self._access_token_db.get_by_token(token)
if access_token is None:
# Token not found, create new one
return await self.write_token(user)
# Update expiration time
new_expires = datetime.now(timezone.utc) + timedelta(
seconds=float(self.lifetime_seconds or SESSION_EXPIRE_TIME_SECONDS)
)
await self._access_token_db.update(access_token, {"expires": new_expires})
return token
def get_redis_strategy() -> TenantAwareRedisStrategy:
return TenantAwareRedisStrategy()
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> RefreshableDatabaseStrategy:
return RefreshableDatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
)
if AUTH_BACKEND == AuthBackend.REDIS:
auth_backend = AuthenticationBackend(
@@ -805,6 +885,88 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
return router
def get_refresh_router(
self,
backend: AuthenticationBackend,
requires_verification: bool = REQUIRE_EMAIL_VERIFICATION,
) -> APIRouter:
"""
Provide a router for session token refreshing.
"""
# Import the oauth_refresher here to avoid circular imports
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
router = APIRouter()
get_current_user_token = self.authenticator.current_user_token(
active=True, verified=requires_verification
)
refresh_responses: OpenAPIResponseType = {
**{
status.HTTP_401_UNAUTHORIZED: {
"description": "Missing token or inactive user."
}
},
**backend.transport.get_openapi_login_responses_success(),
}
@router.post(
"/refresh", name=f"auth:{backend.name}.refresh", responses=refresh_responses
)
async def refresh(
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(
get_user_manager
),
db_session: AsyncSession = Depends(get_async_session),
) -> Response:
try:
user, token = user_token
logger.info(f"Processing token refresh request for user {user.email}")
# Check if user has OAuth accounts that need refreshing
await check_and_refresh_oauth_tokens(
user=cast(User, user),
db_session=db_session,
user_manager=cast(Any, user_manager),
)
# Check if strategy supports refreshing
supports_refresh = hasattr(strategy, "refresh_token") and callable(
getattr(strategy, "refresh_token")
)
if supports_refresh:
try:
refresh_method = getattr(strategy, "refresh_token")
new_token = await refresh_method(token, user)
logger.info(
f"Successfully refreshed session token for user {user.email}"
)
return await backend.transport.get_login_response(new_token)
except Exception as e:
logger.error(f"Error refreshing session token: {str(e)}")
# Fallback to logout and login if refresh fails
await backend.logout(strategy, user, token)
return await backend.login(strategy, user)
# Fallback: logout and login again
logger.info(
"Strategy doesn't support refresh - using logout/login flow"
)
await backend.logout(strategy, user, token)
return await backend.login(strategy, user)
except Exception as e:
logger.error(f"Unexpected error in refresh endpoint: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Token refresh failed: {str(e)}",
)
return router
fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
get_user_manager, [auth_backend]
@@ -1038,12 +1200,20 @@ def get_oauth_router(
"referral_source": referral_source or "default_referral",
}
state = generate_state_token(state_data, state_secret)
# Get the basic authorization URL
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
state,
scopes,
)
# For Google OAuth, add parameters to request refresh tokens
if oauth_client.name == "google":
authorization_url = add_url_params(
authorization_url, {"access_type": "offline", "prompt": "consent"}
)
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
@router.get(

View File

@@ -34,7 +34,6 @@ from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGrou
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import ColoredFormatter
from onyx.utils.logger import PlainFormatter
@@ -225,7 +224,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
Will raise WorkerShutdown to kill the celery worker if the timeout
is reached."""
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
@@ -311,7 +310,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
time_start = time.monotonic()
logger.info("Waiting for primary worker to be ready...")

View File

@@ -1,6 +1,5 @@
from datetime import timedelta
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
@@ -10,12 +9,10 @@ from celery.utils.log import get_task_logger
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import SqlEngine
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
@@ -141,8 +138,6 @@ class DynamicTenantScheduler(PersistentScheduler):
"""Only updates the actual beat schedule on the celery app when it changes"""
do_update = False
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
task_logger.debug("_try_updating_schedule starting")
tenant_ids = get_all_tenant_ids()
@@ -152,16 +147,7 @@ class DynamicTenantScheduler(PersistentScheduler):
current_schedule = self.schedule.items()
# get potential new state
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
if beat_multiplier_raw is not None:
try:
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
beat_multiplier = float(beat_multiplier_bytes.decode())
except ValueError:
task_logger.error(
f"Invalid beat_multiplier value: {beat_multiplier_raw}"
)
beat_multiplier = OnyxRuntime.get_beat_multiplier()
new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)

View File

@@ -38,10 +38,11 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@@ -102,7 +103,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
# This is singleton work that should be done on startup exactly once
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
info: dict[str, Any] = cast(dict, r.info("replication"))
@@ -235,7 +236,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
lock: RedisLock = worker.primary_worker_lock
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")

View File

@@ -14,7 +14,7 @@ logger = setup_logger()
# Only set up memory monitoring in container environment
if is_running_in_container():
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_DIR = "/var/log/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files

View File

@@ -21,6 +21,7 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# we have a better implementation (backpressure, etc)
# Note that DynamicTenantScheduler can adjust the runtime value for this via Redis
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT = 1.0
# tasks that run in either self-hosted on cloud
beat_task_templates: list[dict] = []
@@ -194,6 +195,16 @@ if not MULTI_TENANT:
"queue": OnyxCeleryQueues.MONITORING,
},
},
{
"name": "monitor-process-memory",
"task": OnyxCeleryTask.MONITOR_PROCESS_MEMORY,
"schedule": timedelta(minutes=5),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
"queue": OnyxCeleryQueues.MONITORING,
},
},
]
)

View File

@@ -30,6 +30,9 @@ from onyx.db.connector_credential_pair import (
)
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import (
delete_all_documents_by_connector_credential_pair__no_commit,
)
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from onyx.db.engine import get_session_with_current_tenant
@@ -440,6 +443,16 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
)
# Explicitly delete document by connector credential pair records before deleting the connector
# This is needed because connector_id is a primary key in that table and cascading deletes won't work
delete_all_documents_by_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
db_session.flush()
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,

View File

@@ -17,6 +17,7 @@ from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.document import upsert_document_external_perms
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
@@ -46,7 +47,6 @@ from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
@@ -64,11 +64,14 @@ from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyn
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.server.utils import make_short_id
from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
logger = setup_logger()
@@ -105,9 +108,10 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
if not source_sync_period:
return True
source_sync_period = DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
source_sync_period *= int(OnyxRuntime.get_doc_permission_sync_multiplier())
# If the last sync is greater than the full fetch period, we run the sync
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
@@ -285,7 +289,7 @@ def try_creating_permissions_sync_task(
),
queue=OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.HIGH,
priority=OnyxCeleryPriority.MEDIUM,
)
# fill in the celery task id
@@ -420,12 +424,7 @@ def connector_permission_sync_generator_task(
task_logger.exception(
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
)
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
# TODO: add some notification to the admins here
raise
source_type = cc_pair.connector.source
@@ -453,23 +452,23 @@ def connector_permission_sync_generator_task(
redis_connector.permissions.set_fence(new_payload)
callback = PermissionSyncCallback(redis_connector, lock, r)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(
cc_pair, callback
)
document_external_accesses = doc_sync_func(cc_pair, callback)
task_logger.info(
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
)
tasks_generated = redis_connector.permissions.generate_tasks(
celery_app=self.app,
lock=lock,
new_permissions=document_external_accesses,
source_string=source_type,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
if tasks_generated is None:
return None
tasks_generated = 0
for doc_external_access in document_external_accesses:
redis_connector.permissions.generate_tasks(
celery_app=self.app,
lock=lock,
new_permissions=[doc_external_access],
source_string=source_type,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
tasks_generated += 1
task_logger.info(
f"RedisConnector.permissions.generate_tasks finished. "
@@ -881,6 +880,21 @@ def monitor_ccpair_permissions_taskset(
f"remaining={remaining} "
f"initial={initial}"
)
# Add telemetry for permission syncing progress
optional_telemetry(
record_type=RecordType.PERMISSION_SYNC_PROGRESS,
data={
"cc_pair_id": cc_pair_id,
"id": payload.id if payload else None,
"total_docs": initial if initial is not None else 0,
"remaining_docs": remaining,
"synced_docs": (initial - remaining) if initial is not None else 0,
"is_complete": remaining == 0,
},
tenant_id=tenant_id,
)
if remaining > 0:
return

View File

@@ -41,7 +41,6 @@ from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_external_group_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
@@ -272,7 +271,7 @@ def try_creating_external_group_sync_task(
),
queue=OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.HIGH,
priority=OnyxCeleryPriority.MEDIUM,
)
payload.celery_task_id = result.id
@@ -402,12 +401,7 @@ def connector_external_group_sync_generator_task(
task_logger.exception(
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
)
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
# TODO: add some notification to the admins here
raise
source_type = cc_pair.connector.source
@@ -425,12 +419,9 @@ def connector_external_group_sync_generator_task(
try:
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
except ConnectorValidationError as e:
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
# TODO: add some notification to the admins here
logger.exception(
f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
)
raise e

View File

@@ -72,6 +72,7 @@ from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_utils import is_fence
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
@@ -401,7 +402,11 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
logger.warning(f"Adding {key_bytes} to the lookup table.")
redis_client.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
redis_client.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
redis_client.set(
OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE,
1,
ex=OnyxRuntime.get_build_fence_lookup_table_interval(),
)
# 1/3: KICKOFF

View File

@@ -6,6 +6,7 @@ from itertools import islice
from typing import Any
from typing import Literal
import psutil
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
@@ -19,6 +20,7 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import OnyxCeleryQueues
@@ -39,8 +41,10 @@ from onyx.db.models import UserGroup
from onyx.db.search_settings import get_active_search_settings_list
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.utils.logger import is_running_in_container
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
@@ -904,3 +908,93 @@ def monitor_celery_queues_helper(
f"external_group_sync={n_external_group_sync} "
f"permissions_upsert={n_permissions_upsert} "
)
"""Memory monitoring"""
def _get_cmdline_for_process(process: psutil.Process) -> str | None:
try:
return " ".join(process.cmdline())
except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess):
return None
@shared_task(
name=OnyxCeleryTask.MONITOR_PROCESS_MEMORY,
ignore_result=True,
soft_time_limit=_MONITORING_SOFT_TIME_LIMIT,
time_limit=_MONITORING_TIME_LIMIT,
queue=OnyxCeleryQueues.MONITORING,
bind=True,
)
def monitor_process_memory(self: Task, *, tenant_id: str) -> None:
"""
Task to monitor memory usage of supervisor-managed processes.
This periodically checks the memory usage of processes and logs information
in a standardized format.
The task looks for processes managed by supervisor and logs their
memory usage statistics. This is useful for monitoring memory consumption
over time and identifying potential memory leaks.
"""
# don't run this task in multi-tenant mode, have other, better means of monitoring
if MULTI_TENANT:
return
# Skip memory monitoring if not in container
if not is_running_in_container():
return
try:
# Get all supervisor-managed processes
supervisor_processes: dict[int, str] = {}
# Map cmd line elements to more readable process names
process_type_mapping = {
"--hostname=primary": "primary",
"--hostname=light": "light",
"--hostname=heavy": "heavy",
"--hostname=indexing": "indexing",
"--hostname=monitoring": "monitoring",
"beat": "beat",
"slack/listener.py": "slack",
}
# Find all python processes that are likely celery workers
for proc in psutil.process_iter():
cmdline = _get_cmdline_for_process(proc)
if not cmdline:
continue
# Match supervisor-managed processes
for process_name, process_type in process_type_mapping.items():
if process_name in cmdline:
if process_type in supervisor_processes.values():
task_logger.error(
f"Duplicate process type for type {process_type} "
f"with cmd {cmdline} with pid={proc.pid}."
)
continue
supervisor_processes[proc.pid] = process_type
break
if len(supervisor_processes) != len(process_type_mapping):
task_logger.error(
"Missing processes: "
f"{set(process_type_mapping.keys()).symmetric_difference(supervisor_processes.values())}"
)
# Log memory usage for each process
for pid, process_type in supervisor_processes.items():
try:
emit_process_memory(pid, process_type, {})
except psutil.NoSuchProcess:
# Process may have terminated since we obtained the list
continue
except Exception as e:
task_logger.exception(f"Error monitoring process {pid}: {str(e)}")
except Exception:
task_logger.exception("Error in monitor_process_memory task")

View File

@@ -6,6 +6,8 @@ from sqlalchemy import and_
from sqlalchemy.orm import Session
from onyx.configs.constants import FileOrigin
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.db.engine import get_db_current_time
from onyx.db.index_attempt import get_index_attempt
@@ -16,7 +18,6 @@ from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
from onyx.utils.object_size_check import deep_getsizeof
logger = setup_logger()
_NUM_RECENT_ATTEMPTS_TO_CONSIDER = 20
@@ -52,7 +53,7 @@ def save_checkpoint(
def load_checkpoint(
db_session: Session, index_attempt_id: int
db_session: Session, index_attempt_id: int, connector: BaseConnector
) -> ConnectorCheckpoint | None:
"""Load a checkpoint for a given index attempt from the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
@@ -60,6 +61,8 @@ def load_checkpoint(
try:
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
checkpoint_data = checkpoint_io.read().decode("utf-8")
if isinstance(connector, CheckpointConnector):
return connector.validate_checkpoint_json(checkpoint_data)
return ConnectorCheckpoint.model_validate_json(checkpoint_data)
except RuntimeError:
return None
@@ -71,6 +74,7 @@ def get_latest_valid_checkpoint(
search_settings_id: int,
window_start: datetime,
window_end: datetime,
connector: BaseConnector,
) -> ConnectorCheckpoint:
"""Get the latest valid checkpoint for a given connector credential pair"""
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
@@ -105,7 +109,7 @@ def get_latest_valid_checkpoint(
f"for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
"from scratch."
)
return ConnectorCheckpoint.build_dummy_checkpoint()
return connector.build_dummy_checkpoint()
# assumes latest checkpoint is the furthest along. This only isn't true
# if something else has gone wrong.
@@ -113,12 +117,13 @@ def get_latest_valid_checkpoint(
checkpoint_candidates[0] if checkpoint_candidates else None
)
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
checkpoint = connector.build_dummy_checkpoint()
if latest_valid_checkpoint_candidate:
try:
previous_checkpoint = load_checkpoint(
db_session=db_session,
index_attempt_id=latest_valid_checkpoint_candidate.id,
connector=connector,
)
except Exception:
logger.exception(
@@ -193,7 +198,7 @@ def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
def check_checkpoint_size(checkpoint: ConnectorCheckpoint) -> None:
"""Check if the checkpoint content size exceeds the limit (200MB)"""
content_size = deep_getsizeof(checkpoint.checkpoint_content)
content_size = deep_getsizeof(checkpoint.model_dump())
if content_size > 200_000_000: # 200MB in bytes
raise ValueError(
f"Checkpoint content size ({content_size} bytes) exceeds 200MB limit"

View File

@@ -24,7 +24,6 @@ from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
@@ -32,8 +31,11 @@ from onyx.connectors.models import TextSection
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.constants import CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.index_attempt import create_index_attempt_error
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
@@ -46,8 +48,6 @@ from onyx.db.index_attempt import transition_attempt_to_in_progress
from onyx.db.index_attempt import update_docs_indexed
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
from onyx.document_index.factory import get_default_document_index
from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.embedder import DefaultIndexingEmbedder
@@ -56,9 +56,12 @@ from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.redis.redis_connector import RedisConnector
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
@@ -387,6 +390,7 @@ def _run_indexing(
net_doc_change = 0
document_count = 0
chunk_count = 0
index_attempt: IndexAttempt | None = None
try:
with get_session_with_current_tenant() as db_session_temp:
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
@@ -405,7 +409,7 @@ def _run_indexing(
# the beginning in order to avoid weird interactions between
# checkpointing / failure handling.
if index_attempt.from_beginning:
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
checkpoint = connector_runner.connector.build_dummy_checkpoint()
else:
checkpoint = get_latest_valid_checkpoint(
db_session=db_session_temp,
@@ -413,6 +417,7 @@ def _run_indexing(
search_settings_id=index_attempt.search_settings_id,
window_start=window_start,
window_end=window_end,
connector=connector_runner.connector,
)
unresolved_errors = get_index_attempt_errors_for_cc_pair(
@@ -433,7 +438,7 @@ def _run_indexing(
while checkpoint.has_more:
logger.info(
f"Running '{ctx.source}' connector with checkpoint: {checkpoint}"
f"Running '{ctx.source.value}' connector with checkpoint: {checkpoint}"
)
for document_batch, failure, next_checkpoint in connector_runner.run(
checkpoint
@@ -568,6 +573,22 @@ def _run_indexing(
if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))
# Add telemetry for indexing progress
optional_telemetry(
record_type=RecordType.INDEXING_PROGRESS,
data={
"index_attempt_id": index_attempt_id,
"cc_pair_id": ctx.cc_pair_id,
"connector_id": ctx.connector_id,
"credential_id": ctx.credential_id,
"total_docs_indexed": document_count,
"total_chunks": chunk_count,
"batch_num": batch_num,
"source": ctx.source.value,
},
tenant_id=tenant_id,
)
memory_tracer.increment_and_maybe_trace()
# `make sure the checkpoints aren't getting too large`at some regular interval
@@ -583,6 +604,30 @@ def _run_indexing(
checkpoint=checkpoint,
)
# Add telemetry for completed indexing
redis_connector = RedisConnector(tenant_id, ctx.cc_pair_id)
redis_connector_index = redis_connector.new_index(
index_attempt_start.search_settings_id
)
final_progress = redis_connector_index.get_progress() or 0
optional_telemetry(
record_type=RecordType.INDEXING_COMPLETE,
data={
"index_attempt_id": index_attempt_id,
"cc_pair_id": ctx.cc_pair_id,
"connector_id": ctx.connector_id,
"credential_id": ctx.credential_id,
"total_docs_indexed": document_count,
"total_chunks": chunk_count,
"batch_count": batch_num,
"time_elapsed_seconds": time.monotonic() - start_time,
"source": ctx.source.value,
"redis_progress": final_progress,
},
tenant_id=tenant_id,
)
except Exception as e:
logger.exception(
"Connector run exceptioned after elapsed time: "
@@ -596,16 +641,44 @@ def _run_indexing(
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}",
)
if ctx.is_primary:
update_connector_credential_pair(
if not index_attempt:
# should always be set by now
raise RuntimeError("Should never happen.")
VALIDATION_ERROR_THRESHOLD = 5
recent_index_attempts = get_recent_completed_attempts_for_cc_pair(
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
limit=VALIDATION_ERROR_THRESHOLD,
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
status=ConnectorCredentialPairStatus.INVALID,
)
num_validation_errors = len(
[
index_attempt
for index_attempt in recent_index_attempts
if index_attempt.error_msg
and index_attempt.error_msg.startswith(
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX
)
]
)
if num_validation_errors >= VALIDATION_ERROR_THRESHOLD:
logger.warning(
f"Connector {ctx.connector_id} has {num_validation_errors} consecutive validation"
f" errors. Marking the CC Pair as invalid."
)
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
status=ConnectorCredentialPairStatus.INVALID,
)
memory_tracer.stop()
raise e

View File

@@ -30,7 +30,7 @@ from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.gpu_utils import gpu_status_request
from onyx.utils.gpu_utils import fast_gpu_status_request
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -88,7 +88,9 @@ class Answer:
rerank_settings is not None
and rerank_settings.rerank_provider_type is not None
)
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
allow_agent_reranking = (
fast_gpu_status_request(indexing=False) or using_cloud_reranking
)
# TODO: this is a hack to force the query to be used for the search tool
# this should be removed once we fully unify graph inputs (i.e.

View File

@@ -73,6 +73,7 @@ from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
from onyx.db.chat import update_chat_session_updated_at_timestamp
from onyx.db.engine import get_session_context_manager
from onyx.db.milestone import check_multi_assistant_milestone
from onyx.db.milestone import create_milestone_if_not_exists
@@ -1069,6 +1070,8 @@ def stream_chat_message_objects(
prev_message = next_answer_message
logger.debug("Committing messages")
# Explicitly update the timestamp on the chat session
update_chat_session_updated_at_timestamp(chat_session_id, db_session)
db_session.commit() # actually save user / assistant message
yield AgenticMessageResponseIDInfo(agentic_message_ids=agentic_message_ids)

View File

@@ -1,6 +1,8 @@
import json
import os
import urllib.parse
from datetime import datetime
from datetime import timezone
from typing import cast
from onyx.auth.schemas import AuthBackend
@@ -33,6 +35,10 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
) # 1 day
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
# Controls whether to allow admin query history reports with:
# 1. associated user emails
# 2. anonymized user emails
# 3. no queries
ONYX_QUERY_HISTORY_TYPE = QueryHistoryType(
(os.environ.get("ONYX_QUERY_HISTORY_TYPE") or QueryHistoryType.NORMAL.value).lower()
)
@@ -153,10 +159,9 @@ VESPA_CLOUD_CERT_PATH = os.environ.get("VESPA_CLOUD_CERT_PATH")
VESPA_CLOUD_KEY_PATH = os.environ.get("VESPA_CLOUD_KEY_PATH")
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
try:
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16))
except ValueError:
INDEX_BATCH_SIZE = 16
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE") or 16)
MAX_DRIVE_WORKERS = int(os.environ.get("MAX_DRIVE_WORKERS", 4))
# Below are intended to match the env variables names used by the official postgres docker image
# https://hub.docker.com/_/postgres
@@ -341,8 +346,8 @@ HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY = os.environ.get(
HtmlBasedConnectorTransformLinksStrategy.STRIP,
)
NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP = (
os.environ.get("NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP", "").lower()
NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP = (
os.environ.get("NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP", "").lower()
== "true"
)
@@ -380,10 +385,27 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
# https://jira.atlassian.com/browse/CONFCLOUD-69670
def get_current_tz_offset() -> int:
# datetime now() gets local time, datetime.now(timezone.utc) gets UTC time.
# remove tzinfo to compare non-timezone-aware objects.
time_diff = datetime.now() - datetime.now(timezone.utc).replace(tzinfo=None)
return round(time_diff.total_seconds() / 3600)
# enter as a floating point offset from UTC in hours (-24 < val < 24)
# this will be applied globally, so it probably makes sense to transition this to per
# connector as some point.
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
# For the default value, we assume that the user's local timezone is more likely to be
# correct (i.e. the configured user's timezone or the default server one) than UTC.
# https://developer.atlassian.com/cloud/confluence/cql-fields/#created
CONFLUENCE_TIMEZONE_OFFSET = float(
os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset())
)
GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
)
JIRA_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
@@ -414,6 +436,9 @@ EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID")
LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET")
# Slack specific configs
SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 2)
DASK_JOB_CLIENT_ENABLED = (
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
)
@@ -667,3 +692,7 @@ IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
"IMAGE_ANALYSIS_SYSTEM_PROMPT",
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT,
)
DISABLE_AUTO_AUTH_REFRESH = (
os.environ.get("DISABLE_AUTO_AUTH_REFRESH", "").lower() == "true"
)

View File

@@ -3,6 +3,10 @@ import socket
from enum import auto
from enum import Enum
ONYX_DEFAULT_APPLICATION_NAME = "Onyx"
ONYX_SLACK_URL = "https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA"
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
SOURCE_TYPE = "source_type"
# stored in the `metadata` of a chunk. Used to signify that this chunk should
# not be used for QA. For example, Google Drive file types which can't be parsed
@@ -40,6 +44,7 @@ DISABLED_GEN_AI_MSG = (
"You can still use Onyx as a search engine."
)
DEFAULT_PERSONA_ID = 0
DEFAULT_CC_PAIR_ID = 1
@@ -174,6 +179,7 @@ class DocumentSource(str, Enum):
FIREFLIES = "fireflies"
EGNYTE = "egnyte"
AIRTABLE = "airtable"
HIGHSPOT = "highspot"
# Special case just for integration tests
MOCK_CONNECTOR = "mock_connector"
@@ -376,6 +382,7 @@ ONYX_CLOUD_TENANT_ID = "cloud"
# the redis namespace for runtime variables
ONYX_CLOUD_REDIS_RUNTIME = "runtime"
CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT = 600
class OnyxCeleryTask:
@@ -388,6 +395,9 @@ class OnyxCeleryTask:
)
CHECK_AVAILABLE_TENANTS = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_available_tenants"
# Tenant pre-provisioning
PRE_PROVISION_TENANT = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_pre_provision_tenant"
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
CHECK_FOR_INDEXING = "check_for_indexing"
@@ -402,9 +412,7 @@ class OnyxCeleryTask:
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
# Tenant pre-provisioning
PRE_PROVISION_TENANT = "pre_provision_tenant"
MONITOR_PROCESS_MEMORY = "monitor_process_memory"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (

View File

@@ -87,7 +87,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
credentials.get(key)
for key in ["aws_access_key_id", "aws_secret_access_key"]
):
raise ConnectorMissingCredentialError("Google Cloud Storage")
raise ConnectorMissingCredentialError("Amazon S3")
session = boto3.Session(
aws_access_key_id=credentials["aws_access_key_id"],

View File

@@ -65,19 +65,7 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
_SLIM_DOC_BATCH_SIZE = 5000
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
"gif",
"mp4",
"mov",
"mp3",
"wav",
]
_FULL_EXTENSION_FILTER_STRING = "".join(
[
f" and title!~'*.{extension}'"
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
]
)
ONE_HOUR = 3600
class ConfluenceConnector(
@@ -114,6 +102,7 @@ class ConfluenceConnector(
self.timezone_offset = timezone_offset
self._confluence_client: OnyxConfluence | None = None
self._fetched_titles: set[str] = set()
self.allow_images = False
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
@@ -158,6 +147,9 @@ class ConfluenceConnector(
"max_backoff_seconds": 60,
}
def set_allow_images(self, value: bool) -> None:
self.allow_images = value
@property
def confluence_client(self) -> OnyxConfluence:
if self._confluence_client is None:
@@ -203,7 +195,6 @@ class ConfluenceConnector(
def _construct_attachment_query(self, confluence_page_id: str) -> str:
attachment_query = f"type=attachment and container='{confluence_page_id}'"
attachment_query += self.cql_label_filter
attachment_query += _FULL_EXTENSION_FILTER_STRING
return attachment_query
def _get_comment_string_for_page_id(self, page_id: str) -> str:
@@ -233,7 +224,9 @@ class ConfluenceConnector(
# Extract basic page information
page_id = page["id"]
page_title = page["title"]
page_url = f"{self.wiki_base}{page['_links']['webui']}"
page_url = build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
)
# Get the page content
page_content = extract_text_from_confluence_html(
@@ -264,6 +257,7 @@ class ConfluenceConnector(
self.confluence_client,
attachment,
page_id,
self.allow_images,
)
if result and result.text:
@@ -304,13 +298,14 @@ class ConfluenceConnector(
if "version" in page and "by" in page["version"]:
author = page["version"]["by"]
display_name = author.get("displayName", "Unknown")
primary_owners.append(BasicExpertInfo(display_name=display_name))
email = author.get("email", "unknown@domain.invalid")
primary_owners.append(
BasicExpertInfo(display_name=display_name, email=email)
)
# Create the document
return Document(
id=build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
),
id=page_url,
sections=sections,
source=DocumentSource.CONFLUENCE,
semantic_identifier=page_title,
@@ -364,15 +359,18 @@ class ConfluenceConnector(
if not validate_attachment_filetype(
attachment,
):
logger.info(f"Skipping attachment: {attachment['title']}")
continue
logger.info(f"Processing attachment: {attachment['title']}")
# Attempt to get textual content or image summarization:
try:
logger.info(f"Processing attachment: {attachment['title']}")
response = convert_attachment_to_content(
confluence_client=self.confluence_client,
attachment=attachment,
page_id=page["id"],
allow_images=self.allow_images,
)
if response is None:
continue
@@ -420,7 +418,17 @@ class ConfluenceConnector(
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
return self._fetch_document_batches(start, end)
try:
return self._fetch_document_batches(start, end)
except Exception as e:
if "field 'updated' is invalid" in str(e) and start is not None:
logger.warning(
"Confluence says we provided an invalid 'updated' field. This may indicate"
"a real issue, but can also appear during edge cases like daylight"
f"savings time changes. Retrying with a 1 hour offset. Error: {e}"
)
return self._fetch_document_batches(start - ONE_HOUR, end)
raise
def retrieve_all_slim_documents(
self,

View File

@@ -498,10 +498,12 @@ class OnyxConfluence:
new_start = get_start_param_from_url(url_suffix)
previous_start = get_start_param_from_url(old_url_suffix)
if new_start - previous_start > len(results):
logger.warning(
logger.debug(
f"Start was updated by more than the amount of results "
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
f"Previous Start: {previous_start}, Len Results: {len(results)}."
f"retrieved for `{url_suffix}`. This is a bug with Confluence, "
"but we have logic to work around it - don't worry this isn't"
f" causing an issue. Start: {new_start}, Previous Start: "
f"{previous_start}, Len Results: {len(results)}."
)
# Update the url_suffix to use the adjusted start

View File

@@ -112,6 +112,7 @@ def process_attachment(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
parent_content_id: str | None,
allow_images: bool,
) -> AttachmentProcessingResult:
"""
Processes a Confluence attachment. If it's a document, extracts text,
@@ -119,7 +120,7 @@ def process_attachment(
"""
try:
# Get the media type from the attachment metadata
media_type = attachment.get("metadata", {}).get("mediaType", "")
media_type: str = attachment.get("metadata", {}).get("mediaType", "")
# Validate the attachment type
if not validate_attachment_filetype(attachment):
return AttachmentProcessingResult(
@@ -138,7 +139,14 @@ def process_attachment(
attachment_size = attachment["extensions"]["fileSize"]
if not media_type.startswith("image/"):
if media_type.startswith("image/"):
if not allow_images:
return AttachmentProcessingResult(
text=None,
file_name=None,
error="Image downloading is not enabled",
)
else:
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {attachment_link} due to size. "
@@ -294,6 +302,7 @@ def convert_attachment_to_content(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_id: str,
allow_images: bool,
) -> tuple[str | None, str | None] | None:
"""
Facade function which:
@@ -309,7 +318,7 @@ def convert_attachment_to_content(
)
return None
result = process_attachment(confluence_client, attachment, page_id)
result = process_attachment(confluence_client, attachment, page_id, allow_images)
if result.error is not None:
logger.warning(
f"Attachment {attachment['title']} encountered error: {result.error}"

View File

@@ -2,6 +2,8 @@ import sys
import time
from collections.abc import Generator
from datetime import datetime
from typing import Generic
from typing import TypeVar
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointConnector
@@ -19,8 +21,10 @@ logger = setup_logger()
TimeRange = tuple[datetime, datetime]
CT = TypeVar("CT", bound=ConnectorCheckpoint)
class CheckpointOutputWrapper:
class CheckpointOutputWrapper(Generic[CT]):
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format.
The connector format is easier for the connector implementor (e.g. it enforces exactly
@@ -29,20 +33,20 @@ class CheckpointOutputWrapper:
"""
def __init__(self) -> None:
self.next_checkpoint: ConnectorCheckpoint | None = None
self.next_checkpoint: CT | None = None
def __call__(
self,
checkpoint_connector_generator: CheckpointOutput,
checkpoint_connector_generator: CheckpointOutput[CT],
) -> Generator[
tuple[Document | None, ConnectorFailure | None, ConnectorCheckpoint | None],
tuple[Document | None, ConnectorFailure | None, CT | None],
None,
None,
]:
# grabs the final return value and stores it in the `next_checkpoint` variable
def _inner_wrapper(
checkpoint_connector_generator: CheckpointOutput,
) -> CheckpointOutput:
checkpoint_connector_generator: CheckpointOutput[CT],
) -> CheckpointOutput[CT]:
self.next_checkpoint = yield from checkpoint_connector_generator
return self.next_checkpoint # not used
@@ -64,7 +68,7 @@ class CheckpointOutputWrapper:
yield None, None, self.next_checkpoint
class ConnectorRunner:
class ConnectorRunner(Generic[CT]):
"""
Handles:
- Batching
@@ -85,11 +89,9 @@ class ConnectorRunner:
self.doc_batch: list[Document] = []
def run(
self, checkpoint: ConnectorCheckpoint
self, checkpoint: CT
) -> Generator[
tuple[
list[Document] | None, ConnectorFailure | None, ConnectorCheckpoint | None
],
tuple[list[Document] | None, ConnectorFailure | None, CT | None],
None,
None,
]:
@@ -105,9 +107,9 @@ class ConnectorRunner:
end=self.time_range[1].timestamp(),
checkpoint=checkpoint,
)
next_checkpoint: ConnectorCheckpoint | None = None
next_checkpoint: CT | None = None
# this is guaranteed to always run at least once with next_checkpoint being non-None
for document, failure, next_checkpoint in CheckpointOutputWrapper()(
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
checkpoint_connector_generator
):
if document is not None:
@@ -132,7 +134,7 @@ class ConnectorRunner:
)
else:
finished_checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
finished_checkpoint = self.connector.build_dummy_checkpoint()
finished_checkpoint.has_more = False
if isinstance(self.connector, PollConnector):

View File

@@ -28,8 +28,9 @@ from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import is_text_file_extension
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.extract_file_text import read_text_file
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
@@ -69,7 +70,9 @@ def _process_egnyte_file(
file_name = file_metadata["name"]
extension = get_file_ext(file_name)
if not is_valid_file_ext(extension):
if not is_accepted_file_ext(
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
):
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
return None

View File

@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.constants import DocumentSource
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.connectors.airtable.airtable_connector import AirtableConnector
from onyx.connectors.asana.connector import AsanaConnector
from onyx.connectors.axero.connector import AxeroConnector
@@ -30,6 +31,7 @@ from onyx.connectors.gong.connector import GongConnector
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_site.connector import GoogleSitesConnector
from onyx.connectors.guru.connector import GuruConnector
from onyx.connectors.highspot.connector import HighspotConnector
from onyx.connectors.hubspot.connector import HubSpotConnector
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointConnector
@@ -117,6 +119,7 @@ def identify_connector_class(
DocumentSource.FIREFLIES: FirefliesConnector,
DocumentSource.EGNYTE: EgnyteConnector,
DocumentSource.AIRTABLE: AirtableConnector,
DocumentSource.HIGHSPOT: HighspotConnector,
# just for integration tests
DocumentSource.MOCK_CONNECTOR: MockConnector,
}
@@ -182,6 +185,8 @@ def instantiate_connector(
if new_credentials is not None:
backend_update_credential_json(credential, new_credentials, db_session)
connector.set_allow_images(get_image_extraction_and_analysis_enabled())
return connector

View File

@@ -22,8 +22,9 @@ from onyx.db.engine import get_session_with_current_tenant
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
@@ -51,7 +52,7 @@ def _read_files_and_metadata(
file_content, ignore_dirs=True
):
yield os.path.join(directory_path, file_info.filename), subfile, metadata
elif is_valid_file_ext(extension):
elif is_accepted_file_ext(extension, OnyxExtensionType.All):
yield file_name, file_content, metadata
else:
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
@@ -122,7 +123,7 @@ def _process_file(
logger.warning(f"No file record found for '{file_name}' in PG; skipping.")
return []
if not is_valid_file_ext(extension):
if not is_accepted_file_ext(extension, OnyxExtensionType.All):
logger.warning(
f"Skipping file '{file_name}' with unrecognized extension '{extension}'"
)
@@ -219,24 +220,34 @@ def _process_file(
# 2) Otherwise: text-based approach. Possibly with embedded images.
file.seek(0)
text_content = ""
embedded_images: list[tuple[bytes, str]] = []
# Extract text and images from the file
text_content, embedded_images = extract_text_and_images(
extraction_result = extract_text_and_images(
file=file,
file_name=file_name,
pdf_pass=pdf_pass,
)
# Merge file-specific metadata (from file content) with provided metadata
if extraction_result.metadata:
logger.debug(
f"Found file-specific metadata for {file_name}: {extraction_result.metadata}"
)
metadata.update(extraction_result.metadata)
# Build sections: first the text as a single Section
sections: list[TextSection | ImageSection] = []
link_in_meta = metadata.get("link")
if text_content.strip():
sections.append(TextSection(link=link_in_meta, text=text_content.strip()))
if extraction_result.text_content.strip():
logger.debug(f"Creating TextSection for {file_name} with link: {link_in_meta}")
sections.append(
TextSection(link=link_in_meta, text=extraction_result.text_content.strip())
)
# Then any extracted images from docx, etc.
for idx, (img_data, img_name) in enumerate(embedded_images, start=1):
for idx, (img_data, img_name) in enumerate(
extraction_result.embedded_images, start=1
):
# Store each embedded image as a separate file in PGFileStore
# and create a section with the image reference
try:

View File

@@ -1,8 +1,10 @@
import copy
import time
from collections.abc import Iterator
from collections.abc import Generator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from enum import Enum
from typing import Any
from typing import cast
@@ -13,26 +15,30 @@ from github.GithubException import GithubException
from github.Issue import Issue
from github.PaginatedList import PaginatedList
from github.PullRequest import PullRequest
from github.Requester import Requester
from pydantic import BaseModel
from typing_extensions import override
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import ConnectorFailure
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import TextSection
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
logger = setup_logger()
ITEMS_PER_PAGE = 100
_MAX_NUM_RATE_LIMIT_RETRIES = 5
@@ -48,7 +54,7 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
def _get_batch_rate_limited(
git_objs: PaginatedList, page_num: int, github_client: Github, attempt_num: int = 0
) -> list[Any]:
) -> list[PullRequest | Issue]:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
@@ -69,21 +75,6 @@ def _get_batch_rate_limited(
)
def _batch_github_objects(
git_objs: PaginatedList, github_client: Github, batch_size: int
) -> Iterator[list[Any]]:
page_num = 0
while True:
batch = _get_batch_rate_limited(git_objs, page_num, github_client)
page_num += 1
if not batch:
break
for mini_batch in batch_generator(batch, batch_size=batch_size):
yield mini_batch
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
return Document(
id=pull_request.html_url,
@@ -95,7 +86,9 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
# updated_at is UTC time but is timezone unaware, explicitly add UTC
# as there is logic in indexing to prevent wrong timestamped docs
# due to local time discrepancies with UTC
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc),
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc)
if pull_request.updated_at
else None,
metadata={
"merged": str(pull_request.merged),
"state": pull_request.state,
@@ -122,31 +115,58 @@ def _convert_issue_to_document(issue: Issue) -> Document:
)
class GithubConnector(LoadConnector, PollConnector):
class SerializedRepository(BaseModel):
# id is part of the raw_data as well, just pulled out for convenience
id: int
headers: dict[str, str | int]
raw_data: dict[str, Any]
def to_Repository(self, requester: Requester) -> Repository.Repository:
return Repository.Repository(
requester, self.headers, self.raw_data, completed=True
)
class GithubConnectorStage(Enum):
START = "start"
PRS = "prs"
ISSUES = "issues"
class GithubConnectorCheckpoint(ConnectorCheckpoint):
stage: GithubConnectorStage
curr_page: int
cached_repo_ids: list[int] | None = None
cached_repo: SerializedRepository | None = None
class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
def __init__(
self,
repo_owner: str,
repositories: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
state_filter: str = "all",
include_prs: bool = True,
include_issues: bool = False,
) -> None:
self.repo_owner = repo_owner
self.repositories = repositories
self.batch_size = batch_size
self.state_filter = state_filter
self.include_prs = include_prs
self.include_issues = include_issues
self.github_client: Github | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# defaults to 30 items per page, can be set to as high as 100
self.github_client = (
Github(
credentials["github_access_token"], base_url=GITHUB_CONNECTOR_BASE_URL
credentials["github_access_token"],
base_url=GITHUB_CONNECTOR_BASE_URL,
per_page=ITEMS_PER_PAGE,
)
if GITHUB_CONNECTOR_BASE_URL
else Github(credentials["github_access_token"])
else Github(credentials["github_access_token"], per_page=ITEMS_PER_PAGE)
)
return None
@@ -217,85 +237,193 @@ class GithubConnector(LoadConnector, PollConnector):
return self._get_all_repos(github_client, attempt_num + 1)
def _fetch_from_github(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
self,
checkpoint: GithubConnectorCheckpoint,
start: datetime | None = None,
end: datetime | None = None,
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
repos = []
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self._get_github_repos(self.github_client)
checkpoint = copy.deepcopy(checkpoint)
# First run of the connector, fetch all repos and store in checkpoint
if checkpoint.cached_repo_ids is None:
repos = []
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self._get_github_repos(self.github_client)
else:
# Single repository (backward compatibility)
repos = [self._get_github_repo(self.github_client)]
else:
# Single repository (backward compatibility)
repos = [self._get_github_repo(self.github_client)]
else:
# All repositories
repos = self._get_all_repos(self.github_client)
# All repositories
repos = self._get_all_repos(self.github_client)
if not repos:
checkpoint.has_more = False
return checkpoint
for repo in repos:
if self.include_prs:
logger.info(f"Fetching PRs for repo: {repo.name}")
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
checkpoint.cached_repo_ids = sorted([repo.id for repo in repos])
checkpoint.cached_repo = SerializedRepository(
id=checkpoint.cached_repo_ids[0],
headers=repos[0].raw_headers,
raw_data=repos[0].raw_data,
)
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.curr_page = 0
# save checkpoint with repo ids retrieved
return checkpoint
for pr_batch in _batch_github_objects(
pull_requests, self.github_client, self.batch_size
assert checkpoint.cached_repo is not None, "No repo saved in checkpoint"
repo = checkpoint.cached_repo.to_Repository(self.github_client.requester)
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
logger.info(f"Fetching PRs for repo: {repo.name}")
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
doc_batch: list[Document] = []
pr_batch = _get_batch_rate_limited(
pull_requests, checkpoint.curr_page, self.github_client
)
checkpoint.curr_page += 1
done_with_prs = False
for pr in pr_batch:
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) < start
):
doc_batch: list[Document] = []
for pr in pr_batch:
if start is not None and pr.updated_at < start:
yield doc_batch
break
if end is not None and pr.updated_at > end:
continue
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
yield doc_batch
if self.include_issues:
logger.info(f"Fetching issues for repo: {repo.name}")
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
for issue_batch in _batch_github_objects(
issues, self.github_client, self.batch_size
yield from doc_batch
done_with_prs = True
break
# Skip PRs updated after the end date
if (
end is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) > end
):
doc_batch = []
for issue in issue_batch:
issue = cast(Issue, issue)
if start is not None and issue.updated_at < start:
yield doc_batch
break
if end is not None and issue.updated_at > end:
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
doc_batch.append(_convert_issue_to_document(issue))
yield doc_batch
continue
try:
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
except Exception as e:
error_msg = f"Error converting PR to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(pr.id), document_link=pr.html_url
),
failure_message=error_msg,
exception=e,
)
continue
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_github()
# if we found any PRs on the page, yield any associated documents and return the checkpoint
if not done_with_prs and len(pr_batch) > 0:
yield from doc_batch
return checkpoint
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.utcfromtimestamp(start)
end_datetime = datetime.utcfromtimestamp(end)
# if we went past the start date during the loop or there are no more
# prs to get, we move on to issues
checkpoint.stage = GithubConnectorStage.ISSUES
checkpoint.curr_page = 0
checkpoint.stage = GithubConnectorStage.ISSUES
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
logger.info(f"Fetching issues for repo: {repo.name}")
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
doc_batch = []
issue_batch = _get_batch_rate_limited(
issues, checkpoint.curr_page, self.github_client
)
checkpoint.curr_page += 1
done_with_issues = False
for issue in cast(list[Issue], issue_batch):
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and issue.updated_at.replace(tzinfo=timezone.utc) < start
):
yield from doc_batch
done_with_issues = True
break
# Skip PRs updated after the end date
if (
end is not None
and issue.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
try:
doc_batch.append(_convert_issue_to_document(issue))
except Exception as e:
error_msg = f"Error converting issue to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(issue.id),
document_link=issue.html_url,
),
failure_message=error_msg,
exception=e,
)
continue
# if we found any issues on the page, yield them and return the checkpoint
if not done_with_issues and len(issue_batch) > 0:
yield from doc_batch
return checkpoint
# if we went past the start date during the loop or there are no more
# issues to get, we move on to the next repo
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.curr_page = 0
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 1
if checkpoint.cached_repo_ids:
next_id = checkpoint.cached_repo_ids.pop()
next_repo = self.github_client.get_repo(next_id)
checkpoint.cached_repo = SerializedRepository(
id=next_id,
headers=next_repo.raw_headers,
raw_data=next_repo.raw_data,
)
return checkpoint
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: GithubConnectorCheckpoint,
) -> CheckpointOutput[GithubConnectorCheckpoint]:
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
# Move start time back by 3 hours, since some Issues/PRs are getting dropped
# Could be due to delayed processing on GitHub side
# The non-updated issues since last poll will be shortcut-ed and not embedded
adjusted_start_datetime = start_datetime - timedelta(hours=3)
epoch = datetime.utcfromtimestamp(0)
epoch = datetime.fromtimestamp(0, tz=timezone.utc)
if adjusted_start_datetime < epoch:
adjusted_start_datetime = epoch
return self._fetch_from_github(adjusted_start_datetime, end_datetime)
return self._fetch_from_github(
checkpoint, start=adjusted_start_datetime, end=end_datetime
)
def validate_connector_settings(self) -> None:
if self.github_client is None:
@@ -397,6 +525,16 @@ class GithubConnector(LoadConnector, PollConnector):
f"Unexpected error during GitHub settings validation: {exc}"
)
def validate_checkpoint_json(
self, checkpoint_json: str
) -> GithubConnectorCheckpoint:
return GithubConnectorCheckpoint.model_validate_json(checkpoint_json)
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
return GithubConnectorCheckpoint(
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True
)
if __name__ == "__main__":
import os
@@ -406,7 +544,9 @@ if __name__ == "__main__":
repositories=os.environ["REPOSITORIES"],
)
connector.load_credentials(
{"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]}
{"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"]}
)
document_batches = connector.load_from_checkpoint(
0, time.time(), connector.build_dummy_checkpoint()
)
document_batches = connector.load_from_state()
print(next(document_batches))

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,5 @@
import io
from collections.abc import Callable
from datetime import datetime
from typing import cast
@@ -13,7 +14,9 @@ from onyx.connectors.google_drive.models import GoogleDriveFileType
from onyx.connectors.google_drive.section_extraction import get_document_sections
from onyx.connectors.google_utils.resources import GoogleDocsService
from onyx.connectors.google_utils.resources import GoogleDriveService
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
@@ -73,9 +76,10 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
return is_valid_image_type(mime_type)
def _extract_sections_basic(
def _download_and_extract_sections_basic(
file: dict[str, str],
service: GoogleDriveService,
allow_images: bool,
) -> list[TextSection | ImageSection]:
"""Extract text and images from a Google Drive file."""
file_id = file["id"]
@@ -84,6 +88,10 @@ def _extract_sections_basic(
link = file.get("webViewLink", "")
try:
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
@@ -202,12 +210,17 @@ def _extract_sections_basic(
def convert_drive_item_to_document(
file: GoogleDriveFileType,
drive_service: GoogleDriveService,
docs_service: GoogleDocsService,
) -> Document | None:
drive_service: Callable[[], GoogleDriveService],
docs_service: Callable[[], GoogleDocsService],
allow_images: bool,
size_threshold: int,
) -> Document | ConnectorFailure | None:
"""
Main entry point for converting a Google Drive file => Document object.
"""
doc_id = ""
sections: list[TextSection | ImageSection] = []
try:
# skip shortcuts or folders
if file.get("mimeType") in [DRIVE_SHORTCUT_TYPE, DRIVE_FOLDER_TYPE]:
@@ -215,13 +228,11 @@ def convert_drive_item_to_document(
return None
# If it's a Google Doc, we might do advanced parsing
sections: list[TextSection | ImageSection] = []
# Try to get sections using the advanced method first
if file.get("mimeType") == GDriveMimeType.DOC.value:
try:
# get_document_sections is the advanced approach for Google Docs
doc_sections = get_document_sections(
docs_service=docs_service, doc_id=file.get("id", "")
docs_service=docs_service(), doc_id=file.get("id", "")
)
if doc_sections:
sections = cast(list[TextSection | ImageSection], doc_sections)
@@ -230,9 +241,24 @@ def convert_drive_item_to_document(
f"Error in advanced parsing: {e}. Falling back to basic extraction."
)
size_str = file.get("size")
if size_str:
try:
size_int = int(size_str)
except ValueError:
logger.warning(f"Parsing string to int failed: size_str={size_str}")
else:
if size_int > size_threshold:
logger.warning(
f"{file.get('name')} exceeds size threshold of {size_threshold}. Skipping."
)
return None
# If we don't have sections yet, use the basic extraction method
if not sections:
sections = _extract_sections_basic(file, drive_service)
sections = _download_and_extract_sections_basic(
file, drive_service(), allow_images
)
# If we still don't have any sections, skip this file
if not sections:
@@ -257,8 +283,19 @@ def convert_drive_item_to_document(
),
)
except Exception as e:
logger.error(f"Error converting file {file.get('name')}: {e}")
return None
error_str = f"Error converting file '{file.get('name')}' to Document: {e}"
logger.exception(error_str)
return ConnectorFailure(
failed_document=DocumentFailure(
document_id=doc_id,
document_link=sections[0].link
if sections
else None, # TODO: see if this is the best way to get a link
),
failed_entity=None,
failure_message=error_str,
exception=e,
)
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:

View File

@@ -1,17 +1,23 @@
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from typing import Any
from datetime import timezone
from googleapiclient.discovery import Resource # type: ignore
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from onyx.connectors.google_drive.models import DriveRetrievalStage
from onyx.connectors.google_drive.models import GoogleDriveFileType
from onyx.connectors.google_drive.models import RetrievedDriveFile
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
from onyx.connectors.google_utils.google_utils import GoogleFields
from onyx.connectors.google_utils.google_utils import ORDER_BY_KEY
from onyx.connectors.google_utils.resources import GoogleDriveService
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.utils.logger import setup_logger
logger = setup_logger()
FILE_FIELDS = (
@@ -31,11 +37,13 @@ def _generate_time_range_filter(
) -> str:
time_range_filter = ""
if start is not None:
time_start = datetime.utcfromtimestamp(start).isoformat() + "Z"
time_range_filter += f" and modifiedTime >= '{time_start}'"
time_start = datetime.fromtimestamp(start, tz=timezone.utc).isoformat()
time_range_filter += (
f" and {GoogleFields.MODIFIED_TIME.value} >= '{time_start}'"
)
if end is not None:
time_stop = datetime.utcfromtimestamp(end).isoformat() + "Z"
time_range_filter += f" and modifiedTime <= '{time_stop}'"
time_stop = datetime.fromtimestamp(end, tz=timezone.utc).isoformat()
time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} <= '{time_stop}'"
return time_range_filter
@@ -66,9 +74,9 @@ def _get_folders_in_parent(
def _get_files_in_parent(
service: Resource,
parent_id: str,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
is_slim: bool = False,
) -> Iterator[GoogleDriveFileType]:
query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents"
query += " and trashed = false"
@@ -83,6 +91,7 @@ def _get_files_in_parent(
includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=query,
**({} if is_slim else {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value}),
):
yield file
@@ -90,30 +99,50 @@ def _get_files_in_parent(
def crawl_folders_for_files(
service: Resource,
parent_id: str,
is_slim: bool,
user_email: str,
traversed_parent_ids: set[str],
update_traversed_ids_func: Callable[[str], None],
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
) -> Iterator[RetrievedDriveFile]:
"""
This function starts crawling from any folder. It is slower though.
"""
if parent_id in traversed_parent_ids:
logger.info(f"Skipping subfolder since already traversed: {parent_id}")
return
found_files = False
for file in _get_files_in_parent(
service=service,
start=start,
end=end,
parent_id=parent_id,
):
found_files = True
yield file
if found_files:
update_traversed_ids_func(parent_id)
logger.info("Entered crawl_folders_for_files with parent_id: " + parent_id)
if parent_id not in traversed_parent_ids:
logger.info("Parent id not in traversed parent ids, getting files")
found_files = False
file = {}
try:
for file in _get_files_in_parent(
service=service,
parent_id=parent_id,
is_slim=is_slim,
start=start,
end=end,
):
found_files = True
logger.info(f"Found file: {file['name']}, user email: {user_email}")
yield RetrievedDriveFile(
drive_file=file,
user_email=user_email,
parent_id=parent_id,
completion_stage=DriveRetrievalStage.FOLDER_FILES,
)
except Exception as e:
logger.error(f"Error getting files in parent {parent_id}: {e}")
yield RetrievedDriveFile(
drive_file=file,
user_email=user_email,
parent_id=parent_id,
completion_stage=DriveRetrievalStage.FOLDER_FILES,
error=e,
)
if found_files:
update_traversed_ids_func(parent_id)
else:
logger.info(f"Skipping subfolder files since already traversed: {parent_id}")
for subfolder in _get_folders_in_parent(
service=service,
@@ -123,6 +152,8 @@ def crawl_folders_for_files(
yield from crawl_folders_for_files(
service=service,
parent_id=subfolder["id"],
is_slim=is_slim,
user_email=user_email,
traversed_parent_ids=traversed_parent_ids,
update_traversed_ids_func=update_traversed_ids_func,
start=start,
@@ -133,16 +164,19 @@ def crawl_folders_for_files(
def get_files_in_shared_drive(
service: Resource,
drive_id: str,
is_slim: bool = False,
is_slim: bool,
update_traversed_ids_func: Callable[[str], None] = lambda _: None,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
kwargs = {}
if not is_slim:
kwargs[ORDER_BY_KEY] = GoogleFields.MODIFIED_TIME.value
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
@@ -155,15 +189,13 @@ def get_files_in_shared_drive(
q=folder_query,
):
update_traversed_ids_func(file["id"])
found_folders = True
if found_folders:
update_traversed_ids_func(drive_id)
# Get all files in the shared drive
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += _generate_time_range_filter(start, end)
yield from execute_paginated_retrieval(
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=True,
@@ -173,16 +205,26 @@ def get_files_in_shared_drive(
includeItemsFromAllDrives=True,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=file_query,
)
**kwargs,
):
# If we found any files, mark this drive as traversed. When a user has access to a drive,
# they have access to all the files in the drive. Also not a huge deal if we re-traverse
# empty drives.
update_traversed_ids_func(drive_id)
yield file
def get_all_files_in_my_drive(
service: Any,
service: GoogleDriveService,
update_traversed_ids_func: Callable,
is_slim: bool = False,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
kwargs = {}
if not is_slim:
kwargs[ORDER_BY_KEY] = GoogleFields.MODIFIED_TIME.value
# If we know we are going to folder crawl later, we can cache the folders here
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
@@ -196,7 +238,7 @@ def get_all_files_in_my_drive(
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=folder_query,
):
update_traversed_ids_func(file["id"])
update_traversed_ids_func(file[GoogleFields.ID])
found_folders = True
if found_folders:
update_traversed_ids_func(get_root_folder_id(service))
@@ -209,22 +251,28 @@ def get_all_files_in_my_drive(
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=False,
corpora="user",
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=file_query,
**kwargs,
)
def get_all_files_for_oauth(
service: Any,
service: GoogleDriveService,
include_files_shared_with_me: bool,
include_my_drives: bool,
# One of the above 2 should be true
include_shared_drives: bool,
is_slim: bool = False,
is_slim: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
kwargs = {}
if not is_slim:
kwargs[ORDER_BY_KEY] = GoogleFields.MODIFIED_TIME.value
should_get_all = (
include_shared_drives and include_my_drives and include_files_shared_with_me
)
@@ -243,11 +291,13 @@ def get_all_files_for_oauth(
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,
list_key="files",
continue_on_404_or_403=False,
corpora=corpora,
includeItemsFromAllDrives=should_get_all,
supportsAllDrives=should_get_all,
fields=SLIM_FILE_FIELDS if is_slim else FILE_FIELDS,
q=file_query,
**kwargs,
)
@@ -255,4 +305,8 @@ def get_all_files_for_oauth(
def get_root_folder_id(service: Resource) -> str:
# we dont paginate here because there is only one root folder per user
# https://developers.google.com/drive/api/guides/v2-to-v3-reference
return service.files().get(fileId="root", fields="id").execute()["id"]
return (
service.files()
.get(fileId="root", fields=GoogleFields.ID.value)
.execute()[GoogleFields.ID.value]
)

View File

@@ -1,6 +1,15 @@
from enum import Enum
from typing import Any
from pydantic import BaseModel
from pydantic import ConfigDict
from pydantic import field_serializer
from pydantic import field_validator
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.utils.threadpool_concurrency import ThreadSafeDict
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
@@ -20,3 +29,128 @@ class GDriveMimeType(str, Enum):
GoogleDriveFileType = dict[str, Any]
TOKEN_EXPIRATION_TIME = 3600 # 1 hour
# These correspond to The major stages of retrieval for google drive.
# The stages for the oauth flow are:
# get_all_files_for_oauth(),
# get_all_drive_ids(),
# get_files_in_shared_drive(),
# crawl_folders_for_files()
#
# The stages for the service account flow are roughly:
# get_all_user_emails(),
# get_all_drive_ids(),
# get_files_in_shared_drive(),
# Then for each user:
# get_files_in_my_drive()
# get_files_in_shared_drive()
# crawl_folders_for_files()
class DriveRetrievalStage(str, Enum):
START = "start"
DONE = "done"
# OAuth specific stages
OAUTH_FILES = "oauth_files"
# Service account specific stages
USER_EMAILS = "user_emails"
MY_DRIVE_FILES = "my_drive_files"
# Used for both oauth and service account flows
DRIVE_IDS = "drive_ids"
SHARED_DRIVE_FILES = "shared_drive_files"
FOLDER_FILES = "folder_files"
class StageCompletion(BaseModel):
"""
Describes the point in the retrieval+indexing process that the
connector is at. completed_until is the timestamp of the latest
file that has been retrieved or error that has been yielded.
Optional fields are used for retrieval stages that need more information
for resuming than just the timestamp of the latest file.
"""
stage: DriveRetrievalStage
completed_until: SecondsSinceUnixEpoch
completed_until_parent_id: str | None = None
# only used for shared drives
processed_drive_ids: set[str] = set()
def update(
self,
stage: DriveRetrievalStage,
completed_until: SecondsSinceUnixEpoch,
completed_until_parent_id: str | None = None,
) -> None:
self.stage = stage
self.completed_until = completed_until
self.completed_until_parent_id = completed_until_parent_id
class RetrievedDriveFile(BaseModel):
"""
Describes a file that has been retrieved from google drive.
user_email is the email of the user that the file was retrieved
by impersonating. If an error worthy of being reported is encountered,
error should be set and later propagated as a ConnectorFailure.
"""
# The stage at which this file was retrieved
completion_stage: DriveRetrievalStage
# The file that was retrieved
drive_file: GoogleDriveFileType
# The email of the user that the file was retrieved by impersonating
user_email: str
# The id of the parent folder or drive of the file
parent_id: str | None = None
# Any unexpected error that occurred while retrieving the file.
# In particular, this is not used for 403/404 errors, which are expected
# in the context of impersonating all the users to try to retrieve all
# files from all their Drives and Folders.
error: Exception | None = None
model_config = ConfigDict(arbitrary_types_allowed=True)
class GoogleDriveCheckpoint(ConnectorCheckpoint):
# Checkpoint version of _retrieved_ids
retrieved_folder_and_drive_ids: set[str]
# Describes the point in the retrieval+indexing process that the
# checkpoint is at. when this is set to a given stage, the connector
# has finished yielding all values from the previous stage.
completion_stage: DriveRetrievalStage
# The latest timestamp of a file that has been retrieved per user email.
# StageCompletion is used to track the completion of each stage, but the
# timestamp part is not used for folder crawling.
completion_map: ThreadSafeDict[str, StageCompletion]
# cached version of the drive and folder ids to retrieve
drive_ids_to_retrieve: list[str] | None = None
folder_ids_to_retrieve: list[str] | None = None
# cached user emails
user_emails: list[str] | None = None
@field_serializer("completion_map")
def serialize_completion_map(
self, completion_map: ThreadSafeDict[str, StageCompletion], _info: Any
) -> dict[str, StageCompletion]:
return completion_map._dict
@field_validator("completion_map", mode="before")
def validate_completion_map(cls, v: Any) -> ThreadSafeDict[str, StageCompletion]:
assert isinstance(v, dict) or isinstance(v, ThreadSafeDict)
return ThreadSafeDict(
{k: StageCompletion.model_validate(v) for k, v in v.items()}
)

View File

@@ -4,6 +4,7 @@ from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from enum import Enum
from typing import Any
from googleapiclient.errors import HttpError # type: ignore
@@ -16,20 +17,37 @@ logger = setup_logger()
# Google Drive APIs are quite flakey and may 500 for an
# extended period of time. Trying to combat here by adding a very
# long retry period (~20 minutes of trying every minute)
add_retries = retry_builder(tries=50, max_delay=30)
# extended period of time. This is now addressed by checkpointing.
#
# NOTE: We previously tried to combat this here by adding a very
# long retry period (~20 minutes of trying, one request a minute.)
# This is no longer necessary due to checkpointing.
add_retries = retry_builder(tries=5, max_delay=10)
NEXT_PAGE_TOKEN_KEY = "nextPageToken"
PAGE_TOKEN_KEY = "pageToken"
ORDER_BY_KEY = "orderBy"
# See https://developers.google.com/drive/api/reference/rest/v3/files/list for more
class GoogleFields(str, Enum):
ID = "id"
CREATED_TIME = "createdTime"
MODIFIED_TIME = "modifiedTime"
NAME = "name"
SIZE = "size"
PARENTS = "parents"
def _execute_with_retry(request: Any) -> Any:
max_attempts = 10
max_attempts = 6
attempt = 1
while attempt < max_attempts:
# Note for reasons unknown, the Google API will sometimes return a 429
# and even after waiting the retry period, it will return another 429.
# It could be due to a few possibilities:
# 1. Other things are also requesting from the Gmail API with the same key
# 1. Other things are also requesting from the Drive/Gmail API with the same key
# 2. It's a rolling rate limit so the moment we get some amount of requests cleared, we hit it again very quickly
# 3. The retry-after has a maximum and we've already hit the limit for the day
# or it's something else...
@@ -90,11 +108,11 @@ def execute_paginated_retrieval(
retrieval_function: The specific list function to call (e.g., service.files().list)
**kwargs: Arguments to pass to the list function
"""
next_page_token = ""
next_page_token = kwargs.get(PAGE_TOKEN_KEY, "")
while next_page_token is not None:
request_kwargs = kwargs.copy()
if next_page_token:
request_kwargs["pageToken"] = next_page_token
request_kwargs[PAGE_TOKEN_KEY] = next_page_token
try:
results = retrieval_function(**request_kwargs).execute()
@@ -117,7 +135,7 @@ def execute_paginated_retrieval(
logger.exception("Error executing request:")
raise e
next_page_token = results.get("nextPageToken")
next_page_token = results.get(NEXT_PAGE_TOKEN_KEY)
if list_key:
for item in results.get(list_key, []):
yield item

View File

@@ -0,0 +1,4 @@
"""
Highspot connector package for Onyx.
Enables integration with Highspot's knowledge base.
"""

View File

@@ -0,0 +1,280 @@
import base64
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from urllib.parse import urljoin
import requests
from requests.adapters import HTTPAdapter
from requests.exceptions import HTTPError
from requests.exceptions import RequestException
from requests.exceptions import Timeout
from urllib3.util.retry import Retry
from onyx.utils.logger import setup_logger
logger = setup_logger()
class HighspotClientError(Exception):
"""Base exception for Highspot API client errors."""
def __init__(self, message: str, status_code: Optional[int] = None):
self.message = message
self.status_code = status_code
super().__init__(self.message)
class HighspotAuthenticationError(HighspotClientError):
"""Exception raised for authentication errors."""
class HighspotRateLimitError(HighspotClientError):
"""Exception raised when rate limit is exceeded."""
def __init__(self, message: str, retry_after: Optional[str] = None):
self.retry_after = retry_after
super().__init__(message)
class HighspotClient:
"""
Client for interacting with the Highspot API.
Uses basic authentication with provided key (username) and secret (password).
Implements retry logic, error handling, and connection pooling.
"""
BASE_URL = "https://api-su2.highspot.com/v1.0/"
def __init__(
self,
key: str,
secret: str,
base_url: str = BASE_URL,
timeout: int = 30,
max_retries: int = 3,
backoff_factor: float = 0.5,
status_forcelist: Optional[List[int]] = None,
):
"""
Initialize the Highspot API client.
Args:
key: API key (used as username)
secret: API secret (used as password)
base_url: Base URL for the Highspot API
timeout: Request timeout in seconds
max_retries: Maximum number of retries for failed requests
backoff_factor: Backoff factor for retries
status_forcelist: HTTP status codes to retry on
"""
if not key or not secret:
raise ValueError("API key and secret are required")
self.key = key
self.secret = secret
self.base_url = base_url
self.timeout = timeout
# Set up session with retry logic
self.session = requests.Session()
retry_strategy = Retry(
total=max_retries,
backoff_factor=backoff_factor,
status_forcelist=status_forcelist or [429, 500, 502, 503, 504],
allowed_methods=["GET", "POST", "PUT", "DELETE"],
)
adapter = HTTPAdapter(max_retries=retry_strategy)
self.session.mount("http://", adapter)
self.session.mount("https://", adapter)
# Set up authentication
self._setup_auth()
def _setup_auth(self) -> None:
"""Set up basic authentication for the session."""
auth = f"{self.key}:{self.secret}"
encoded_auth = base64.b64encode(auth.encode()).decode()
self.session.headers.update(
{
"Authorization": f"Basic {encoded_auth}",
"Content-Type": "application/json",
"Accept": "application/json",
}
)
def _make_request(
self,
method: str,
endpoint: str,
params: Optional[Dict[str, Any]] = None,
data: Optional[Dict[str, Any]] = None,
json_data: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
"""
Make a request to the Highspot API.
Args:
method: HTTP method (GET, POST, etc.)
endpoint: API endpoint
params: URL parameters
data: Form data
json_data: JSON data
headers: Additional headers
Returns:
API response as a dictionary
Raises:
HighspotClientError: On API errors
HighspotAuthenticationError: On authentication errors
HighspotRateLimitError: On rate limiting
requests.exceptions.RequestException: On request failures
"""
url = urljoin(self.base_url, endpoint)
request_headers = {}
if headers:
request_headers.update(headers)
try:
logger.debug(f"Making {method} request to {url}")
response = self.session.request(
method=method,
url=url,
params=params,
data=data,
json=json_data,
headers=request_headers,
timeout=self.timeout,
)
response.raise_for_status()
if response.content and response.content.strip():
return response.json()
return {}
except HTTPError as e:
status_code = e.response.status_code
error_msg = str(e)
try:
error_data = e.response.json()
if isinstance(error_data, dict):
error_msg = error_data.get("message", str(e))
except (ValueError, KeyError):
pass
if status_code == 401:
raise HighspotAuthenticationError(f"Authentication failed: {error_msg}")
elif status_code == 429:
retry_after = e.response.headers.get("Retry-After")
raise HighspotRateLimitError(
f"Rate limit exceeded: {error_msg}", retry_after=retry_after
)
else:
raise HighspotClientError(
f"API error {status_code}: {error_msg}", status_code=status_code
)
except Timeout:
raise HighspotClientError("Request timed out")
except RequestException as e:
raise HighspotClientError(f"Request failed: {str(e)}")
def get_spots(self) -> List[Dict[str, Any]]:
"""
Get all available spots.
Returns:
List of spots with their names and IDs
"""
params = {"right": "view"}
response = self._make_request("GET", "spots", params=params)
logger.info(f"Received {response} spots")
total_counts = response.get("counts_total")
# Fix comparison to handle None value
if total_counts is not None and total_counts > 0:
return response.get("collection", [])
return []
def get_spot(self, spot_id: str) -> Dict[str, Any]:
"""
Get details for a specific spot.
Args:
spot_id: ID of the spot
Returns:
Spot details
"""
if not spot_id:
raise ValueError("spot_id is required")
return self._make_request("GET", f"spots/{spot_id}")
def get_spot_items(
self, spot_id: str, offset: int = 0, page_size: int = 100
) -> Dict[str, Any]:
"""
Get items in a specific spot.
Args:
spot_id: ID of the spot
offset: offset number
page_size: Number of items per page
Returns:
Items in the spot
"""
if not spot_id:
raise ValueError("spot_id is required")
params = {"spot": spot_id, "start": offset, "limit": page_size}
return self._make_request("GET", "items", params=params)
def get_item(self, item_id: str) -> Dict[str, Any]:
"""
Get details for a specific item.
Args:
item_id: ID of the item
Returns:
Item details
"""
if not item_id:
raise ValueError("item_id is required")
return self._make_request("GET", f"items/{item_id}")
def get_item_content(self, item_id: str) -> bytes:
"""
Get the raw content of an item.
Args:
item_id: ID of the item
Returns:
Raw content bytes
"""
if not item_id:
raise ValueError("item_id is required")
url = urljoin(self.base_url, f"items/{item_id}/content")
response = self.session.get(url, timeout=self.timeout)
response.raise_for_status()
return response.content
def health_check(self) -> bool:
"""
Check if the API is accessible and credentials are valid.
Returns:
True if API is accessible, False otherwise
"""
try:
self._make_request("GET", "spots", params={"limit": 1})
return True
except (HighspotClientError, HighspotAuthenticationError):
return False

View File

@@ -0,0 +1,431 @@
from datetime import datetime
from io import BytesIO
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.highspot.client import HighspotClient
from onyx.connectors.highspot.client import HighspotClientError
from onyx.connectors.highspot.utils import scrape_url_content
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
_SLIM_BATCH_SIZE = 1000
class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
"""
Connector for loading data from Highspot.
Retrieves content from specified spots using the Highspot API.
If no spots are specified, retrieves content from all available spots.
"""
def __init__(
self,
spot_names: List[str] = [],
batch_size: int = INDEX_BATCH_SIZE,
):
"""
Initialize the Highspot connector.
Args:
spot_names: List of spot names to retrieve content from (if empty, gets all spots)
batch_size: Number of items to retrieve in each batch
"""
self.spot_names = spot_names
self.batch_size = batch_size
self._client: Optional[HighspotClient] = None
self._spot_id_map: Dict[str, str] = {} # Maps spot names to spot IDs
self._all_spots_fetched = False
self.highspot_url: Optional[str] = None
self.key: Optional[str] = None
self.secret: Optional[str] = None
@property
def client(self) -> HighspotClient:
if self._client is None:
if not self.key or not self.secret:
raise ConnectorMissingCredentialError("Highspot")
# Ensure highspot_url is a string, use default if None
base_url = (
self.highspot_url
if self.highspot_url is not None
else HighspotClient.BASE_URL
)
self._client = HighspotClient(self.key, self.secret, base_url=base_url)
return self._client
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
logger.info("Loading Highspot credentials")
self.highspot_url = credentials.get("highspot_url")
self.key = credentials.get("highspot_key")
self.secret = credentials.get("highspot_secret")
return None
def _populate_spot_id_map(self) -> None:
"""
Populate the spot ID map with all available spots.
Keys are stored as lowercase for case-insensitive lookups.
"""
spots = self.client.get_spots()
for spot in spots:
if "title" in spot and "id" in spot:
spot_name = spot["title"]
self._spot_id_map[spot_name.lower()] = spot["id"]
self._all_spots_fetched = True
logger.info(f"Retrieved {len(self._spot_id_map)} spots from Highspot")
def _get_all_spot_names(self) -> List[str]:
"""
Retrieve all available spot names.
Returns:
List of all spot names
"""
if not self._all_spots_fetched:
self._populate_spot_id_map()
return [spot_name for spot_name in self._spot_id_map.keys()]
def _get_spot_id_from_name(self, spot_name: str) -> str:
"""
Get spot ID from a spot name.
Args:
spot_name: Name of the spot
Returns:
ID of the spot
Raises:
ValueError: If spot name is not found
"""
if not self._all_spots_fetched:
self._populate_spot_id_map()
spot_name_lower = spot_name.lower()
if spot_name_lower not in self._spot_id_map:
raise ValueError(f"Spot '{spot_name}' not found")
return self._spot_id_map[spot_name_lower]
def load_from_state(self) -> GenerateDocumentsOutput:
"""
Load content from configured spots in Highspot.
If no spots are configured, loads from all spots.
Yields:
Batches of Document objects
"""
return self.poll_source(None, None)
def poll_source(
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
"""
Poll Highspot for content updated since the start time.
Args:
start: Start time as seconds since Unix epoch
end: End time as seconds since Unix epoch
Yields:
Batches of Document objects
"""
doc_batch: list[Document] = []
# If no spots specified, get all spots
spot_names_to_process = self.spot_names
if not spot_names_to_process:
spot_names_to_process = self._get_all_spot_names()
logger.info(
f"No spots specified, using all {len(spot_names_to_process)} available spots"
)
for spot_name in spot_names_to_process:
try:
spot_id = self._get_spot_id_from_name(spot_name)
if spot_id is None:
logger.warning(f"Spot ID not found for spot {spot_name}")
continue
offset = 0
has_more = True
while has_more:
logger.info(
f"Retrieving items from spot {spot_name}, offset {offset}"
)
response = self.client.get_spot_items(
spot_id=spot_id, offset=offset, page_size=self.batch_size
)
items = response.get("collection", [])
logger.info(f"Received Items: {items}")
if not items:
has_more = False
continue
for item in items:
try:
item_id = item.get("id")
if not item_id:
logger.warning("Item without ID found, skipping")
continue
item_details = self.client.get_item(item_id)
if not item_details:
logger.warning(
f"Item {item_id} details not found, skipping"
)
continue
# Apply time filter if specified
if start or end:
updated_at = item_details.get("date_updated")
if updated_at:
# Convert to datetime for comparison
try:
updated_time = datetime.fromisoformat(
updated_at.replace("Z", "+00:00")
)
if (
start and updated_time.timestamp() < start
) or (end and updated_time.timestamp() > end):
continue
except (ValueError, TypeError):
# Skip if date cannot be parsed
logger.warning(
f"Invalid date format for item {item_id}: {updated_at}"
)
continue
content = self._get_item_content(item_details)
title = item_details.get("title", "")
doc_batch.append(
Document(
id=f"HIGHSPOT_{item_id}",
sections=[
TextSection(
link=item_details.get(
"url",
f"https://www.highspot.com/items/{item_id}",
),
text=content,
)
],
source=DocumentSource.HIGHSPOT,
semantic_identifier=title,
metadata={
"spot_name": spot_name,
"type": item_details.get("content_type", ""),
"created_at": item_details.get(
"date_added", ""
),
"author": item_details.get("author", ""),
"language": item_details.get("language", ""),
"can_download": str(
item_details.get("can_download", False)
),
},
doc_updated_at=item_details.get("date_updated"),
)
)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
except HighspotClientError as e:
item_id = "ID" if not item_id else item_id
logger.error(f"Error retrieving item {item_id}: {str(e)}")
has_more = len(items) >= self.batch_size
offset += self.batch_size
except (HighspotClientError, ValueError) as e:
logger.error(f"Error processing spot {spot_name}: {str(e)}")
if doc_batch:
yield doc_batch
def _get_item_content(self, item_details: Dict[str, Any]) -> str:
"""
Get the text content of an item.
Args:
item_details: Item details from the API
Returns:
Text content of the item
"""
item_id = item_details.get("id", "")
content_name = item_details.get("content_name", "")
is_valid_format = content_name and "." in content_name
file_extension = content_name.split(".")[-1].lower() if is_valid_format else ""
file_extension = "." + file_extension if file_extension else ""
can_download = item_details.get("can_download", False)
content_type = item_details.get("content_type", "")
# Extract title and description once at the beginning
title, description = self._extract_title_and_description(item_details)
default_content = f"{title}\n{description}"
logger.info(f"Processing item {item_id} with extension {file_extension}")
try:
if content_type == "WebLink":
url = item_details.get("url")
if not url:
return default_content
content = scrape_url_content(url, True)
return content if content else default_content
elif (
is_valid_format
and file_extension in ALL_ACCEPTED_FILE_EXTENSIONS
and can_download
):
# For documents, try to get the text content
if not item_id: # Ensure item_id is defined
return default_content
content_response = self.client.get_item_content(item_id)
# Process and extract text from binary content based on type
if content_response:
text_content = extract_file_text(
BytesIO(content_response), content_name
)
return text_content
return default_content
else:
return default_content
except HighspotClientError as e:
# Use item_id safely in the warning message
error_context = f"item {item_id}" if item_id else "item"
logger.warning(f"Could not retrieve content for {error_context}: {str(e)}")
return ""
def _extract_title_and_description(
self, item_details: Dict[str, Any]
) -> tuple[str, str]:
"""
Extract the title and description from item details.
Args:
item_details: Item details from the API
Returns:
Tuple of title and description
"""
title = item_details.get("title", "")
description = item_details.get("description", "")
return title, description
def retrieve_all_slim_documents(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
"""
Retrieve all document IDs from the configured spots.
If no spots are configured, retrieves from all spots.
Args:
start: Optional start time filter
end: Optional end time filter
callback: Optional indexing heartbeat callback
Yields:
Batches of SlimDocument objects
"""
slim_doc_batch: list[SlimDocument] = []
# If no spots specified, get all spots
spot_names_to_process = self.spot_names
if not spot_names_to_process:
spot_names_to_process = self._get_all_spot_names()
logger.info(
f"No spots specified, using all {len(spot_names_to_process)} available spots for slim documents"
)
for spot_name in spot_names_to_process:
try:
spot_id = self._get_spot_id_from_name(spot_name)
offset = 0
has_more = True
while has_more:
logger.info(
f"Retrieving slim documents from spot {spot_name}, offset {offset}"
)
response = self.client.get_spot_items(
spot_id=spot_id, offset=offset, page_size=self.batch_size
)
items = response.get("collection", [])
if not items:
has_more = False
continue
for item in items:
item_id = item.get("id")
if not item_id:
continue
slim_doc_batch.append(SlimDocument(id=f"HIGHSPOT_{item_id}"))
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
yield slim_doc_batch
slim_doc_batch = []
has_more = len(items) >= self.batch_size
offset += self.batch_size
except (HighspotClientError, ValueError) as e:
logger.error(
f"Error retrieving slim documents from spot {spot_name}: {str(e)}"
)
if slim_doc_batch:
yield slim_doc_batch
def validate_credentials(self) -> bool:
"""
Validate that the provided credentials can access the Highspot API.
Returns:
True if credentials are valid, False otherwise
"""
try:
return self.client.health_check()
except Exception as e:
logger.error(f"Failed to validate credentials: {str(e)}")
return False
if __name__ == "__main__":
spot_names: List[str] = []
connector = HighspotConnector(spot_names)
credentials = {"highspot_key": "", "highspot_secret": ""}
connector.load_credentials(credentials=credentials)
for doc in connector.load_from_state():
print(doc)

View File

@@ -0,0 +1,122 @@
from typing import Optional
from urllib.parse import urlparse
from bs4 import BeautifulSoup
from playwright.sync_api import sync_playwright
from onyx.file_processing.html_utils import web_html_cleanup
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Constants
WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser"
DEFAULT_TIMEOUT = 60000 # 60 seconds
def scrape_url_content(
url: str, scroll_before_scraping: bool = False, timeout_ms: int = DEFAULT_TIMEOUT
) -> Optional[str]:
"""
Scrapes content from a given URL and returns the cleaned text.
Args:
url: The URL to scrape
scroll_before_scraping: Whether to scroll through the page to load lazy content
timeout_ms: Timeout in milliseconds for page navigation and loading
Returns:
The cleaned text content of the page or None if scraping fails
"""
playwright = None
browser = None
try:
validate_url(url)
playwright = sync_playwright().start()
browser = playwright.chromium.launch(headless=True)
context = browser.new_context()
page = context.new_page()
logger.info(f"Navigating to URL: {url}")
try:
page.goto(url, timeout=timeout_ms)
except Exception as e:
logger.error(f"Failed to navigate to {url}: {str(e)}")
return None
if scroll_before_scraping:
logger.debug("Scrolling page to load lazy content")
scroll_attempts = 0
previous_height = page.evaluate("document.body.scrollHeight")
while scroll_attempts < WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS:
page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
try:
page.wait_for_load_state("networkidle", timeout=timeout_ms)
except Exception as e:
logger.warning(f"Network idle wait timed out: {str(e)}")
break
new_height = page.evaluate("document.body.scrollHeight")
if new_height == previous_height:
break
previous_height = new_height
scroll_attempts += 1
content = page.content()
soup = BeautifulSoup(content, "html.parser")
parsed_html = web_html_cleanup(soup)
if JAVASCRIPT_DISABLED_MESSAGE in parsed_html.cleaned_text:
logger.debug("JavaScript disabled message detected, checking iframes")
try:
iframe_count = page.frame_locator("iframe").locator("html").count()
if iframe_count > 0:
iframe_texts = (
page.frame_locator("iframe").locator("html").all_inner_texts()
)
iframe_content = "\n".join(iframe_texts)
if len(parsed_html.cleaned_text) < 700:
parsed_html.cleaned_text = iframe_content
else:
parsed_html.cleaned_text += "\n" + iframe_content
except Exception as e:
logger.warning(f"Error processing iframes: {str(e)}")
return parsed_html.cleaned_text
except Exception as e:
logger.error(f"Error scraping URL {url}: {str(e)}")
return None
finally:
if browser:
try:
browser.close()
except Exception as e:
logger.debug(f"Error closing browser: {str(e)}")
if playwright:
try:
playwright.stop()
except Exception as e:
logger.debug(f"Error stopping playwright: {str(e)}")
def validate_url(url: str) -> None:
"""
Validates that a URL is properly formatted.
Args:
url: The URL to validate
Raises:
ValueError: If URL is not valid
"""
parse = urlparse(url)
if parse.scheme != "http" and parse.scheme != "https":
raise ValueError("URL must be of scheme https?://")
if not parse.hostname:
raise ValueError("URL must include a hostname")

View File

@@ -4,6 +4,7 @@ from collections.abc import Iterator
from types import TracebackType
from typing import Any
from typing import Generic
from typing import TypeAlias
from typing import TypeVar
from pydantic import BaseModel
@@ -19,10 +20,11 @@ SecondsSinceUnixEpoch = float
GenerateDocumentsOutput = Iterator[list[Document]]
GenerateSlimDocumentOutput = Iterator[list[SlimDocument]]
CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]
CT = TypeVar("CT", bound=ConnectorCheckpoint)
class BaseConnector(abc.ABC):
class BaseConnector(abc.ABC, Generic[CT]):
REDIS_KEY_PREFIX = "da_connector_data:"
# Common image file extensions supported across connectors
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
@@ -57,6 +59,14 @@ class BaseConnector(abc.ABC):
Default is a no-op (always successful).
"""
def set_allow_images(self, value: bool) -> None:
"""Implement if the underlying connector wants to skip/allow image downloading
based on the application level image analysis setting."""
def build_dummy_checkpoint(self) -> CT:
# TODO: find a way to make this work without type: ignore
return ConnectorCheckpoint(has_more=True) # type: ignore
# Large set update or reindex, generally pulling a complete state or from a savestate file
class LoadConnector(BaseConnector):
@@ -74,6 +84,8 @@ class PollConnector(BaseConnector):
raise NotImplementedError
# Slim connectors can retrieve just the ids and
# permission syncing information for connected documents
class SlimConnector(BaseConnector):
@abc.abstractmethod
def retrieve_all_slim_documents(
@@ -186,14 +198,17 @@ class EventConnector(BaseConnector):
raise NotImplementedError
class CheckpointConnector(BaseConnector):
CheckpointOutput: TypeAlias = Generator[Document | ConnectorFailure, None, CT]
class CheckpointConnector(BaseConnector[CT]):
@abc.abstractmethod
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
checkpoint: CT,
) -> CheckpointOutput[CT]:
"""Yields back documents or failures. Final return is the new checkpoint.
Final return can be access via either:
@@ -214,3 +229,12 @@ class CheckpointConnector(BaseConnector):
```
"""
raise NotImplementedError
@abc.abstractmethod
def build_dummy_checkpoint(self) -> CT:
raise NotImplementedError
@abc.abstractmethod
def validate_checkpoint_json(self, checkpoint_json: str) -> CT:
"""Validate the checkpoint json and return the checkpoint object"""
raise NotImplementedError

View File

@@ -2,6 +2,7 @@ from typing import Any
import httpx
from pydantic import BaseModel
from typing_extensions import override
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
@@ -15,14 +16,18 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
class MockConnectorCheckpoint(ConnectorCheckpoint):
last_document_id: str | None = None
class SingleConnectorYield(BaseModel):
documents: list[Document]
checkpoint: ConnectorCheckpoint
checkpoint: MockConnectorCheckpoint
failures: list[ConnectorFailure]
unhandled_exception: str | None = None
class MockConnector(CheckpointConnector):
class MockConnector(CheckpointConnector[MockConnectorCheckpoint]):
def __init__(
self,
mock_server_host: str,
@@ -48,7 +53,7 @@ class MockConnector(CheckpointConnector):
def _get_mock_server_url(self, endpoint: str) -> str:
return f"http://{self.mock_server_host}:{self.mock_server_port}/{endpoint}"
def _save_checkpoint(self, checkpoint: ConnectorCheckpoint) -> None:
def _save_checkpoint(self, checkpoint: MockConnectorCheckpoint) -> None:
response = self.client.post(
self._get_mock_server_url("add-checkpoint"),
json=checkpoint.model_dump(mode="json"),
@@ -59,8 +64,8 @@ class MockConnector(CheckpointConnector):
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
checkpoint: MockConnectorCheckpoint,
) -> CheckpointOutput[MockConnectorCheckpoint]:
if self.connector_yields is None:
raise ValueError("No connector yields configured")
@@ -84,3 +89,13 @@ class MockConnector(CheckpointConnector):
yield failure
return current_yield.checkpoint
@override
def build_dummy_checkpoint(self) -> MockConnectorCheckpoint:
return MockConnectorCheckpoint(
has_more=True,
last_document_id=None,
)
def validate_checkpoint_json(self, checkpoint_json: str) -> MockConnectorCheckpoint:
return MockConnectorCheckpoint.model_validate_json(checkpoint_json)

View File

@@ -1,4 +1,3 @@
import json
from datetime import datetime
from enum import Enum
from typing import Any
@@ -232,21 +231,16 @@ class IndexAttemptMetadata(BaseModel):
class ConnectorCheckpoint(BaseModel):
# TODO: maybe move this to something disk-based to handle extremely large checkpoints?
checkpoint_content: dict
has_more: bool
@classmethod
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
def __str__(self) -> str:
"""String representation of the checkpoint, with truncation for large checkpoint content."""
MAX_CHECKPOINT_CONTENT_CHARS = 1000
content_str = json.dumps(self.checkpoint_content)
content_str = self.model_dump_json()
if len(content_str) > MAX_CHECKPOINT_CONTENT_CHARS:
content_str = content_str[: MAX_CHECKPOINT_CONTENT_CHARS - 3] + "..."
return f"ConnectorCheckpoint(checkpoint_content={content_str}, has_more={self.has_more})"
return content_str
class DocumentFailure(BaseModel):

View File

@@ -1,16 +1,16 @@
from collections.abc import Generator
from dataclasses import dataclass
from dataclasses import fields
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import Optional
import requests
from pydantic import BaseModel
from retry import retry
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP
from onyx.configs.app_configs import NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rl_requests,
@@ -25,6 +25,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
@@ -38,8 +39,7 @@ _NOTION_CALL_TIMEOUT = 30 # 30 seconds
# TODO: Tables need to be ingested, Pages need to have their metadata ingested
@dataclass
class NotionPage:
class NotionPage(BaseModel):
"""Represents a Notion Page object"""
id: str
@@ -49,17 +49,10 @@ class NotionPage:
properties: dict[str, Any]
url: str
database_name: str | None # Only applicable to the database type page (wiki)
def __init__(self, **kwargs: dict[str, Any]) -> None:
names = set([f.name for f in fields(self)])
for k, v in kwargs.items():
if k in names:
setattr(self, k, v)
database_name: str | None = None # Only applicable to the database type page (wiki)
@dataclass
class NotionBlock:
class NotionBlock(BaseModel):
"""Represents a Notion Block object"""
id: str # Used for the URL
@@ -69,20 +62,13 @@ class NotionBlock:
prefix: str
@dataclass
class NotionSearchResponse:
class NotionSearchResponse(BaseModel):
"""Represents the response from the Notion Search API"""
results: list[dict[str, Any]]
next_cursor: Optional[str]
has_more: bool = False
def __init__(self, **kwargs: dict[str, Any]) -> None:
names = set([f.name for f in fields(self)])
for k, v in kwargs.items():
if k in names:
setattr(self, k, v)
class NotionConnector(LoadConnector, PollConnector):
"""Notion Page connector that reads all Notion pages
@@ -95,7 +81,7 @@ class NotionConnector(LoadConnector, PollConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
recursive_index_enabled: bool = NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP,
recursive_index_enabled: bool = not NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP,
root_page_id: str | None = None,
) -> None:
"""Initialize with parameters."""
@@ -464,23 +450,53 @@ class NotionConnector(LoadConnector, PollConnector):
page_blocks, child_page_ids = self._read_blocks(page.id)
all_child_page_ids.extend(child_page_ids)
if not page_blocks:
continue
# okay to mark here since there's no way for this to not succeed
# without a critical failure
self.indexed_pages.add(page.id)
page_title = (
self._read_page_title(page) or f"Untitled Page with ID {page.id}"
)
raw_page_title = self._read_page_title(page)
page_title = raw_page_title or f"Untitled Page with ID {page.id}"
if not page_blocks:
if not raw_page_title:
logger.warning(
f"No blocks OR title found for page with ID '{page.id}'. Skipping."
)
continue
logger.debug(f"No blocks found for page with ID '{page.id}'")
"""
Something like:
TITLE
PROP1: PROP1_VALUE
PROP2: PROP2_VALUE
"""
text = page_title
if page.properties:
text += "\n\n" + "\n".join(
[f"{key}: {value}" for key, value in page.properties.items()]
)
sections = [
TextSection(
link=f"{page.url}",
text=text,
)
]
else:
sections = [
TextSection(
link=f"{page.url}#{block.id.replace('-', '')}",
text=block.prefix + block.text,
)
for block in page_blocks
]
yield (
Document(
id=page.id,
sections=[
TextSection(
link=f"{page.url}#{block.id.replace('-', '')}",
text=block.prefix + block.text,
)
for block in page_blocks
],
sections=cast(list[TextSection | ImageSection], sections),
source=DocumentSource.NOTION,
semantic_identifier=page_title,
doc_updated_at=datetime.fromisoformat(

View File

@@ -6,6 +6,7 @@ from typing import Any
from jira import JIRA
from jira.resources import Issue
from typing_extensions import override
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
@@ -15,14 +16,16 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_t
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info
@@ -42,121 +45,112 @@ _JIRA_SLIM_PAGE_SIZE = 500
_JIRA_FULL_PAGE_SIZE = 50
def _paginate_jql_search(
def _perform_jql_search(
jira_client: JIRA,
jql: str,
start: int,
max_results: int,
fields: str | None = None,
) -> Iterable[Issue]:
start = 0
while True:
logger.debug(
f"Fetching Jira issues with JQL: {jql}, "
f"starting at {start}, max results: {max_results}"
)
issues = jira_client.search_issues(
jql_str=jql,
startAt=start,
maxResults=max_results,
fields=fields,
)
logger.debug(
f"Fetching Jira issues with JQL: {jql}, "
f"starting at {start}, max results: {max_results}"
)
issues = jira_client.search_issues(
jql_str=jql,
startAt=start,
maxResults=max_results,
fields=fields,
)
for issue in issues:
if isinstance(issue, Issue):
yield issue
else:
raise Exception(f"Found Jira object not of type Issue: {issue}")
if len(issues) < max_results:
break
start += max_results
for issue in issues:
if isinstance(issue, Issue):
yield issue
else:
raise RuntimeError(f"Found Jira object not of type Issue: {issue}")
def fetch_jira_issues_batch(
def process_jira_issue(
jira_client: JIRA,
jql: str,
batch_size: int,
issue: Issue,
comment_email_blacklist: tuple[str, ...] = (),
labels_to_skip: set[str] | None = None,
) -> Iterable[Document]:
for issue in _paginate_jql_search(
jira_client=jira_client,
jql=jql,
max_results=batch_size,
):
if labels_to_skip:
if any(label in issue.fields.labels for label in labels_to_skip):
logger.info(
f"Skipping {issue.key} because it has a label to skip. Found "
f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
)
continue
description = (
issue.fields.description
if JIRA_API_VERSION == "2"
else extract_text_from_adf(issue.raw["fields"]["description"])
)
comments = get_comment_strs(
issue=issue,
comment_email_blacklist=comment_email_blacklist,
)
ticket_content = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment]
)
# Check ticket size
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
) -> Document | None:
if labels_to_skip:
if any(label in issue.fields.labels for label in labels_to_skip):
logger.info(
f"Skipping {issue.key} because it exceeds the maximum size of "
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
f"Skipping {issue.key} because it has a label to skip. Found "
f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}."
)
continue
return None
page_url = f"{jira_client.client_info()}/browse/{issue.key}"
description = (
issue.fields.description
if JIRA_API_VERSION == "2"
else extract_text_from_adf(issue.raw["fields"]["description"])
)
comments = get_comment_strs(
issue=issue,
comment_email_blacklist=comment_email_blacklist,
)
ticket_content = f"{description}\n" + "\n".join(
[f"Comment: {comment}" for comment in comments if comment]
)
people = set()
try:
creator = best_effort_get_field_from_issue(issue, "creator")
if basic_expert_info := best_effort_basic_expert_info(creator):
people.add(basic_expert_info)
except Exception:
# Author should exist but if not, doesn't matter
pass
try:
assignee = best_effort_get_field_from_issue(issue, "assignee")
if basic_expert_info := best_effort_basic_expert_info(assignee):
people.add(basic_expert_info)
except Exception:
# Author should exist but if not, doesn't matter
pass
metadata_dict = {}
if priority := best_effort_get_field_from_issue(issue, "priority"):
metadata_dict["priority"] = priority.name
if status := best_effort_get_field_from_issue(issue, "status"):
metadata_dict["status"] = status.name
if resolution := best_effort_get_field_from_issue(issue, "resolution"):
metadata_dict["resolution"] = resolution.name
if labels := best_effort_get_field_from_issue(issue, "labels"):
metadata_dict["label"] = labels
yield Document(
id=page_url,
sections=[TextSection(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
title=f"{issue.key} {issue.fields.summary}",
doc_updated_at=time_str_to_utc(issue.fields.updated),
primary_owners=list(people) or None,
# TODO add secondary_owners (commenters) if needed
metadata=metadata_dict,
# Check ticket size
if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE:
logger.info(
f"Skipping {issue.key} because it exceeds the maximum size of "
f"{JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes."
)
return None
page_url = build_jira_url(jira_client, issue.key)
people = set()
try:
creator = best_effort_get_field_from_issue(issue, "creator")
if basic_expert_info := best_effort_basic_expert_info(creator):
people.add(basic_expert_info)
except Exception:
# Author should exist but if not, doesn't matter
pass
try:
assignee = best_effort_get_field_from_issue(issue, "assignee")
if basic_expert_info := best_effort_basic_expert_info(assignee):
people.add(basic_expert_info)
except Exception:
# Author should exist but if not, doesn't matter
pass
metadata_dict = {}
if priority := best_effort_get_field_from_issue(issue, "priority"):
metadata_dict["priority"] = priority.name
if status := best_effort_get_field_from_issue(issue, "status"):
metadata_dict["status"] = status.name
if resolution := best_effort_get_field_from_issue(issue, "resolution"):
metadata_dict["resolution"] = resolution.name
if labels := best_effort_get_field_from_issue(issue, "labels"):
metadata_dict["labels"] = labels
return Document(
id=page_url,
sections=[TextSection(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
title=f"{issue.key} {issue.fields.summary}",
doc_updated_at=time_str_to_utc(issue.fields.updated),
primary_owners=list(people) or None,
metadata=metadata_dict,
)
class JiraConnector(LoadConnector, PollConnector, SlimConnector):
class JiraConnectorCheckpoint(ConnectorCheckpoint):
offset: int | None = None
class JiraConnector(CheckpointConnector[JiraConnectorCheckpoint], SlimConnector):
def __init__(
self,
jira_base_url: str,
@@ -200,33 +194,10 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
)
return None
def _get_jql_query(self) -> str:
"""Get the JQL query based on whether a specific project is set"""
if self.jira_project:
return f"project = {self.quoted_jira_project}"
return "" # Empty string means all accessible projects
def load_from_state(self) -> GenerateDocumentsOutput:
jql = self._get_jql_query()
document_batch = []
for doc in fetch_jira_issues_batch(
jira_client=self.jira_client,
jql=jql,
batch_size=_JIRA_FULL_PAGE_SIZE,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
):
document_batch.append(doc)
if len(document_batch) >= self.batch_size:
yield document_batch
document_batch = []
yield document_batch
def poll_source(
def _get_jql_query(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
) -> str:
"""Get the JQL query based on whether a specific project is set and time range"""
start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
"%Y-%m-%d %H:%M"
)
@@ -234,25 +205,61 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
"%Y-%m-%d %H:%M"
)
base_jql = self._get_jql_query()
jql = (
f"{base_jql} AND " if base_jql else ""
) + f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'"
time_jql = f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'"
document_batch = []
for doc in fetch_jira_issues_batch(
if self.jira_project:
base_jql = f"project = {self.quoted_jira_project}"
return f"{base_jql} AND {time_jql}"
return time_jql
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: JiraConnectorCheckpoint,
) -> CheckpointOutput[JiraConnectorCheckpoint]:
jql = self._get_jql_query(start, end)
# Get the current offset from checkpoint or start at 0
starting_offset = checkpoint.offset or 0
current_offset = starting_offset
for issue in _perform_jql_search(
jira_client=self.jira_client,
jql=jql,
batch_size=_JIRA_FULL_PAGE_SIZE,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
start=current_offset,
max_results=_JIRA_FULL_PAGE_SIZE,
):
document_batch.append(doc)
if len(document_batch) >= self.batch_size:
yield document_batch
document_batch = []
issue_key = issue.key
try:
if document := process_jira_issue(
jira_client=self.jira_client,
issue=issue,
comment_email_blacklist=self.comment_email_blacklist,
labels_to_skip=self.labels_to_skip,
):
yield document
yield document_batch
except Exception as e:
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=issue_key,
document_link=build_jira_url(self.jira_client, issue_key),
),
failure_message=f"Failed to process Jira issue: {str(e)}",
exception=e,
)
current_offset += 1
# Update checkpoint
checkpoint = JiraConnectorCheckpoint(
offset=current_offset,
# if we didn't retrieve a full batch, we're done
has_more=current_offset - starting_offset == _JIRA_FULL_PAGE_SIZE,
)
return checkpoint
def retrieve_all_slim_documents(
self,
@@ -260,12 +267,13 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
jql = self._get_jql_query()
jql = self._get_jql_query(start or 0, end or float("inf"))
slim_doc_batch = []
for issue in _paginate_jql_search(
for issue in _perform_jql_search(
jira_client=self.jira_client,
jql=jql,
start=0,
max_results=_JIRA_SLIM_PAGE_SIZE,
fields="key",
):
@@ -334,6 +342,16 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
raise RuntimeError(f"Unexpected Jira error during validation: {e}")
@override
def validate_checkpoint_json(self, checkpoint_json: str) -> JiraConnectorCheckpoint:
return JiraConnectorCheckpoint.model_validate_json(checkpoint_json)
@override
def build_dummy_checkpoint(self) -> JiraConnectorCheckpoint:
return JiraConnectorCheckpoint(
has_more=True,
)
if __name__ == "__main__":
import os
@@ -350,5 +368,7 @@ if __name__ == "__main__":
"jira_api_token": os.environ["JIRA_API_TOKEN"],
}
)
document_batches = connector.load_from_state()
document_batches = connector.load_from_checkpoint(
0, float("inf"), JiraConnectorCheckpoint(has_more=True)
)
print(next(document_batches))

View File

@@ -10,13 +10,15 @@ from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import TypedDict
from pydantic import BaseModel
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from typing_extensions import override
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import SLACK_NUM_THREADS
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
@@ -56,8 +58,8 @@ MessageType = dict[str, Any]
ThreadType = list[MessageType]
class SlackCheckpointContent(TypedDict):
channel_ids: list[str]
class SlackCheckpoint(ConnectorCheckpoint):
channel_ids: list[str] | None
channel_completion_map: dict[str, str]
current_channel: ChannelType | None
seen_thread_ts: list[str]
@@ -412,8 +414,8 @@ def _get_all_doc_ids(
callback=callback,
)
message_ts_set: set[str] = set()
for message_batch in channel_message_batches:
slim_doc_batch: list[SlimDocument] = []
for message in message_batch:
if msg_filter_func(message):
continue
@@ -421,18 +423,27 @@ def _get_all_doc_ids(
# The document id is the channel id and the ts of the first message in the thread
# Since we already have the first message of the thread, we dont have to
# fetch the thread for id retrieval, saving time and API calls
message_ts_set.add(message["ts"])
channel_metadata_list: list[SlimDocument] = []
for message_ts in message_ts_set:
channel_metadata_list.append(
SlimDocument(
id=_build_doc_id(channel_id=channel_id, thread_ts=message_ts),
perm_sync_data={"channel_id": channel_id},
slim_doc_batch.append(
SlimDocument(
id=_build_doc_id(
channel_id=channel_id, thread_ts=message["ts"]
),
perm_sync_data={"channel_id": channel_id},
)
)
)
yield channel_metadata_list
yield slim_doc_batch
class ProcessedSlackMessage(BaseModel):
doc: Document | None
# if the message is part of a thread, this is the thread_ts
# otherwise, this is the message_ts. Either way, will be a unique identifier.
# In the future, if the message becomes a thread, then the thread_ts
# will be set to the message_ts.
thread_or_message_ts: str
failure: ConnectorFailure | None
def _process_message(
@@ -443,8 +454,9 @@ def _process_message(
user_cache: dict[str, BasicExpertInfo | None],
seen_thread_ts: set[str],
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> tuple[Document | None, str | None, ConnectorFailure | None]:
) -> ProcessedSlackMessage:
thread_ts = message.get("thread_ts")
thread_or_message_ts = thread_ts or message["ts"]
try:
# causes random failures for testing checkpointing / continue on failure
# import random
@@ -460,16 +472,18 @@ def _process_message(
seen_thread_ts=seen_thread_ts,
msg_filter_func=msg_filter_func,
)
return (doc, thread_ts, None)
return ProcessedSlackMessage(
doc=doc, thread_or_message_ts=thread_or_message_ts, failure=None
)
except Exception as e:
logger.exception(f"Error processing message {message['ts']}")
return (
None,
thread_ts,
ConnectorFailure(
return ProcessedSlackMessage(
doc=None,
thread_or_message_ts=thread_or_message_ts,
failure=ConnectorFailure(
failed_document=DocumentFailure(
document_id=_build_doc_id(
channel_id=channel["id"], thread_ts=(thread_ts or message["ts"])
channel_id=channel["id"], thread_ts=thread_or_message_ts
),
document_link=get_message_link(message, client, channel["id"]),
),
@@ -479,7 +493,9 @@ def _process_message(
)
class SlackConnector(SlimConnector, CheckpointConnector):
class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
FAST_TIMEOUT = 1
def __init__(
self,
channels: list[str] | None = None,
@@ -487,12 +503,14 @@ class SlackConnector(SlimConnector, CheckpointConnector):
# regexes, and will only index channels that fully match the regexes
channel_regex_enabled: bool = False,
batch_size: int = INDEX_BATCH_SIZE,
num_threads: int = SLACK_NUM_THREADS,
) -> None:
self.channels = channels
self.channel_regex_enabled = channel_regex_enabled
self.batch_size = batch_size
self.num_threads = num_threads
self.client: WebClient | None = None
self.fast_client: WebClient | None = None
# just used for efficiency
self.text_cleaner: SlackTextCleaner | None = None
self.user_cache: dict[str, BasicExpertInfo | None] = {}
@@ -500,6 +518,10 @@ class SlackConnector(SlimConnector, CheckpointConnector):
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
bot_token = credentials["slack_bot_token"]
self.client = WebClient(token=bot_token)
# use for requests that must return quickly (e.g. realtime flows where user is waiting)
self.fast_client = WebClient(
token=bot_token, timeout=SlackConnector.FAST_TIMEOUT
)
self.text_cleaner = SlackTextCleaner(client=self.client)
return None
@@ -523,8 +545,8 @@ class SlackConnector(SlimConnector, CheckpointConnector):
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
checkpoint: SlackCheckpoint,
) -> CheckpointOutput[SlackCheckpoint]:
"""Rough outline:
Step 1: Get all channels, yield back Checkpoint.
@@ -540,49 +562,36 @@ class SlackConnector(SlimConnector, CheckpointConnector):
if self.client is None or self.text_cleaner is None:
raise ConnectorMissingCredentialError("Slack")
checkpoint_content = cast(
SlackCheckpointContent,
(
copy.deepcopy(checkpoint.checkpoint_content)
or {
"channel_ids": None,
"channel_completion_map": {},
"current_channel": None,
"seen_thread_ts": [],
}
),
)
checkpoint = cast(SlackCheckpoint, copy.deepcopy(checkpoint))
# if this is the very first time we've called this, need to
# get all relevant channels and save them into the checkpoint
if checkpoint_content["channel_ids"] is None:
if checkpoint.channel_ids is None:
raw_channels = get_channels(self.client)
filtered_channels = filter_channels(
raw_channels, self.channels, self.channel_regex_enabled
)
checkpoint.channel_ids = [c["id"] for c in filtered_channels]
if len(filtered_channels) == 0:
checkpoint.has_more = False
return checkpoint
checkpoint_content["channel_ids"] = [c["id"] for c in filtered_channels]
checkpoint_content["current_channel"] = filtered_channels[0]
checkpoint = ConnectorCheckpoint(
checkpoint_content=checkpoint_content, # type: ignore
has_more=True,
)
checkpoint.current_channel = filtered_channels[0]
checkpoint.has_more = True
return checkpoint
final_channel_ids = checkpoint_content["channel_ids"]
channel = checkpoint_content["current_channel"]
final_channel_ids = checkpoint.channel_ids
channel = checkpoint.current_channel
if channel is None:
raise ValueError("current_channel key not found in checkpoint")
raise ValueError("current_channel key not set in checkpoint")
channel_id = channel["id"]
if channel_id not in final_channel_ids:
raise ValueError(f"Channel {channel_id} not found in checkpoint")
oldest = str(start) if start else None
latest = checkpoint_content["channel_completion_map"].get(channel_id, str(end))
seen_thread_ts = set(checkpoint_content["seen_thread_ts"])
latest = checkpoint.channel_completion_map.get(channel_id, str(end))
seen_thread_ts = set(checkpoint.seen_thread_ts)
try:
logger.debug(
f"Getting messages for channel {channel} within range {oldest} - {latest}"
@@ -593,8 +602,8 @@ class SlackConnector(SlimConnector, CheckpointConnector):
new_latest = message_batch[-1]["ts"] if message_batch else latest
# Process messages in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=8) as executor:
futures: list[Future] = []
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
futures: list[Future[ProcessedSlackMessage]] = []
for message in message_batch:
# Capture the current context so that the thread gets the current tenant ID
current_context = contextvars.copy_context()
@@ -612,46 +621,46 @@ class SlackConnector(SlimConnector, CheckpointConnector):
)
for future in as_completed(futures):
doc, thread_ts, failures = future.result()
processed_slack_message = future.result()
doc = processed_slack_message.doc
thread_or_message_ts = processed_slack_message.thread_or_message_ts
failure = processed_slack_message.failure
if doc:
# handle race conditions here since this is single
# threaded. Multi-threaded _process_message reads from this
# but since this is single threaded, we won't run into simul
# writes. At worst, we can duplicate a thread, which will be
# deduped later on.
if thread_ts not in seen_thread_ts:
if thread_or_message_ts not in seen_thread_ts:
yield doc
if thread_ts:
seen_thread_ts.add(thread_ts)
elif failures:
for failure in failures:
yield failure
assert (
thread_or_message_ts
), "found non-None doc with None thread_or_message_ts"
seen_thread_ts.add(thread_or_message_ts)
elif failure:
yield failure
checkpoint_content["seen_thread_ts"] = list(seen_thread_ts)
checkpoint_content["channel_completion_map"][channel["id"]] = new_latest
checkpoint.seen_thread_ts = list(seen_thread_ts)
checkpoint.channel_completion_map[channel["id"]] = new_latest
if has_more_in_channel:
checkpoint_content["current_channel"] = channel
checkpoint.current_channel = channel
else:
new_channel_id = next(
(
channel_id
for channel_id in final_channel_ids
if channel_id
not in checkpoint_content["channel_completion_map"]
if channel_id not in checkpoint.channel_completion_map
),
None,
)
if new_channel_id:
new_channel = _get_channel_by_id(self.client, new_channel_id)
checkpoint_content["current_channel"] = new_channel
checkpoint.current_channel = new_channel
else:
checkpoint_content["current_channel"] = None
checkpoint.current_channel = None
checkpoint = ConnectorCheckpoint(
checkpoint_content=checkpoint_content, # type: ignore
has_more=checkpoint_content["current_channel"] is not None,
)
checkpoint.has_more = checkpoint.current_channel is not None
return checkpoint
except Exception as e:
@@ -675,12 +684,12 @@ class SlackConnector(SlimConnector, CheckpointConnector):
2. Ensure the bot has enough scope to list channels.
3. Check that every channel specified in self.channels exists (only when regex is not enabled).
"""
if self.client is None:
if self.fast_client is None:
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
try:
# 1) Validate connection to workspace
auth_response = self.client.auth_test()
auth_response = self.fast_client.auth_test()
if not auth_response.get("ok", False):
error_msg = auth_response.get(
"error", "Unknown error from Slack auth_test"
@@ -688,7 +697,7 @@ class SlackConnector(SlimConnector, CheckpointConnector):
raise ConnectorValidationError(f"Failed Slack auth_test: {error_msg}")
# 2) Minimal test to confirm listing channels works
test_resp = self.client.conversations_list(
test_resp = self.fast_client.conversations_list(
limit=1, types=["public_channel"]
)
if not test_resp.get("ok", False):
@@ -706,29 +715,41 @@ class SlackConnector(SlimConnector, CheckpointConnector):
)
# 3) If channels are specified and regex is not enabled, verify each is accessible
if self.channels and not self.channel_regex_enabled:
accessible_channels = get_channels(
client=self.client,
exclude_archived=True,
get_public=True,
get_private=True,
)
# For quick lookups by name or ID, build a map:
accessible_channel_names = {ch["name"] for ch in accessible_channels}
accessible_channel_ids = {ch["id"] for ch in accessible_channels}
# NOTE: removed this for now since it may be too slow for large workspaces which may
# have some automations which create a lot of channels (100k+)
for user_channel in self.channels:
if (
user_channel not in accessible_channel_names
and user_channel not in accessible_channel_ids
):
raise ConnectorValidationError(
f"Channel '{user_channel}' not found or inaccessible in this workspace."
)
# if self.channels and not self.channel_regex_enabled:
# accessible_channels = get_channels(
# client=self.fast_client,
# exclude_archived=True,
# get_public=True,
# get_private=True,
# )
# # For quick lookups by name or ID, build a map:
# accessible_channel_names = {ch["name"] for ch in accessible_channels}
# accessible_channel_ids = {ch["id"] for ch in accessible_channels}
# for user_channel in self.channels:
# if (
# user_channel not in accessible_channel_names
# and user_channel not in accessible_channel_ids
# ):
# raise ConnectorValidationError(
# f"Channel '{user_channel}' not found or inaccessible in this workspace."
# )
except SlackApiError as e:
slack_error = e.response.get("error", "")
if slack_error == "missing_scope":
if slack_error == "ratelimited":
# Handle rate limiting specifically
retry_after = int(e.response.headers.get("Retry-After", 1))
logger.warning(
f"Slack API rate limited during validation. Retry suggested after {retry_after} seconds. "
"Proceeding with validation, but be aware that connector operations might be throttled."
)
# Continue validation without failing - the connector is likely valid but just rate limited
return
elif slack_error == "missing_scope":
raise InsufficientPermissionsError(
"Slack bot token lacks the necessary scope to list/access channels. "
"Please ensure your Slack app has 'channels:read' (and/or 'groups:read' for private channels)."
@@ -751,6 +772,20 @@ class SlackConnector(SlimConnector, CheckpointConnector):
f"Unexpected error during Slack settings validation: {e}"
)
@override
def build_dummy_checkpoint(self) -> SlackCheckpoint:
return SlackCheckpoint(
channel_ids=None,
channel_completion_map={},
current_channel=None,
seen_thread_ts=[],
has_more=True,
)
@override
def validate_checkpoint_json(self, checkpoint_json: str) -> SlackCheckpoint:
return SlackCheckpoint.model_validate_json(checkpoint_json)
if __name__ == "__main__":
import os
@@ -765,9 +800,11 @@ if __name__ == "__main__":
current = time.time()
one_day_ago = current - 24 * 60 * 60 # 1 day
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
checkpoint = connector.build_dummy_checkpoint()
gen = connector.load_from_checkpoint(one_day_ago, current, checkpoint)
gen = connector.load_from_checkpoint(
one_day_ago, current, cast(SlackCheckpoint, checkpoint)
)
try:
for document_or_failure in gen:
if isinstance(document_or_failure, Document):

View File

@@ -1,23 +1,32 @@
import copy
import time
from collections.abc import Iterator
from typing import Any
from typing import cast
import requests
from pydantic import BaseModel
from requests.exceptions import HTTPError
from typing_extensions import override
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
time_str_to_utc,
)
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorFailure
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import parse_html_page_basic
@@ -26,6 +35,7 @@ from onyx.utils.retry_wrapper import retry_builder
MAX_PAGE_SIZE = 30 # Zendesk API maximum
MAX_AUTHOR_MAP_SIZE = 50_000 # Reset author map cache if it gets too large
_SLIM_BATCH_SIZE = 1000
@@ -53,10 +63,22 @@ class ZendeskClient:
# Sleep for the duration indicated by the Retry-After header
time.sleep(int(retry_after))
elif (
response.status_code == 403
and response.json().get("error") == "SupportProductInactive"
):
return response.json()
response.raise_for_status()
return response.json()
class ZendeskPageResponse(BaseModel):
data: list[dict[str, Any]]
meta: dict[str, Any]
has_more: bool
def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
content_tags: dict[str, str] = {}
params = {"page[size]": MAX_PAGE_SIZE}
@@ -82,11 +104,9 @@ def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
def _get_articles(
client: ZendeskClient, start_time: int | None = None, page_size: int = MAX_PAGE_SIZE
) -> Iterator[dict[str, Any]]:
params = (
{"start_time": start_time, "page[size]": page_size}
if start_time
else {"page[size]": page_size}
)
params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"}
if start_time is not None:
params["start_time"] = start_time
while True:
data = client.make_request("help_center/articles", params)
@@ -98,10 +118,30 @@ def _get_articles(
params["page[after]"] = data["meta"]["after_cursor"]
def _get_article_page(
client: ZendeskClient,
start_time: int | None = None,
after_cursor: str | None = None,
page_size: int = MAX_PAGE_SIZE,
) -> ZendeskPageResponse:
params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"}
if start_time is not None:
params["start_time"] = start_time
if after_cursor is not None:
params["page[after]"] = after_cursor
data = client.make_request("help_center/articles", params)
return ZendeskPageResponse(
data=data["articles"],
meta=data["meta"],
has_more=bool(data["meta"].get("has_more", False)),
)
def _get_tickets(
client: ZendeskClient, start_time: int | None = None
) -> Iterator[dict[str, Any]]:
params = {"start_time": start_time} if start_time else {"start_time": 0}
params = {"start_time": start_time or 0}
while True:
data = client.make_request("incremental/tickets.json", params)
@@ -114,9 +154,33 @@ def _get_tickets(
break
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
# TODO: maybe these don't need to be their own functions?
def _get_tickets_page(
client: ZendeskClient, start_time: int | None = None
) -> ZendeskPageResponse:
params = {"start_time": start_time or 0}
# NOTE: for some reason zendesk doesn't seem to be respecting the start_time param
# in my local testing with very few tickets. We'll look into it if this becomes an
# issue in larger deployments
data = client.make_request("incremental/tickets.json", params)
if data.get("error") == "SupportProductInactive":
raise ValueError(
"Zendesk Support Product is not active for this account, No tickets to index"
)
return ZendeskPageResponse(
data=data["tickets"],
meta={"end_time": data["end_time"]},
has_more=not bool(data.get("end_of_stream", False)),
)
def _fetch_author(
client: ZendeskClient, author_id: str | int
) -> BasicExpertInfo | None:
# Skip fetching if author_id is invalid
if not author_id or author_id == "-1":
# cast to str to avoid issues with zendesk changing their types
if not author_id or str(author_id) == "-1":
return None
try:
@@ -278,13 +342,22 @@ def _ticket_to_document(
)
class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
class ZendeskConnectorCheckpoint(ConnectorCheckpoint):
# We use cursor-based paginated retrieval for articles
after_cursor_articles: str | None
# We use timestamp-based paginated retrieval for tickets
next_start_time_tickets: int | None
cached_author_map: dict[str, BasicExpertInfo] | None
cached_content_tags: dict[str, str] | None
class ZendeskConnector(SlimConnector, CheckpointConnector[ZendeskConnectorCheckpoint]):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
content_type: str = "articles",
) -> None:
self.batch_size = batch_size
self.content_type = content_type
self.subdomain = ""
# Fetch all tags ahead of time
@@ -304,33 +377,50 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
return self.poll_source(None, None)
def poll_source(
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ZendeskConnectorCheckpoint,
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
if self.client is None:
raise ZendeskCredentialsNotSetUpError()
self.content_tags = _get_content_tag_mapping(self.client)
if checkpoint.cached_content_tags is None:
checkpoint.cached_content_tags = _get_content_tag_mapping(self.client)
return checkpoint # save the content tags to the checkpoint
self.content_tags = checkpoint.cached_content_tags
if self.content_type == "articles":
yield from self._poll_articles(start)
checkpoint = yield from self._retrieve_articles(start, end, checkpoint)
return checkpoint
elif self.content_type == "tickets":
yield from self._poll_tickets(start)
checkpoint = yield from self._retrieve_tickets(start, end, checkpoint)
return checkpoint
else:
raise ValueError(f"Unsupported content_type: {self.content_type}")
def _poll_articles(
self, start: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
articles = _get_articles(self.client, start_time=int(start) if start else None)
def _retrieve_articles(
self,
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
checkpoint: ZendeskConnectorCheckpoint,
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
checkpoint = copy.deepcopy(checkpoint)
# This one is built on the fly as there may be more many more authors than tags
author_map: dict[str, BasicExpertInfo] = {}
author_map: dict[str, BasicExpertInfo] = checkpoint.cached_author_map or {}
after_cursor = checkpoint.after_cursor_articles
doc_batch: list[Document] = []
doc_batch = []
response = _get_article_page(
self.client,
start_time=int(start) if start else None,
after_cursor=after_cursor,
)
articles = response.data
has_more = response.has_more
after_cursor = response.meta.get("after_cursor")
for article in articles:
if (
article.get("body") is None
@@ -342,66 +432,109 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
):
continue
new_author_map, documents = _article_to_document(
article, self.content_tags, author_map, self.client
)
try:
new_author_map, document = _article_to_document(
article, self.content_tags, author_map, self.client
)
except Exception as e:
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=f"{article.get('id')}",
document_link=article.get("html_url", ""),
),
failure_message=str(e),
exception=e,
)
continue
if new_author_map:
author_map.update(new_author_map)
doc_batch.append(documents)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch.clear()
doc_batch.append(document)
if doc_batch:
yield doc_batch
if not has_more:
yield from doc_batch
checkpoint.has_more = False
return checkpoint
def _poll_tickets(
self, start: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
# Sometimes no documents are retrieved, but the cursor
# is still updated so the connector makes progress.
yield from doc_batch
checkpoint.after_cursor_articles = after_cursor
last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None
checkpoint.has_more = bool(
end is None
or last_doc_updated_at is None
or last_doc_updated_at.timestamp() <= end
)
checkpoint.cached_author_map = (
author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None
)
return checkpoint
def _retrieve_tickets(
self,
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
checkpoint: ZendeskConnectorCheckpoint,
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
checkpoint = copy.deepcopy(checkpoint)
if self.client is None:
raise ZendeskCredentialsNotSetUpError()
author_map: dict[str, BasicExpertInfo] = {}
author_map: dict[str, BasicExpertInfo] = checkpoint.cached_author_map or {}
ticket_generator = _get_tickets(
self.client, start_time=int(start) if start else None
doc_batch: list[Document] = []
next_start_time = int(checkpoint.next_start_time_tickets or start or 0)
ticket_response = _get_tickets_page(self.client, start_time=next_start_time)
tickets = ticket_response.data
has_more = ticket_response.has_more
next_start_time = ticket_response.meta["end_time"]
for ticket in tickets:
if ticket.get("status") == "deleted":
continue
try:
new_author_map, document = _ticket_to_document(
ticket=ticket,
author_map=author_map,
client=self.client,
default_subdomain=self.subdomain,
)
except Exception as e:
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=f"{ticket.get('id')}",
document_link=ticket.get("url", ""),
),
failure_message=str(e),
exception=e,
)
continue
if new_author_map:
author_map.update(new_author_map)
doc_batch.append(document)
if not has_more:
yield from doc_batch
checkpoint.has_more = False
return checkpoint
yield from doc_batch
checkpoint.next_start_time_tickets = next_start_time
last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None
checkpoint.has_more = bool(
end is None
or last_doc_updated_at is None
or last_doc_updated_at.timestamp() <= end
)
while True:
doc_batch = []
for _ in range(self.batch_size):
try:
ticket = next(ticket_generator)
# Check if the ticket status is deleted and skip it if so
if ticket.get("status") == "deleted":
continue
new_author_map, documents = _ticket_to_document(
ticket=ticket,
author_map=author_map,
client=self.client,
default_subdomain=self.subdomain,
)
if new_author_map:
author_map.update(new_author_map)
doc_batch.append(documents)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch.clear()
except StopIteration:
# No more tickets to process
if doc_batch:
yield doc_batch
return
if doc_batch:
yield doc_batch
checkpoint.cached_author_map = (
author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None
)
return checkpoint
def retrieve_all_slim_documents(
self,
@@ -441,10 +574,51 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
if slim_doc_batch:
yield slim_doc_batch
@override
def validate_connector_settings(self) -> None:
if self.client is None:
raise ZendeskCredentialsNotSetUpError()
try:
_get_article_page(self.client, start_time=0)
except HTTPError as e:
# Check for HTTP status codes
if e.response.status_code == 401:
raise CredentialExpiredError(
"Your Zendesk credentials appear to be invalid or expired (HTTP 401)."
) from e
elif e.response.status_code == 403:
raise InsufficientPermissionsError(
"Your Zendesk token does not have sufficient permissions (HTTP 403)."
) from e
elif e.response.status_code == 404:
raise ConnectorValidationError(
"Zendesk resource not found (HTTP 404)."
) from e
else:
raise ConnectorValidationError(
f"Unexpected Zendesk error (status={e.response.status_code}): {e}"
) from e
@override
def validate_checkpoint_json(
self, checkpoint_json: str
) -> ZendeskConnectorCheckpoint:
return ZendeskConnectorCheckpoint.model_validate_json(checkpoint_json)
@override
def build_dummy_checkpoint(self) -> ZendeskConnectorCheckpoint:
return ZendeskConnectorCheckpoint(
after_cursor_articles=None,
next_start_time_tickets=None,
cached_author_map=None,
cached_content_tags=None,
has_more=True,
)
if __name__ == "__main__":
import os
import time
connector = ZendeskConnector()
connector.load_credentials(
@@ -457,6 +631,8 @@ if __name__ == "__main__":
current = time.time()
one_day_ago = current - 24 * 60 * 60 # 1 day
document_batches = connector.poll_source(one_day_ago, current)
document_batches = connector.load_from_checkpoint(
one_day_ago, current, connector.build_dummy_checkpoint()
)
print(next(document_batches))

View File

@@ -1089,3 +1089,20 @@ def log_agent_sub_question_results(
db_session.commit()
return None
def update_chat_session_updated_at_timestamp(
chat_session_id: UUID, db_session: Session
) -> None:
"""
Explicitly update the timestamp on a chat session without modifying other fields.
This is useful when adding messages to a chat session to reflect recent activity.
"""
# Direct SQL update to avoid loading the entire object if it's not already loaded
db_session.execute(
update(ChatSession)
.where(ChatSession.id == chat_session_id)
.values(time_updated=func.now())
)
# No commit - the caller is responsible for committing the transaction

View File

@@ -1,2 +1,4 @@
SLACK_BOT_PERSONA_PREFIX = "__slack_bot_persona__"
DEFAULT_PERSONA_SLACK_CHANNEL_NAME = "DEFAULT_SLACK_CHANNEL"
CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX = "ConnectorValidationError:"

View File

@@ -555,6 +555,28 @@ def delete_documents_by_connector_credential_pair__no_commit(
db_session.execute(stmt)
def delete_all_documents_by_connector_credential_pair__no_commit(
db_session: Session,
connector_id: int,
credential_id: int,
) -> None:
"""Deletes all document by connector credential pair entries for a specific connector and credential.
This is primarily used during connector deletion to ensure all references are removed
before deleting the connector itself. This is crucial because connector_id is part of the
primary key in DocumentByConnectorCredentialPair, and attempting to delete the Connector
would otherwise try to set the foreign key to NULL, which fails for primary keys.
NOTE: Does not commit the transaction, this must be done by the caller.
"""
stmt = delete(DocumentByConnectorCredentialPair).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
db_session.execute(stmt)
def delete_documents__no_commit(db_session: Session, document_ids: list[str]) -> None:
db_session.execute(delete(DbDocument).where(DbDocument.id.in_(document_ids)))

View File

@@ -8,23 +8,31 @@ from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from onyx.connectors.models import ConnectorFailure
from onyx.db.engine import get_session_context_manager
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.server.documents.models import ConnectorCredentialPair
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
# Comment out unused imports that cause mypy errors
# from onyx.auth.models import UserRole
# from onyx.configs.constants import MAX_LAST_VALID_CHECKPOINT_AGE_SECONDS
# from onyx.db.connector_credential_pair import ConnectorCredentialPairIdentifier
# from onyx.db.engine import async_query_for_dms
logger = setup_logger()
@@ -201,6 +209,17 @@ def mark_attempt_in_progress(
attempt.status = IndexingStatus.IN_PROGRESS
attempt.time_started = index_attempt.time_started or func.now() # type: ignore
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt.id,
"status": IndexingStatus.IN_PROGRESS.value,
"cc_pair_id": index_attempt.connector_credential_pair_id,
"search_settings_id": index_attempt.search_settings_id,
},
)
except Exception:
db_session.rollback()
raise
@@ -219,6 +238,19 @@ def mark_attempt_succeeded(
attempt.status = IndexingStatus.SUCCESS
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.SUCCESS.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -237,6 +269,19 @@ def mark_attempt_partially_succeeded(
attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.COMPLETED_WITH_ERRORS.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -259,6 +304,20 @@ def mark_attempt_canceled(
attempt.status = IndexingStatus.CANCELED
attempt.error_msg = reason
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.CANCELED.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"reason": reason,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -283,6 +342,20 @@ def mark_attempt_failed(
attempt.error_msg = failure_reason
attempt.full_exception_trace = full_exception_trace
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.FAILED.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"reason": failure_reason,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -434,7 +507,7 @@ def get_latest_index_attempts_parallel(
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
return get_latest_index_attempts(
secondary_index,
db_session,

View File

@@ -16,8 +16,8 @@ from onyx.db.models import User__UserGroup
from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from shared_configs.enums import EmbeddingProvider
@@ -67,7 +67,7 @@ def upsert_cloud_embedding_provider(
def upsert_llm_provider(
llm_provider: LLMProviderUpsertRequest,
db_session: Session,
) -> FullLLMProvider:
) -> LLMProviderView:
existing_llm_provider = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == llm_provider.name)
)
@@ -98,7 +98,7 @@ def upsert_llm_provider(
group_ids=llm_provider.groups,
db_session=db_session,
)
full_llm_provider = FullLLMProvider.from_model(existing_llm_provider)
full_llm_provider = LLMProviderView.from_model(existing_llm_provider)
db_session.commit()
@@ -132,6 +132,16 @@ def fetch_existing_llm_providers(
return list(db_session.scalars(stmt).all())
def fetch_existing_llm_provider(
provider_name: str, db_session: Session
) -> LLMProviderModel | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
)
return provider_model
def fetch_existing_llm_providers_for_user(
db_session: Session,
user: User | None = None,
@@ -177,7 +187,7 @@ def fetch_embedding_provider(
)
def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
def fetch_default_provider(db_session: Session) -> LLMProviderView | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_provider == True # noqa: E712
@@ -185,10 +195,10 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
)
if not provider_model:
return None
return FullLLMProvider.from_model(provider_model)
return LLMProviderView.from_model(provider_model)
def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None:
def fetch_default_vision_provider(db_session: Session) -> LLMProviderView | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_vision_provider == True # noqa: E712
@@ -196,16 +206,18 @@ def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None
)
if not provider_model:
return None
return FullLLMProvider.from_model(provider_model)
return LLMProviderView.from_model(provider_model)
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
def fetch_llm_provider_view(
db_session: Session, provider_name: str
) -> LLMProviderView | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
)
if not provider_model:
return None
return FullLLMProvider.from_model(provider_model)
return LLMProviderView.from_model(provider_model)
def remove_embedding_provider(

View File

@@ -24,7 +24,9 @@ from onyx.db.models import User__UserGroup
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
def validate_user_role_update(requested_role: UserRole, current_role: UserRole) -> None:
def validate_user_role_update(
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
) -> None:
"""
Validate that a user role update is valid.
Assumed only admins can hit this endpoint.
@@ -57,6 +59,9 @@ def validate_user_role_update(requested_role: UserRole, current_role: UserRole)
detail="To change a Limited User's role, they must first login to Onyx via the web app.",
)
if explicit_override:
return
if requested_role == UserRole.CURATOR:
# This shouldn't happen, but just in case
raise HTTPException(

View File

@@ -821,26 +821,30 @@ class VespaIndex(DocumentIndex):
num_to_retrieve: int = NUM_RETURNED_HITS,
offset: int = 0,
) -> list[InferenceChunkUncleaned]:
vespa_where_clauses = build_vespa_filters(filters, include_hidden=True)
yql = (
YQL_BASE.format(index_name=self.index_name)
+ vespa_where_clauses
+ '({grammar: "weakAnd"}userInput(@query) '
# `({defaultIndex: "content_summary"}userInput(@query))` section is
# needed for highlighting while the N-gram highlighting is broken /
# not working as desired
+ f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))'
vespa_where_clauses = build_vespa_filters(
filters, include_hidden=True, remove_trailing_and=True
)
yql = YQL_BASE.format(index_name=self.index_name) + vespa_where_clauses
params: dict[str, str | int] = {
"yql": yql,
"query": query,
"hits": num_to_retrieve,
"offset": 0,
"ranking.profile": "admin_search",
"timeout": VESPA_TIMEOUT,
}
if len(query.strip()) > 0:
yql += (
' and ({grammar: "weakAnd"}userInput(@query) '
# `({defaultIndex: "content_summary"}userInput(@query))` section is
# needed for highlighting while the N-gram highlighting is broken /
# not working as desired
+ f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))'
)
params["yql"] = yql
params["query"] = query
return query_vespa(params)
# Retrieves chunk information for a document:

View File

@@ -5,13 +5,15 @@ import re
import zipfile
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from email.parser import Parser as EmailParser
from enum import auto
from enum import IntFlag
from io import BytesIO
from pathlib import Path
from typing import Any
from typing import IO
from typing import List
from typing import Tuple
from typing import NamedTuple
import chardet
import docx # type: ignore
@@ -35,7 +37,7 @@ logger = setup_logger()
TEXT_SECTION_SEPARATOR = "\n\n"
PLAIN_TEXT_FILE_EXTENSIONS = [
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS = [
".txt",
".md",
".mdx",
@@ -49,7 +51,7 @@ PLAIN_TEXT_FILE_EXTENSIONS = [
".yaml",
]
VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [
ACCEPTED_DOCUMENT_FILE_EXTENSIONS = [
".pdf",
".docx",
".pptx",
@@ -57,12 +59,21 @@ VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [
".eml",
".epub",
".html",
]
ACCEPTED_IMAGE_FILE_EXTENSIONS = [
".png",
".jpg",
".jpeg",
".webp",
]
ALL_ACCEPTED_FILE_EXTENSIONS = (
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
+ ACCEPTED_DOCUMENT_FILE_EXTENSIONS
+ ACCEPTED_IMAGE_FILE_EXTENSIONS
)
IMAGE_MEDIA_TYPES = [
"image/png",
"image/jpeg",
@@ -70,8 +81,15 @@ IMAGE_MEDIA_TYPES = [
]
class OnyxExtensionType(IntFlag):
Plain = auto()
Document = auto()
Multimedia = auto()
All = Plain | Document | Multimedia
def is_text_file_extension(file_name: str) -> bool:
return any(file_name.endswith(ext) for ext in PLAIN_TEXT_FILE_EXTENSIONS)
return any(file_name.endswith(ext) for ext in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS)
def get_file_ext(file_path_or_name: str | Path) -> str:
@@ -83,8 +101,20 @@ def is_valid_media_type(media_type: str) -> bool:
return media_type in IMAGE_MEDIA_TYPES
def is_valid_file_ext(ext: str) -> bool:
return ext in VALID_FILE_EXTENSIONS
def is_accepted_file_ext(ext: str, ext_type: OnyxExtensionType) -> bool:
if ext_type & OnyxExtensionType.Plain:
if ext in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS:
return True
if ext_type & OnyxExtensionType.Document:
if ext in ACCEPTED_DOCUMENT_FILE_EXTENSIONS:
return True
if ext_type & OnyxExtensionType.Multimedia:
if ext in ACCEPTED_IMAGE_FILE_EXTENSIONS:
return True
return False
def is_text_file(file: IO[bytes]) -> bool:
@@ -219,7 +249,7 @@ def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str:
def read_pdf_file(
file: IO[Any], pdf_pass: str | None = None, extract_images: bool = False
) -> tuple[str, dict, list[tuple[bytes, str]]]:
) -> tuple[str, dict[str, Any], Sequence[tuple[bytes, str]]]:
"""
Returns the text, basic PDF metadata, and optionally extracted images.
"""
@@ -282,13 +312,13 @@ def read_pdf_file(
def docx_to_text_and_images(
file: IO[Any],
) -> Tuple[str, List[Tuple[bytes, str]]]:
) -> tuple[str, Sequence[tuple[bytes, str]]]:
"""
Extract text from a docx. If embed_images=True, also extract inline images.
Return (text_content, list_of_images).
"""
paragraphs = []
embedded_images: List[Tuple[bytes, str]] = []
embedded_images: list[tuple[bytes, str]] = []
doc = docx.Document(file)
@@ -382,6 +412,9 @@ def extract_file_text(
"""
Legacy function that returns *only text*, ignoring embedded images.
For backward-compatibility in code that only wants text.
NOTE: Ignoring seems to be defined as returning an empty string for files it can't
handle (such as images).
"""
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
".pdf": pdf_to_text,
@@ -405,7 +438,9 @@ def extract_file_text(
if extension is None:
extension = get_file_ext(file_name)
if is_valid_file_ext(extension):
if is_accepted_file_ext(
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
):
func = extension_to_function.get(extension, file_io_to_text)
file.seek(0)
return func(file)
@@ -426,14 +461,22 @@ def extract_file_text(
return ""
class ExtractionResult(NamedTuple):
"""Structured result from text and image extraction from various file types."""
text_content: str
embedded_images: Sequence[tuple[bytes, str]]
metadata: dict[str, Any]
def extract_text_and_images(
file: IO[Any],
file_name: str,
pdf_pass: str | None = None,
) -> Tuple[str, List[Tuple[bytes, str]]]:
) -> ExtractionResult:
"""
Primary new function for the updated connector.
Returns (text_content, [(embedded_img_bytes, embedded_img_name), ...]).
Returns structured extraction result with text content, embedded images, and metadata.
"""
try:
@@ -442,7 +485,9 @@ def extract_text_and_images(
# If the user doesn't want embedded images, unstructured is fine
file.seek(0)
text_content = unstructured_to_text(file, file_name)
return (text_content, [])
return ExtractionResult(
text_content=text_content, embedded_images=[], metadata={}
)
extension = get_file_ext(file_name)
@@ -450,54 +495,76 @@ def extract_text_and_images(
if extension == ".docx":
file.seek(0)
text_content, images = docx_to_text_and_images(file)
return (text_content, images)
return ExtractionResult(
text_content=text_content, embedded_images=images, metadata={}
)
# PDF example: we do not show complicated PDF image extraction here
# so we simply extract text for now and skip images.
if extension == ".pdf":
file.seek(0)
text_content, _, images = read_pdf_file(file, pdf_pass, extract_images=True)
return (text_content, images)
text_content, pdf_metadata, images = read_pdf_file(
file, pdf_pass, extract_images=True
)
return ExtractionResult(
text_content=text_content, embedded_images=images, metadata=pdf_metadata
)
# For PPTX, XLSX, EML, etc., we do not show embedded image logic here.
# You can do something similar to docx if needed.
if extension == ".pptx":
file.seek(0)
return (pptx_to_text(file), [])
return ExtractionResult(
text_content=pptx_to_text(file), embedded_images=[], metadata={}
)
if extension == ".xlsx":
file.seek(0)
return (xlsx_to_text(file), [])
return ExtractionResult(
text_content=xlsx_to_text(file), embedded_images=[], metadata={}
)
if extension == ".eml":
file.seek(0)
return (eml_to_text(file), [])
return ExtractionResult(
text_content=eml_to_text(file), embedded_images=[], metadata={}
)
if extension == ".epub":
file.seek(0)
return (epub_to_text(file), [])
return ExtractionResult(
text_content=epub_to_text(file), embedded_images=[], metadata={}
)
if extension == ".html":
file.seek(0)
return (parse_html_page_basic(file), [])
return ExtractionResult(
text_content=parse_html_page_basic(file),
embedded_images=[],
metadata={},
)
# If we reach here and it's a recognized text extension
if is_text_file_extension(file_name):
file.seek(0)
encoding = detect_encoding(file)
text_content_raw, _ = read_text_file(
text_content_raw, file_metadata = read_text_file(
file, encoding=encoding, ignore_onyx_metadata=False
)
return (text_content_raw, [])
return ExtractionResult(
text_content=text_content_raw,
embedded_images=[],
metadata=file_metadata,
)
# If it's an image file or something else, we do not parse embedded images from them
# just return empty text
file.seek(0)
return ("", [])
return ExtractionResult(text_content="", embedded_images=[], metadata={})
except Exception as e:
logger.exception(f"Failed to extract text/images from {file_name}: {e}")
return ("", [])
return ExtractionResult(text_content="", embedded_images=[], metadata={})
def convert_docx_to_txt(

View File

@@ -15,6 +15,7 @@ EXCLUDED_IMAGE_TYPES = [
"image/tiff",
"image/gif",
"image/svg+xml",
"image/avif",
]

View File

@@ -1,7 +1,9 @@
from abc import ABC
from abc import abstractmethod
from typing import cast
from typing import IO
import puremagic
from sqlalchemy.orm import Session
from onyx.configs.constants import FileOrigin
@@ -12,6 +14,7 @@ from onyx.db.pg_file_store import delete_pgfilestore_by_file_name
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.db.pg_file_store import read_lobj
from onyx.db.pg_file_store import upsert_pgfilestore
from onyx.utils.file import FileWithMimeType
class FileStore(ABC):
@@ -140,6 +143,18 @@ class PostgresBackedFileStore(FileStore):
self.db_session.rollback()
raise
def get_file_with_mime_type(self, filename: str) -> FileWithMimeType | None:
mime_type: str = "application/octet-stream"
try:
file_io = self.read_file(filename, mode="b")
file_content = file_io.read()
matches = puremagic.magic_string(file_content)
if matches:
mime_type = cast(str, matches[0].mime_type)
return FileWithMimeType(data=file_content, mime_type=mime_type)
except Exception:
return None
def get_default_file_store(db_session: Session) -> FileStore:
# The only supported file store now is the Postgres File Store

View File

@@ -9,14 +9,14 @@ from onyx.db.engine import get_session_with_current_tenant
from onyx.db.llm import fetch_default_provider
from onyx.db.llm import fetch_default_vision_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_provider
from onyx.db.llm import fetch_llm_provider_view
from onyx.db.models import Persona
from onyx.llm.chat_llm import DefaultMultiLLM
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.interfaces import LLM
from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderView
from onyx.utils.headers import build_llm_extra_headers
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
@@ -62,7 +62,7 @@ def get_llms_for_persona(
)
with get_session_context_manager() as db_session:
llm_provider = fetch_provider(db_session, provider_name)
llm_provider = fetch_llm_provider_view(db_session, provider_name)
if not llm_provider:
raise ValueError("No LLM provider found")
@@ -106,7 +106,7 @@ def get_default_llm_with_vision(
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()
def create_vision_llm(provider: FullLLMProvider, model: str) -> LLM:
def create_vision_llm(provider: LLMProviderView, model: str) -> LLM:
"""Helper to create an LLM if the provider supports image input."""
return get_llm(
provider=provider.provider,
@@ -148,7 +148,7 @@ def get_default_llm_with_vision(
provider.default_vision_model, provider.provider
):
return create_vision_llm(
FullLLMProvider.from_model(provider), provider.default_vision_model
LLMProviderView.from_model(provider), provider.default_vision_model
)
return None

View File

@@ -56,7 +56,9 @@ BEDROCK_PROVIDER_NAME = "bedrock"
# models
BEDROCK_MODEL_NAMES = [
model
for model in litellm.bedrock_models
# bedrock_converse_models are just extensions of the bedrock_models, not sure why
# litellm has split them into two lists :(
for model in litellm.bedrock_models + litellm.bedrock_converse_models
if "/" not in model and "embed" not in model
][::-1]

View File

@@ -361,7 +361,15 @@ def get_application() -> FastAPI:
)
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
# For Google OAuth, refresh tokens are requested by:
# 1. Adding the right scopes
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
oauth_client = GoogleOAuth2(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
# Use standard scopes that include profile and email
scopes=["openid", "email", "profile"],
)
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -383,6 +391,13 @@ def get_application() -> FastAPI:
prefix="/auth",
)
# Add refresh token endpoint for OAuth as well
include_auth_router_with_prefix(
application,
fastapi_users.get_refresh_router(auth_backend),
prefix="/auth",
)
application.add_exception_handler(
RequestValidationError, validation_exception_handler
)

View File

@@ -15,7 +15,6 @@ from onyx.configs.constants import MessageType
from onyx.configs.constants import SearchFeedbackType
from onyx.configs.onyxbot_configs import DANSWER_FOLLOWUP_EMOJI
from onyx.connectors.slack.utils import expert_info_from_slack_id
from onyx.connectors.slack.utils import make_slack_api_rate_limited
from onyx.context.search.models import SavedSearchDoc
from onyx.db.chat import get_chat_message
from onyx.db.chat import translate_db_message_to_chat_message_detail
@@ -553,8 +552,7 @@ def handle_followup_resolved_button(
# Delete the message with the option to mark resolved
if not immediate:
slack_call = make_slack_api_rate_limited(client.web_client.chat_delete)
response = slack_call(
response = client.web_client.chat_delete(
channel=channel_id,
ts=message_ts,
)

View File

@@ -170,7 +170,8 @@ def handle_message(
respond_tag_only = channel_conf.get("respond_tag_only") or False
respond_member_group_list = channel_conf.get("respond_member_group_list", None)
if respond_tag_only and not bypass_filters:
# NOTE: always respond in the DMs, as long the default config is not disabled.
if respond_tag_only and not bypass_filters and not is_bot_dm:
logger.info(
"Skipping message since the channel is configured such that "
"OnyxBot only responds to tags"

View File

@@ -18,6 +18,9 @@ from prometheus_client import start_http_server
from redis.lock import Lock
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.http_retry import ConnectionErrorRetryHandler
from slack_sdk.http_retry import RateLimitErrorRetryHandler
from slack_sdk.http_retry import RetryHandler
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
@@ -41,6 +44,7 @@ from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import SlackBot
from onyx.db.search_settings import get_current_search_settings
from onyx.db.slack_bot import fetch_slack_bot
from onyx.db.slack_bot import fetch_slack_bots
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
@@ -519,6 +523,25 @@ class SlackbotHandler:
def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool:
"""True to keep going, False to ignore this Slack request"""
# skip cases where the bot is disabled in the web UI
bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client)
with get_session_with_current_tenant() as db_session:
slack_bot = fetch_slack_bot(
db_session=db_session, slack_bot_id=client.slack_bot_id
)
if not slack_bot:
logger.error(
f"Slack bot with ID '{client.slack_bot_id}' not found. Skipping request."
)
return False
if not slack_bot.enabled:
logger.info(
f"Slack bot with ID '{client.slack_bot_id}' is disabled. Skipping request."
)
return False
if req.type == "events_api":
# Verify channel is valid
event = cast(dict[str, Any], req.payload.get("event", {}))
@@ -924,10 +947,21 @@ def _get_socket_client(
) -> TenantSocketModeClient:
# For more info on how to set this up, checkout the docs:
# https://docs.onyx.app/slack_bot_setup
# use the retry handlers built into the slack sdk
connection_error_retry_handler = ConnectionErrorRetryHandler()
rate_limit_error_retry_handler = RateLimitErrorRetryHandler(max_retry_count=7)
slack_retry_handlers: list[RetryHandler] = [
connection_error_retry_handler,
rate_limit_error_retry_handler,
]
return TenantSocketModeClient(
# This app-level token will be used only for establishing a connection
app_token=slack_bot_tokens.app_token,
web_client=WebClient(token=slack_bot_tokens.bot_token),
web_client=WebClient(
token=slack_bot_tokens.bot_token, retry_handlers=slack_retry_handlers
),
tenant_id=tenant_id,
slack_bot_id=slack_bot_id,
)

View File

@@ -30,7 +30,6 @@ from onyx.configs.onyxbot_configs import (
from onyx.configs.onyxbot_configs import (
DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS,
)
from onyx.connectors.slack.utils import make_slack_api_rate_limited
from onyx.connectors.slack.utils import SlackTextCleaner
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.users import get_user_by_email
@@ -125,13 +124,18 @@ def update_emote_react(
)
return
func = client.reactions_remove if remove else client.reactions_add
slack_call = make_slack_api_rate_limited(func) # type: ignore
slack_call(
name=emoji,
channel=channel,
timestamp=message_ts,
)
if remove:
client.reactions_remove(
name=emoji,
channel=channel,
timestamp=message_ts,
)
else:
client.reactions_add(
name=emoji,
channel=channel,
timestamp=message_ts,
)
except SlackApiError as e:
if remove:
logger.error(f"Failed to remove Reaction due to: {e}")
@@ -200,9 +204,8 @@ def respond_in_thread_or_channel(
message_ids: list[str] = []
if not receiver_ids:
slack_call = make_slack_api_rate_limited(client.chat_postMessage)
try:
response = slack_call(
response = client.chat_postMessage(
channel=channel,
text=text,
blocks=blocks,
@@ -224,7 +227,7 @@ def respond_in_thread_or_channel(
blocks_without_urls.append(_build_error_block(str(e)))
# Try again wtihout blocks containing url
response = slack_call(
response = client.chat_postMessage(
channel=channel,
text=text,
blocks=blocks_without_urls,
@@ -236,11 +239,9 @@ def respond_in_thread_or_channel(
message_ids.append(response["message_ts"])
else:
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
for receiver in receiver_ids:
try:
response = slack_call(
response = client.chat_postEphemeral(
channel=channel,
user=receiver,
text=text,
@@ -263,7 +264,7 @@ def respond_in_thread_or_channel(
blocks_without_urls.append(_build_error_block(str(e)))
# Try again wtihout blocks containing url
response = slack_call(
response = client.chat_postEphemeral(
channel=channel,
user=receiver,
text=text,
@@ -500,7 +501,7 @@ def fetch_user_semantic_id_from_id(
if not user_id:
return None
response = make_slack_api_rate_limited(client.users_info)(user=user_id)
response = client.users_info(user=user_id)
if not response["ok"]:
return None

View File

@@ -31,6 +31,7 @@ PUBLIC_ENDPOINT_SPECS = [
# just gets the version of Onyx (e.g. 0.3.11)
("/version", {"GET"}),
# stuff related to basic auth
("/auth/refresh", {"POST"}),
("/auth/register", {"POST"}),
("/auth/login", {"POST"}),
("/auth/logout", {"POST"}),

View File

@@ -9,9 +9,9 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine import get_session
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_llm_providers_for_user
from onyx.db.llm import fetch_provider
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import update_default_vision_provider
@@ -24,9 +24,9 @@ from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.llm.utils import model_supports_image_input
from onyx.llm.utils import test_llm
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import TestLLMRequest
from onyx.server.manage.llm.models import VisionProviderResponse
from onyx.utils.logger import setup_logger
@@ -49,11 +49,27 @@ def fetch_llm_options(
def test_llm_configuration(
test_llm_request: TestLLMRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
"""Test regular llm and fast llm settings"""
# the api key is sanitized if we are testing a provider already in the system
test_api_key = test_llm_request.api_key
if test_llm_request.name:
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
# as it turns out the name is not editable in the UI and other code also keys off name,
# so we won't rock the boat just yet.
existing_provider = fetch_existing_llm_provider(
test_llm_request.name, db_session
)
if existing_provider:
test_api_key = existing_provider.api_key
llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.default_model_name,
api_key=test_llm_request.api_key,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
custom_config=test_llm_request.custom_config,
@@ -69,7 +85,7 @@ def test_llm_configuration(
fast_llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.fast_default_model_name,
api_key=test_llm_request.api_key,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
custom_config=test_llm_request.custom_config,
@@ -119,11 +135,17 @@ def test_default_provider(
def list_llm_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[FullLLMProvider]:
return [
FullLLMProvider.from_model(llm_provider_model)
for llm_provider_model in fetch_existing_llm_providers(db_session)
]
) -> list[LLMProviderView]:
llm_provider_list: list[LLMProviderView] = []
for llm_provider_model in fetch_existing_llm_providers(db_session):
full_llm_provider = LLMProviderView.from_model(llm_provider_model)
if full_llm_provider.api_key:
full_llm_provider.api_key = (
full_llm_provider.api_key[:4] + "****" + full_llm_provider.api_key[-4:]
)
llm_provider_list.append(full_llm_provider)
return llm_provider_list
@admin_router.put("/provider")
@@ -135,11 +157,11 @@ def put_llm_provider(
),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> FullLLMProvider:
) -> LLMProviderView:
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
# the result
existing_provider = fetch_provider(db_session, llm_provider.name)
existing_provider = fetch_existing_llm_provider(llm_provider.name, db_session)
if existing_provider and is_creation:
raise HTTPException(
status_code=400,
@@ -161,6 +183,11 @@ def put_llm_provider(
llm_provider.fast_default_model_name
)
# the llm api key is sanitized when returned to clients, so the only time we
# should get a real key is when it is explicitly changed
if existing_provider and not llm_provider.api_key_changed:
llm_provider.api_key = existing_provider.api_key
try:
return upsert_llm_provider(
llm_provider=llm_provider,
@@ -234,7 +261,7 @@ def get_vision_capable_providers(
# Only include providers with at least one vision-capable model
if vision_models:
provider_dict = FullLLMProvider.from_model(provider).model_dump()
provider_dict = LLMProviderView.from_model(provider).model_dump()
provider_dict["vision_models"] = vision_models
logger.info(
f"Vision provider: {provider.provider} with models: {vision_models}"

View File

@@ -12,6 +12,7 @@ if TYPE_CHECKING:
class TestLLMRequest(BaseModel):
# provider level
name: str | None = None
provider: str
api_key: str | None = None
api_base: str | None = None
@@ -76,16 +77,19 @@ class LLMProviderUpsertRequest(LLMProvider):
# should only be used for a "custom" provider
# for default providers, the built-in model names are used
model_names: list[str] | None = None
api_key_changed: bool = False
class FullLLMProvider(LLMProvider):
class LLMProviderView(LLMProvider):
"""Stripped down representation of LLMProvider for display / limited access info only"""
id: int
is_default_provider: bool | None = None
is_default_vision_provider: bool | None = None
model_names: list[str]
@classmethod
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "FullLLMProvider":
def from_model(cls, llm_provider_model: "LLMProviderModel") -> "LLMProviderView":
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
@@ -111,7 +115,7 @@ class FullLLMProvider(LLMProvider):
)
class VisionProviderResponse(FullLLMProvider):
class VisionProviderResponse(LLMProviderView):
"""Response model for vision providers endpoint, including vision-specific fields."""
vision_models: list[str]

View File

@@ -132,6 +132,7 @@ class UserByEmail(BaseModel):
class UserRoleUpdateRequest(BaseModel):
user_email: str
new_role: UserRole
explicit_override: bool = False
class UserRoleResponse(BaseModel):

View File

@@ -32,10 +32,14 @@ from onyx.server.manage.models import SlackChannelConfig
from onyx.server.manage.models import SlackChannelConfigCreationRequest
from onyx.server.manage.validate_tokens import validate_app_token
from onyx.server.manage.validate_tokens import validate_bot_token
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/manage")
@@ -257,9 +261,6 @@ def create_bot(
# Create a default Slack channel config
default_channel_config = ChannelConfig(
channel_name=None,
respond_member_group_list=[],
answer_filters=[],
follow_up_tags=[],
respond_tag_only=True,
)
insert_slack_channel_config(
@@ -367,7 +368,9 @@ def get_all_channels_from_slack_api(
_: User | None = Depends(current_admin_user),
) -> list[SlackChannel]:
"""
Fetches channels the bot is a member of from the Slack API.
Fetches all channels in the Slack workspace using the conversations_list API.
This includes both public and private channels that are visible to the app,
not just the ones the bot is a member of.
Handles pagination with a limit to avoid excessive API calls.
"""
tokens = fetch_slack_bot_tokens(db_session, bot_id)
@@ -376,26 +379,26 @@ def get_all_channels_from_slack_api(
status_code=404, detail="Bot token not found for the given bot ID"
)
client = WebClient(token=tokens["bot_token"])
client = WebClient(token=tokens["bot_token"], timeout=1)
all_channels = []
next_cursor = None
current_page = 0
try:
# Use users_conversations with limited pagination
# Use conversations_list to get all channels in the workspace (including ones the bot is not a member of)
while current_page < MAX_SLACK_PAGES:
current_page += 1
# Make API call with cursor if we have one
if next_cursor:
response = client.users_conversations(
response = client.conversations_list(
types="public_channel,private_channel",
exclude_archived=True,
cursor=next_cursor,
limit=SLACK_API_CHANNELS_PER_PAGE,
)
else:
response = client.users_conversations(
response = client.conversations_list(
types="public_channel,private_channel",
exclude_archived=True,
limit=SLACK_API_CHANNELS_PER_PAGE,
@@ -431,6 +434,7 @@ def get_all_channels_from_slack_api(
except SlackApiError as e:
# Handle rate limiting or other API errors
logger.exception("Error fetching channels from Slack API")
raise HTTPException(
status_code=500,
detail=f"Error fetching channels from Slack API: {str(e)}",

View File

@@ -102,6 +102,7 @@ def set_user_role(
validate_user_role_update(
requested_role=requested_role,
current_role=current_role,
explicit_override=user_role_update_request.explicit_override,
)
if user_to_update.id == current_user.id:
@@ -122,6 +123,22 @@ def set_user_role(
db_session.commit()
class TestUpsertRequest(BaseModel):
email: str
@router.post("/manage/users/test-upsert-user")
async def test_upsert_user(
request: TestUpsertRequest,
_: User = Depends(current_admin_user),
) -> None | FullUserSnapshot:
"""Test endpoint for upsert_saml_user. Only used for integration testing."""
user = await fetch_ee_implementation_or_noop(
"onyx.server.saml", "upsert_saml_user", None
)(email=request.email)
return FullUserSnapshot.from_user_model(user) if user else None
@router.get("/manage/users/accepted")
def list_accepted_users(
q: str | None = Query(default=None),
@@ -296,7 +313,7 @@ def bulk_invite_users(
detail=f"Invalid email address: {email} - {str(e)}",
)
if MULTI_TENANT and not DEV_MODE:
if MULTI_TENANT:
try:
fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning", "add_users_to_tenant", None
@@ -318,7 +335,7 @@ def bulk_invite_users(
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
if not MULTI_TENANT:
if not MULTI_TENANT or DEV_MODE:
return number_of_invited_users
# for billing purposes, write to the control plane about the number of new users
@@ -351,13 +368,15 @@ def remove_invited_user(
user_emails = get_invited_users()
remaining_users = [user for user in user_emails if user != user_email.user_email]
fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
)([user_email.user_email], tenant_id)
if MULTI_TENANT:
fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
)([user_email.user_email], tenant_id)
number_of_invited_users = write_invited_users(remaining_users)
try:
if MULTI_TENANT:
if MULTI_TENANT and not DEV_MODE:
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(tenant_id, get_total_users_count(db_session))

View File

@@ -0,0 +1,167 @@
import io
from typing import cast
from PIL import Image
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.celery.tasks.beat_schedule import (
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT,
)
from onyx.configs.constants import CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import ONYX_EMAILABLE_LOGO_MAX_DIM
from onyx.db.engine import get_session_with_shared_schema
from onyx.file_store.file_store import PostgresBackedFileStore
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.utils.file import FileWithMimeType
from onyx.utils.file import OnyxStaticFileManager
from onyx.utils.variable_functionality import (
fetch_ee_implementation_or_noop,
)
class OnyxRuntime:
"""Used by the application to get the final runtime value of a setting.
Rationale: Settings and overrides may be persisted in multiple places, including the
DB, Redis, env vars, and default constants, etc. The logic to present a final
setting to the application should be centralized and in one place.
Example: To get the logo for the application, one must check the DB for an override,
use the override if present, fall back to the filesystem if not present, and worry
about enterprise or not enterprise.
"""
@staticmethod
def _get_with_static_fallback(
db_filename: str | None, static_filename: str
) -> FileWithMimeType:
onyx_file: FileWithMimeType | None = None
if db_filename:
with get_session_with_shared_schema() as db_session:
file_store = PostgresBackedFileStore(db_session)
onyx_file = file_store.get_file_with_mime_type(db_filename)
if not onyx_file:
onyx_file = OnyxStaticFileManager.get_static(static_filename)
if not onyx_file:
raise RuntimeError(
f"Resource not found: db={db_filename} static={static_filename}"
)
return onyx_file
@staticmethod
def get_logo() -> FileWithMimeType:
STATIC_FILENAME = "static/images/logo.png"
db_filename: str | None = fetch_ee_implementation_or_noop(
"onyx.server.enterprise_settings.store", "get_logo_filename", None
)
return OnyxRuntime._get_with_static_fallback(db_filename, STATIC_FILENAME)
@staticmethod
def get_emailable_logo() -> FileWithMimeType:
onyx_file = OnyxRuntime.get_logo()
# check dimensions and resize downwards if necessary or if not PNG
image = Image.open(io.BytesIO(onyx_file.data))
if (
image.size[0] > ONYX_EMAILABLE_LOGO_MAX_DIM
or image.size[1] > ONYX_EMAILABLE_LOGO_MAX_DIM
or image.format != "PNG"
):
image.thumbnail(
(ONYX_EMAILABLE_LOGO_MAX_DIM, ONYX_EMAILABLE_LOGO_MAX_DIM),
Image.LANCZOS,
) # maintains aspect ratio
output_buffer = io.BytesIO()
image.save(output_buffer, format="PNG")
onyx_file = FileWithMimeType(
data=output_buffer.getvalue(), mime_type="image/png"
)
return onyx_file
@staticmethod
def get_logotype() -> FileWithMimeType:
STATIC_FILENAME = "static/images/logotype.png"
db_filename: str | None = fetch_ee_implementation_or_noop(
"onyx.server.enterprise_settings.store", "get_logotype_filename", None
)
return OnyxRuntime._get_with_static_fallback(db_filename, STATIC_FILENAME)
@staticmethod
def get_beat_multiplier() -> float:
"""the beat multiplier is used to scale up or down the frequency of certain beat
tasks in the cloud. It has a significant effect on load and is useful to adjust
in real time."""
beat_multiplier: float = CLOUD_BEAT_MULTIPLIER_DEFAULT
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
if beat_multiplier_raw is not None:
try:
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
beat_multiplier = float(beat_multiplier_bytes.decode())
except ValueError:
pass
if beat_multiplier <= 0.0:
return 1.0
return beat_multiplier
@staticmethod
def get_doc_permission_sync_multiplier() -> float:
"""Permission syncs are a significant source of load / queueing in the cloud."""
value: float = CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
value_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:doc_permission_sync_multiplier")
if value_raw is not None:
try:
value_bytes = cast(bytes, value_raw)
value = float(value_bytes.decode())
except ValueError:
pass
if value <= 0.0:
return 1.0
return value
@staticmethod
def get_build_fence_lookup_table_interval() -> int:
"""We maintain an active fence table to make lookups of existing fences efficient.
However, reconstructing the table is expensive, so adjusting it in realtime is useful.
"""
interval: int = CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
interval_raw = r.get(
f"{ONYX_CLOUD_REDIS_RUNTIME}:build_fence_lookup_table_interval"
)
if interval_raw is not None:
try:
interval_bytes = cast(bytes, interval_raw)
interval = int(interval_bytes.decode())
except ValueError:
pass
if interval <= 0.0:
return CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
return interval

View File

@@ -307,6 +307,7 @@ def setup_postgres(db_session: Session) -> None:
groups=[],
display_model_names=OPEN_AI_MODEL_NAMES,
model_names=OPEN_AI_MODEL_NAMES,
api_key_changed=True,
)
new_llm_provider = upsert_llm_provider(
llm_provider=model_req, db_session=db_session
@@ -323,7 +324,7 @@ def update_default_multipass_indexing(db_session: Session) -> None:
logger.info(
"No existing docs or connectors found. Checking GPU availability for multipass indexing."
)
gpu_available = gpu_status_request()
gpu_available = gpu_status_request(indexing=True)
logger.info(f"GPU available: {gpu_available}")
current_settings = get_current_search_settings(db_session)

View File

@@ -21,7 +21,6 @@ def build_tool_message(
)
# TODO: does this NEED to be BaseModel__v1?
class ToolCallSummary(BaseModel):
tool_call_request: AIMessage
tool_call_result: ToolMessage

View File

@@ -0,0 +1,36 @@
from typing import cast
import puremagic
from pydantic import BaseModel
from onyx.utils.logger import setup_logger
logger = setup_logger()
class FileWithMimeType(BaseModel):
data: bytes
mime_type: str
class OnyxStaticFileManager:
"""Retrieve static resources with this class. Currently, these should all be located
in the static directory ... e.g. static/images/logo.png"""
@staticmethod
def get_static(filename: str) -> FileWithMimeType | None:
try:
mime_type: str = "application/octet-stream"
with open(filename, "rb") as f:
file_content = f.read()
matches = puremagic.magic_string(file_content)
if matches:
mime_type = cast(str, matches[0].mime_type)
except (OSError, FileNotFoundError, PermissionError) as e:
logger.error(f"Failed to read file {filename}: {e}")
return None
except Exception as e:
logger.error(f"Unexpected exception reading file {filename}: {e}")
return None
return FileWithMimeType(data=file_content, mime_type=mime_type)

View File

@@ -1,3 +1,5 @@
from functools import lru_cache
import requests
from retry import retry
@@ -10,8 +12,7 @@ from shared_configs.configs import MODEL_SERVER_PORT
logger = setup_logger()
@retry(tries=5, delay=5)
def gpu_status_request(indexing: bool = True) -> bool:
def _get_gpu_status_from_model_server(indexing: bool) -> bool:
if indexing:
model_server_url = f"{INDEXING_MODEL_SERVER_HOST}:{INDEXING_MODEL_SERVER_PORT}"
else:
@@ -28,3 +29,14 @@ def gpu_status_request(indexing: bool = True) -> bool:
except requests.RequestException as e:
logger.error(f"Error: Unable to fetch GPU status. Error: {str(e)}")
raise # Re-raise exception to trigger a retry
@retry(tries=5, delay=5)
def gpu_status_request(indexing: bool) -> bool:
return _get_gpu_status_from_model_server(indexing)
@lru_cache(maxsize=1)
def fast_gpu_status_request(indexing: bool) -> bool:
"""For use in sync flows, where we don't want to retry / we want to cache this."""
return gpu_status_request(indexing=indexing)

View File

@@ -0,0 +1,13 @@
from collections.abc import Callable
from functools import lru_cache
from typing import TypeVar
R = TypeVar("R")
def lazy_eval(func: Callable[[], R]) -> Callable[[], R]:
@lru_cache(maxsize=1)
def lazy_func() -> R:
return func()
return lazy_func

View File

@@ -36,6 +36,10 @@ class RecordType(str, Enum):
LATENCY = "latency"
FAILURE = "failure"
METRIC = "metric"
INDEXING_PROGRESS = "indexing_progress"
INDEXING_COMPLETE = "indexing_complete"
PERMISSION_SYNC_PROGRESS = "permission_sync_progress"
INDEX_ATTEMPT_STATUS = "index_attempt_status"
def _get_or_generate_customer_id_mt(tenant_id: str) -> str:

View File

@@ -1,27 +1,167 @@
import collections.abc
import contextvars
import copy
import threading
import uuid
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import MutableMapping
from collections.abc import Sequence
from concurrent.futures import as_completed
from concurrent.futures import FIRST_COMPLETED
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import wait
from typing import Any
from typing import cast
from typing import Generic
from typing import overload
from typing import Protocol
from typing import TypeVar
from pydantic import GetCoreSchemaHandler
from pydantic_core import core_schema
from onyx.utils.logger import setup_logger
logger = setup_logger()
R = TypeVar("R")
KT = TypeVar("KT") # Key type
VT = TypeVar("VT") # Value type
_T = TypeVar("_T") # Default type
class ThreadSafeDict(MutableMapping[KT, VT]):
"""
A thread-safe dictionary implementation that uses a lock to ensure thread safety.
Implements the MutableMapping interface to provide a complete dictionary-like interface.
Example usage:
# Create a thread-safe dictionary
safe_dict: ThreadSafeDict[str, int] = ThreadSafeDict()
# Basic operations (atomic)
safe_dict["key"] = 1
value = safe_dict["key"]
del safe_dict["key"]
# Bulk operations (atomic)
safe_dict.update({"key1": 1, "key2": 2})
"""
def __init__(self, input_dict: dict[KT, VT] | None = None) -> None:
self._dict: dict[KT, VT] = input_dict or {}
self.lock = threading.Lock()
def __getitem__(self, key: KT) -> VT:
with self.lock:
return self._dict[key]
def __setitem__(self, key: KT, value: VT) -> None:
with self.lock:
self._dict[key] = value
def __delitem__(self, key: KT) -> None:
with self.lock:
del self._dict[key]
def __iter__(self) -> Iterator[KT]:
# Return a snapshot of keys to avoid potential modification during iteration
with self.lock:
return iter(list(self._dict.keys()))
def __len__(self) -> int:
with self.lock:
return len(self._dict)
@classmethod
def __get_pydantic_core_schema__(
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
return core_schema.no_info_after_validator_function(
cls.validate, handler(dict[KT, VT])
)
@classmethod
def validate(cls, v: Any) -> "ThreadSafeDict[KT, VT]":
if isinstance(v, dict):
return ThreadSafeDict(v)
return v
def __deepcopy__(self, memo: Any) -> "ThreadSafeDict[KT, VT]":
return ThreadSafeDict(copy.deepcopy(self._dict))
def clear(self) -> None:
"""Remove all items from the dictionary atomically."""
with self.lock:
self._dict.clear()
def copy(self) -> dict[KT, VT]:
"""Return a shallow copy of the dictionary atomically."""
with self.lock:
return self._dict.copy()
@overload
def get(self, key: KT) -> VT | None:
...
@overload
def get(self, key: KT, default: VT | _T) -> VT | _T:
...
def get(self, key: KT, default: Any = None) -> Any:
"""Get a value with a default, atomically."""
with self.lock:
return self._dict.get(key, default)
def pop(self, key: KT, default: Any = None) -> Any:
"""Remove and return a value with optional default, atomically."""
with self.lock:
if default is None:
return self._dict.pop(key)
return self._dict.pop(key, default)
def setdefault(self, key: KT, default: VT) -> VT:
"""Set a default value if key is missing, atomically."""
with self.lock:
return self._dict.setdefault(key, default)
def update(self, *args: Any, **kwargs: VT) -> None:
"""Update the dictionary atomically from another mapping or from kwargs."""
with self.lock:
self._dict.update(*args, **kwargs)
def items(self) -> collections.abc.ItemsView[KT, VT]:
"""Return a view of (key, value) pairs atomically."""
with self.lock:
return collections.abc.ItemsView(self)
def keys(self) -> collections.abc.KeysView[KT]:
"""Return a view of keys atomically."""
with self.lock:
return collections.abc.KeysView(self)
def values(self) -> collections.abc.ValuesView[VT]:
"""Return a view of values atomically."""
with self.lock:
return collections.abc.ValuesView(self)
class CallableProtocol(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
...
def run_functions_tuples_in_parallel(
functions_with_args: list[tuple[Callable, tuple]],
functions_with_args: Sequence[tuple[CallableProtocol, tuple[Any, ...]]],
allow_failures: bool = False,
max_workers: int | None = None,
) -> list[Any]:
"""
Executes multiple functions in parallel and returns a list of the results for each function.
This function preserves contextvars across threads, which is important for maintaining
context like tenant IDs in database sessions.
Args:
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
@@ -29,7 +169,7 @@ def run_functions_tuples_in_parallel(
max_workers: Max number of worker threads
Returns:
dict: A dictionary mapping function names to their results or error messages.
list: A list of results from each function, in the same order as the input functions.
"""
workers = (
min(max_workers, len(functions_with_args))
@@ -56,7 +196,7 @@ def run_functions_tuples_in_parallel(
results.append((index, future.result()))
except Exception as e:
logger.exception(f"Function at index {index} failed due to {e}")
results.append((index, None))
results.append((index, None)) # type: ignore
if not allow_failures:
raise
@@ -158,7 +298,7 @@ def run_with_timeout(
if task.is_alive():
task.end()
return task.result
return task.result # type: ignore
# NOTE: this function should really only be used when run_functions_tuples_in_parallel is
@@ -174,9 +314,9 @@ def run_in_background(
"""
context = contextvars.copy_context()
# Timeout not used in the non-blocking case
task = TimeoutThread(-1, context.run, func, *args, **kwargs)
task = TimeoutThread(-1, context.run, func, *args, **kwargs) # type: ignore
task.start()
return task
return cast(TimeoutThread[R], task)
def wait_on_background(task: TimeoutThread[R]) -> R:
@@ -190,3 +330,27 @@ def wait_on_background(task: TimeoutThread[R]) -> R:
raise task.exception
return task.result
def _next_or_none(ind: int, g: Iterator[R]) -> tuple[int, R | None]:
return ind, next(g, None)
def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R]:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_index: dict[Future[tuple[int, R | None]], int] = {
executor.submit(_next_or_none, i, g): i for i, g in enumerate(gens)
}
next_ind = len(gens)
while future_to_index:
done, _ = wait(future_to_index, return_when=FIRST_COMPLETED)
for future in done:
ind, result = future.result()
if result is not None:
yield result
future_to_index[
executor.submit(_next_or_none, ind, gens[ind])
] = next_ind
next_ind += 1
del future_to_index[future]

View File

@@ -38,7 +38,7 @@ langchainhub==0.1.21
langgraph==0.2.72
langgraph-checkpoint==2.0.13
langgraph-sdk==0.1.44
litellm==1.61.16
litellm==1.63.8
lxml==5.3.0
lxml_html_clean==0.2.2
llama-index==0.9.45
@@ -47,15 +47,16 @@ msal==1.28.0
nltk==3.8.1
Office365-REST-Python-Client==2.5.9
oauthlib==3.2.2
openai==1.61.0
openai==1.66.3
openpyxl==3.1.2
playwright==1.41.2
psutil==5.9.5
psycopg2-binary==2.9.9
puremagic==1.28
pyairtable==3.0.1
pycryptodome==3.19.1
pydantic==2.8.2
PyGithub==1.58.2
PyGithub==2.5.0
python-dateutil==2.8.2
python-gitlab==3.9.0
python-pptx==0.6.23

View File

@@ -0,0 +1,39 @@
#!/bin/bash
# USAGE: nohup ./docker_memory_tracking.sh &
# Set default output file or use the provided argument
OUTPUT_FILE="./docker_stats.log"
if [ $# -ge 1 ]; then
OUTPUT_FILE="$1"
fi
INTERVAL_SECONDS=600 # 10 minutes
# Create the output file if it doesn't exist, or append to it if it does
touch "$OUTPUT_FILE"
echo "Docker stats will be collected every 10 minutes and saved to $OUTPUT_FILE"
echo "Press Ctrl+C to stop the script"
# Function to handle script termination
cleanup() {
echo -e "\nStopping docker stats collection"
exit 0
}
# Set up trap for clean exit
trap cleanup SIGINT SIGTERM
# Main loop
while true; do
# Add timestamp
echo -e "\n--- Docker Stats: $(date) ---" >> "$OUTPUT_FILE"
# Run docker stats for a single snapshot (--no-stream ensures it runs once)
docker stats --no-stream --all >> "$OUTPUT_FILE"
# Wait for the next interval
echo "Stats collected at $(date). Next collection in 10 minutes."
sleep $INTERVAL_SECONDS
done

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 44 KiB

View File

@@ -0,0 +1,77 @@
import os
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.configs.constants import BlobType
from onyx.connectors.blob.connector import BlobStorageConnector
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import ACCEPTED_DOCUMENT_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import ACCEPTED_IMAGE_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import get_file_ext
@pytest.fixture
def blob_connector(request: pytest.FixtureRequest) -> BlobStorageConnector:
connector = BlobStorageConnector(
bucket_type=BlobType.S3, bucket_name="onyx-connector-tests"
)
connector.load_credentials(
{
"aws_access_key_id": os.environ["AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS"],
"aws_secret_access_key": os.environ[
"AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS"
],
}
)
return connector
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_blob_s3_connector(
mock_get_api_key: MagicMock, blob_connector: BlobStorageConnector
) -> None:
"""
Plain and document file types should be fully indexed.
Multimedia and unknown file types will be indexed by title only with one empty section.
This is intentional in order to allow searching by just the title even if we can't
index the file content.
"""
all_docs: list[Document] = []
document_batches = blob_connector.load_from_state()
for doc_batch in document_batches:
for doc in doc_batch:
all_docs.append(doc)
#
assert len(all_docs) == 19
for doc in all_docs:
section = doc.sections[0]
assert isinstance(section, TextSection)
file_extension = get_file_ext(doc.semantic_identifier)
if file_extension in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS:
assert len(section.text) > 0
continue
if file_extension in ACCEPTED_DOCUMENT_FILE_EXTENSIONS:
assert len(section.text) > 0
continue
if file_extension in ACCEPTED_IMAGE_FILE_EXTENSIONS:
assert len(section.text) == 0
continue
# unknown extension
assert len(section.text) == 0

Some files were not shown because too many files have changed in this diff Show More