Compare commits

...

58 Commits

Author SHA1 Message Date
Justin Tahara
5602ff8666 fix: use only celery-shared for security context (#5236) (#5239)
* fix: use only celery-shared for security context

* fix: bump helm chart version 0.2.8

Co-authored-by: Sam Waddell <shwaddell28@gmail.com>
2025-08-21 17:25:06 -07:00
Sam Waddell
2fc70781b4 fix: use only celery-shared for security context (#5236)
* fix: use only celery-shared for security context

* fix: bump helm chart version 0.2.8
2025-08-21 14:15:07 -07:00
Justin Tahara
f76b4dec4c feat(infra): Ignoring local Terraform files (#5227)
* feat(infra): Ignoring local Terraform files

* Addressing some comments
2025-08-21 09:43:18 -07:00
Jessica Singh
a5a516fa8a refactor(model): move api-based embeddings/reranking calls out of model server (#5216)
* move api-based embeddings/reranking calls to api server out of model server, added/modified unit tests

* ran pre-commit

* fix mypy errors

* mypy and precommit

* move utils to right place and add requirements

* precommit check

* removed extra constants, changed error msg

* Update backend/onyx/utils/search_nlp_models_utils.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* greptile

* addressed comments

* added code enforcement to throw error

---------

Co-authored-by: Jessica Singh <jessicasingh@Mac.attlocal.net>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-08-20 21:50:21 +00:00
Sam Waddell
811a198134 docs: add non-root user info (#5224) 2025-08-20 13:50:10 -07:00
Sam Waddell
5867ab1d7d feat: add non-root user to backend and model-server images (#5134)
* feat: add non-root user to backend and model-server image

* feat: update values to support security context for index, inference, and celery_shared

* feat: add security context support for index and inference

* feat: add celery_shared security context support to celery worker templates

* fix: cache management strategy

* fix: update deployment files for volume mount

* fix: address comments

* fix: bump helm chart version for new security context template changes

* fix: bump helm chart version for new security context template changes

* feat: move useradd earlier in build for reduced image size

---------

Co-authored-by: Phil Critchfield <phil.critchfield@liatrio.com>
2025-08-20 13:49:50 -07:00
Jose Bañez
dd6653eb1f fix(connector): #5178 Add error handling and logging for empty answer text in Loopio Connector (#5179)
* fix(connector): #5178 Add error handling and logging for empty answer text in LoopioConnector

* fix(connector): onyx-dot-app#5178:  Improve handling of empty answer text in LoopioConnector

---------

Co-authored-by: Jose Bañez <jose@4gclinical.com>
2025-08-20 09:14:08 -07:00
Richard Guan
db457ef432 fix(admin): [DAN-2202] Remove users from invited users after accept (#5214)
* .

* .

* .

* .

* .

* .

* .

---------

Co-authored-by: Richard Guan <richardguan@Richards-MacBook-Pro.local>
Co-authored-by: Richard Guan <richardguan@Mac.attlocal.net>
2025-08-20 03:55:02 +00:00
Richard Guan
de7fe939b2 . (#5212)
Co-authored-by: Richard Guan <richardguan@Richards-MBP.lan>
2025-08-20 02:36:44 +00:00
Chris Weaver
38114d9542 fix: PDF file upload (#5218)
* Fix / improve file upload

* Address cubic comment
2025-08-19 15:16:08 -07:00
Justin Tahara
32f20f2e2e feat(infra): Add WAF implementation (#5213) (#5217)
* feat(infra): Add WAF implementation

* Addressing greptile comments

* Additional removal of unnecessary code
2025-08-19 13:01:40 -07:00
Justin Tahara
3dd27099f7 feat(infra): Add WAF implementation (#5213)
* feat(infra): Add WAF implementation

* Addressing greptile comments

* Additional removal of unnecessary code
2025-08-18 17:45:50 -07:00
Cameron
91c4d43a80 Move @types packages to devDependencies (#5210) 2025-08-18 14:34:09 -07:00
SubashMohan
a63ba1bb03 fix: sharepoint group not found error and url with apostrophe (#5208)
* fix: handle ClientRequestException in SharePoint permission utils and connector

* feat: enhance SharePoint permission utilities with logging and URL handling

* greptile typo fix

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* enhance group sync handling for public groups

---------

Co-authored-by: Wenxi <wenxi@onyx.app>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-08-18 17:12:59 +00:00
Evan Lohn
7b6189e74c corrected routing (#5202) 2025-08-18 16:07:28 +00:00
Evan Lohn
ba423e5773 fix: model server concurrency (#5206)
* fix: model server race cond

* fix async

* different approach
2025-08-18 16:07:16 +00:00
SubashMohan
fe029eccae chore: add SharePoint sync environment variables to integration test (#5197)
* chore: add SharePoint sync environment variables to integration test workflows

* fix cubic comments

* test: skip SharePoint permission tests for non-enterprise

* test: update SharePoint permission tests to skip for non-enterprise environments
2025-08-18 03:21:04 +00:00
Wenxi
ea72af7698 fix sharepoint tests (#5209) 2025-08-17 22:25:47 +00:00
Wenxi
17abf85533 fix unpaused user files (#5205) 2025-08-16 01:39:16 +00:00
Wenxi
3bd162acb9 fix: sharepoint tests and indexing logic (#5204)
* don't index onedrive personal sites in sharepoint

* fix sharepoint tests and indexing behavior

* remove print
2025-08-15 18:19:42 -07:00
Evan Lohn
664ce441eb generous timeout between docfetching finishing and docprocessing starting (#5201) 2025-08-15 15:43:01 -07:00
Wenxi
6863fbee54 fix: validate sharepoint connector with validate_connector_settings (#5199)
* validate sharepoint connector with validate_connector_settings

* fix test

* fix tests
2025-08-15 00:38:31 +00:00
Justin Tahara
bb98088b80 fix(infra): Fix Helm Chart Test (#5198) 2025-08-14 23:28:17 +00:00
Justin Tahara
ce8cb1112a feat(infra): Adding new AWS Terraform Template Code (#5194)
* feat(infra): Adding new AWS Terraform Template Code

* Addressing greptile comments

* Applying some updates after the cubic reviews as well

* Adding one detail

* Removing unused var

* Addressing more cubic comments
2025-08-14 16:47:15 -07:00
Nils
a605bd4ca4 feat: make sharepoint documents and sharepoint pages optional (#5183)
* feat: make sharepoint documents and sharepoint pages optional

* fix: address review feedback for PR #5183

* fix: exclude personal sites from sharepoint connector

---------

Co-authored-by: Nils Kleinrahm <nils.kleinrahm@pledoc.de>
2025-08-14 15:17:23 -07:00
Dominic Feliton
0e8b5af619 fix(connector): user file helm start cmd + legacy file connector incompatibility (#5195)
* Fix user file helm start cmd + legacy file connector incompatibility

* typo

* remove unnecessary logic

* undo

* make recommended changes

* keep comment

* cleanup

* format

---------

Co-authored-by: Dominic Feliton <37809476+dominicfeliton@users.noreply.github.com>
2025-08-14 13:20:19 -07:00
SubashMohan
46f3af4f68 enhance file processing with content type handling (#5196) 2025-08-14 08:59:53 +00:00
Evan Lohn
2af64ebf4c fix: ensure exception strings don't get swallowed (#5192)
* ensure exception strings don't get swallowed

* just send exception code
2025-08-13 20:05:16 +00:00
Evan Lohn
0eb1824158 fix: sf connector docs (#5171)
* fix: sf connector docs

* more sf logs

* better logs and new attempt

* add fields to error temporarily

* fix sf

---------

Co-authored-by: Wenxi <wenxi@onyx.app>
2025-08-13 17:52:32 +00:00
Chris Weaver
e0a9a6fb66 feat: okta profile tool (#5184)
* Initial Okta profile tool

* Improve

* Fix

* Improve

* Improve

* Address EL comments
2025-08-13 09:57:31 -07:00
Wenxi
fe194076c2 make default personas hideable (#5190) 2025-08-13 01:12:51 +00:00
Wenxi
55dc24fd27 fix: seeded total doc count (#5188)
* fix seeded total doc count

* fix seeded total doc count
2025-08-13 00:19:06 +00:00
Evan Lohn
da02962a67 fix: thread safe approach to docprocessing logging (#5185)
* thread safe approach to docprocessing logging

* unify approaches

* reset
2025-08-12 02:25:47 +00:00
SubashMohan
9bc62cc803 feat: sharepoint perm sync (#5033)
* sharepoint perm sync first draft

* feat: Implement SharePoint permission synchronization

* mypy fix

* remove commented code

* bot comments fixes and job failure fixes

* introduce generic way to upload certificates in credentials

* mypy fix

* add checkpoiting to sharepoint connector

* add sharepoint integration tests

* Refactor SharePoint connector to derive tenant domain from verified domains and remove direct tenant domain input from credentials

* address review comments

* add permission sync to site pages

* mypy fix

* fix tests error

* fix tests and address comments

* Update file extraction behavior in SharePoint connector to continue processing on unprocessable files
2025-08-11 16:59:16 +00:00
Evan Lohn
bf6705a9a5 fix: max tokens param (#5174)
* max tokens param

* fix unit test

* fix unit test
2025-08-11 09:57:44 -07:00
Rei Meguro
df2fef3383 fix: removal of old tags + is_list differentiation (#5147)
* initial migration

* getting metadata from tags

* complete migration

* migration override for cloud

* fix: more robust structured tag gen

* tag and indexing update

* fix: move is_list to tags

* migration rebase

* test cases + bugfix on unique constraint

* fix logging
2025-08-10 22:39:33 +00:00
SubashMohan
8cec3448d7 fix: restrict user file access to current user only (#5177)
* fix: restrict user file access to current user only

* fix: enhance user file access control for recent folder
2025-08-10 19:00:18 +00:00
Justin Tahara
b81687995e fix(infra): Removing invalid helm version (#5176) 2025-08-08 18:40:55 -07:00
Justin Tahara
87c2253451 fix(infra): Update github workflow to not tag latest (#5172)
* fix(infra): Update github workflow to not tag latest

* Cleaned up the code a bit
2025-08-08 23:23:55 +00:00
Wenxi
297c2957b4 add gpt 5 display names (#5175) 2025-08-08 16:58:47 -07:00
Wenxi
bacee0d09d fix: sanitize slack payload before logging (#5167)
* sanitize slack payload before logging

* nit
2025-08-08 02:10:00 +00:00
Evan Lohn
297720c132 refactor: file processing (#5136)
* file processing refactor

* mypy

* CW comments

* address CW
2025-08-08 00:34:35 +00:00
Evan Lohn
bd4bd00cef feat: office parsing markitdown (#5115)
* switch to markitdown untested

* passing tests

* reset file

* dotenv version

* docs

* add test file

* add doc

* fix integration test
2025-08-07 23:26:02 +00:00
Chris Weaver
07c482f727 Make starter messages visible on smaller screens (#5170) 2025-08-07 16:49:18 -07:00
Wenxi
cf193dee29 feat: support gpt5 models (#5169)
* support gpt5 models

* gpt5mini visible
2025-08-07 12:35:46 -07:00
Evan Lohn
1b47fa2700 fix: remove erroneous error case and add valid error (#5163)
* fix: remove erroneous error case and add valid error

* also address docfetching-docprocessing limbo
2025-08-07 18:17:00 +00:00
Wenxi Onyx
e1a305d18a mask llm api key from logs 2025-08-07 00:01:29 -07:00
Evan Lohn
e2233d22c9 feat: salesforce custom query (#5158)
* WIP merged approach untested

* tested custom configs

* JT comments

* fix unit test

* CW comments

* fix unit test
2025-08-07 02:37:23 +00:00
Justin Tahara
20d1175312 feat(infra): Bump Vespa Helm Version (#5161)
* feat(infra): Bump Vespa Helm Version

* Adding the Chart.lock file
2025-08-06 19:06:18 -07:00
justin-tahara
7117774287 Revert that change. Let's do this properly 2025-08-06 18:54:21 -07:00
justin-tahara
77f2660bb2 feat(infra): Update Vespa Helm Chart Version 2025-08-06 18:53:02 -07:00
Wenxi
1b2f4f3b87 fix: slash command slackbot to respond in private msg (#5151)
* fix slash command slackbot to respond in private msg

* rename confusing variable. fix slash message response in DMs
2025-08-05 19:03:38 -07:00
Evan Lohn
d85b55a9d2 no more scheduled stalling (#5154) 2025-08-05 20:17:44 +00:00
Justin Tahara
e2bae5a2d9 fix(infra): Adding helm directory (#5156)
* feat(infra): Adding helm directory

* one more fix
2025-08-05 14:11:57 -07:00
Justin Tahara
cc9c76c4fb feat(infra): Release Charts on Github Pages (#5155) 2025-08-05 14:03:28 -07:00
Chris Weaver
258e08abcd feat: add customization via env vars for curator role (#5150)
* Add customization via env vars for curator role

* Simplify

* Simplify more

* Address comments
2025-08-05 09:58:36 -07:00
Evan Lohn
67047e42a7 fix: preserve error traces (#5152) 2025-08-05 09:44:55 -07:00
SubashMohan
146628e734 fix unsupported character error in minio migration (#5145)
* fix unsupported character error in minio migration

* slash fix
2025-08-04 12:42:07 -07:00
180 changed files with 7812 additions and 1997 deletions

View File

@@ -18,23 +18,32 @@ jobs:
with:
fetch-depth: 0
- name: Configure Git
run: |
git config user.name "$GITHUB_ACTOR"
git config user.email "$GITHUB_ACTOR@users.noreply.github.com"
- name: Install Helm
- name: Install Helm CLI
uses: azure/setup-helm@v4
with:
version: v3.12.1
- name: Add Required Helm Repositories
- name: Add required Helm repositories
run: |
helm repo add bitnami https://charts.bitnami.com/bitnami
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo update
- name: Run chart-releaser
uses: helm/chart-releaser-action@v1.7.0
env:
CR_TOKEN: "${{ secrets.GITHUB_TOKEN }}"
- name: Build chart dependencies
run: |
set -euo pipefail
for chart_dir in deployment/helm/charts/*; do
if [ -f "$chart_dir/Chart.yaml" ]; then
echo "Building dependencies for $chart_dir"
helm dependency build "$chart_dir"
fi
done
- name: Publish Helm charts to gh-pages
uses: stefanprodan/helm-gh-pages@v1.7.0
with:
token: ${{ secrets.GITHUB_TOKEN }}
charts_dir: deployment/helm/charts
branch: gh-pages
commit_username: ${{ github.actor }}
commit_email: ${{ github.actor }}@users.noreply.github.com

View File

@@ -55,7 +55,25 @@ jobs:
- name: Run chart-testing (install)
if: steps.list-changed.outputs.changed == 'true'
run: ct install --all --helm-extra-set-args="--set=nginx.enabled=false" --debug --config ct.yaml
run: ct install --all \
--helm-extra-set-args="\
--set=nginx.enabled=false \
--set=postgresql.enabled=false \
--set=redis.enabled=false \
--set=minio.enabled=false \
--set=vespa.enabled=false \
--set=slackbot.enabled=false \
--set=api.replicaCount=0 \
--set=inferenceCapability.replicaCount=0 \
--set=indexCapability.replicaCount=0 \
--set=celery_beat.replicaCount=0 \
--set=celery_worker_heavy.replicaCount=0 \
--set=celery_worker_docprocessing.replicaCount=0 \
--set=celery_worker_light.replicaCount=0 \
--set=celery_worker_monitoring.replicaCount=0 \
--set=celery_worker_primary.replicaCount=0 \
--set=celery_worker_user_files_indexing.replicaCount=0" \
--debug --config ct.yaml
# the following would install only changed charts, but we only have one chart so
# don't worry about that for now
# run: ct install --target-branch ${{ github.event.repository.default_branch }}

View File

@@ -19,6 +19,10 @@ env:
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
PLATFORM_PAIR: linux-amd64
jobs:
@@ -272,6 +276,10 @@ jobs:
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \

View File

@@ -19,6 +19,10 @@ env:
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
PLATFORM_PAIR: linux-amd64
jobs:
integration-tests-mit:
@@ -207,6 +211,10 @@ jobs:
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \

11
.gitignore vendored
View File

@@ -21,8 +21,19 @@ backend/tests/regression/search_quality/*.json
# secret files
.env
jira_test_env
settings.json
# others
/deployment/data/nginx/app.conf
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml
# Local .terraform directories
**/.terraform/*
# Local .tfstate files
*.tfstate
*.tfstate.*
# Local .terraform.lock.hcl file
.terraform.lock.hcl

View File

@@ -23,6 +23,9 @@ DISABLE_LLM_DOC_RELEVANCE=False
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
OAUTH_CLIENT_ID=<REPLACE THIS>
OAUTH_CLIENT_SECRET=<REPLACE THIS>
OPENID_CONFIG_URL=<REPLACE THIS>
SAML_CONF_DIR=/<ABSOLUTE PATH TO ONYX>/onyx/backend/ee/onyx/configs/saml_config
# Generally not useful for dev, we don't generally want to set up an SMTP server for dev
REQUIRE_EMAIL_VERIFICATION=False

View File

@@ -31,14 +31,16 @@
],
"presentation": {
"group": "1"
}
},
"stopAll": true
},
{
"name": "Web / Model / API",
"configurations": ["Web Server", "Model Server", "API Server"],
"presentation": {
"group": "1"
}
},
"stopAll": true
},
{
"name": "Celery (all)",
@@ -53,7 +55,8 @@
],
"presentation": {
"group": "1"
}
},
"stopAll": true
}
],
"configurations": [

View File

@@ -103,10 +103,10 @@ If using PowerShell, the command slightly differs:
Install the required python dependencies:
```bash
pip install -r onyx/backend/requirements/default.txt
pip install -r onyx/backend/requirements/dev.txt
pip install -r onyx/backend/requirements/ee.txt
pip install -r onyx/backend/requirements/model_server.txt
pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt
pip install -r backend/requirements/ee.txt
pip install -r backend/requirements/model_server.txt
```
Install Playwright for Python (headless browser required by the Web Connector)

View File

@@ -116,6 +116,14 @@ COPY ./assets /app/assets
ENV PYTHONPATH=/app
# Create non-root user for security best practices
RUN groupadd -g 1001 onyx && \
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
chown -R onyx:onyx /app && \
mkdir -p /var/log/onyx && \
chmod 755 /var/log/onyx && \
chown onyx:onyx /var/log/onyx
# Default command which does nothing
# This container is used by api server and background which specify their own CMD
CMD ["tail", "-f", "/dev/null"]

View File

@@ -9,11 +9,20 @@ visit https://github.com/onyx-dot-app/onyx."
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
ARG ONYX_VERSION=0.0.0-dev
ENV ONYX_VERSION=${ONYX_VERSION} \
DANSWER_RUNNING_IN_DOCKER="true"
DANSWER_RUNNING_IN_DOCKER="true" \
HF_HOME=/app/.cache/huggingface
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
# Create non-root user for security best practices
RUN mkdir -p /app && \
groupadd -g 1001 onyx && \
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
chown -R onyx:onyx /app && \
mkdir -p /var/log/onyx && \
chmod 755 /var/log/onyx && \
chown onyx:onyx /var/log/onyx
COPY ./requirements/model_server.txt /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade \
--retries 5 \
@@ -38,9 +47,11 @@ snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from sentence_transformers import SentenceTransformer; \
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
# In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while
# running Onyx, don't overwrite it with the built in cache folder
RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface
# In case the user has volumes mounted to /app/.cache/huggingface that they've downloaded while
# running Onyx, move the current contents of the cache folder to a temporary location to ensure
# it's preserved in order to combine with the user's cache contents
RUN mv /app/.cache/huggingface /app/.cache/temp_huggingface && \
chown -R onyx:onyx /app
WORKDIR /app

View File

@@ -0,0 +1,341 @@
"""tag-fix
Revision ID: 90e3b9af7da4
Revises: 62c3a055a141
Create Date: 2025-08-01 20:58:14.607624
"""
import json
import logging
import os
from typing import cast
from typing import Generator
from alembic import op
import sqlalchemy as sa
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.db.search_settings import SearchSettings
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.constants import AuthType
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
logger = logging.getLogger("alembic.runtime.migration")
# revision identifiers, used by Alembic.
revision = "90e3b9af7da4"
down_revision = "62c3a055a141"
branch_labels = None
depends_on = None
SKIP_TAG_FIX = os.environ.get("SKIP_TAG_FIX", "true").lower() == "true"
# override for cloud
if AUTH_TYPE == AuthType.CLOUD:
SKIP_TAG_FIX = True
def set_is_list_for_known_tags() -> None:
"""
Sets is_list to true for all tags that are known to be lists.
"""
LIST_METADATA: list[tuple[str, str]] = [
("CLICKUP", "tags"),
("CONFLUENCE", "labels"),
("DISCOURSE", "tags"),
("FRESHDESK", "emails"),
("GITHUB", "assignees"),
("GITHUB", "labels"),
("GURU", "tags"),
("GURU", "folders"),
("HUBSPOT", "associated_contact_ids"),
("HUBSPOT", "associated_company_ids"),
("HUBSPOT", "associated_deal_ids"),
("HUBSPOT", "associated_ticket_ids"),
("JIRA", "labels"),
("MEDIAWIKI", "categories"),
("ZENDESK", "labels"),
("ZENDESK", "content_tags"),
]
bind = op.get_bind()
for source, key in LIST_METADATA:
bind.execute(
sa.text(
f"""
UPDATE tag
SET is_list = true
WHERE tag_key = '{key}'
AND source = '{source}'
"""
)
)
def set_is_list_for_list_tags() -> None:
"""
Sets is_list to true for all tags which have multiple values for a given
document, key, and source triplet. This only works if we remove old tags
from the database.
"""
bind = op.get_bind()
bind.execute(
sa.text(
"""
UPDATE tag
SET is_list = true
FROM (
SELECT DISTINCT tag.tag_key, tag.source
FROM tag
JOIN document__tag ON tag.id = document__tag.tag_id
GROUP BY tag.tag_key, tag.source, document__tag.document_id
HAVING count(*) > 1
) AS list_tags
WHERE tag.tag_key = list_tags.tag_key
AND tag.source = list_tags.source
"""
)
)
def log_list_tags() -> None:
bind = op.get_bind()
result = bind.execute(
sa.text(
"""
SELECT DISTINCT source, tag_key
FROM tag
WHERE is_list
ORDER BY source, tag_key
"""
)
).fetchall()
logger.info(
"List tags:\n" + "\n".join(f" {source}: {key}" for source, key in result)
)
def remove_old_tags() -> None:
"""
Removes old tags from the database.
Previously, there was a bug where if a document got indexed with a tag and then
the document got reindexed, the old tag would not be removed.
This function removes those old tags by comparing it against the tags in vespa.
"""
current_search_settings, future_search_settings = active_search_settings()
document_index = get_default_document_index(
current_search_settings, future_search_settings
)
# Get the index name
if hasattr(document_index, "index_name"):
index_name = document_index.index_name
else:
# Default index name if we can't get it from the document_index
index_name = "danswer_index"
for batch in _get_batch_documents_with_multiple_tags():
n_deleted = 0
for document_id in batch:
true_metadata = _get_vespa_metadata(document_id, index_name)
tags = _get_document_tags(document_id)
# identify document__tags to delete
to_delete: list[str] = []
for tag_id, tag_key, tag_value in tags:
true_val = true_metadata.get(tag_key, "")
if (isinstance(true_val, list) and tag_value not in true_val) or (
isinstance(true_val, str) and tag_value != true_val
):
to_delete.append(str(tag_id))
if not to_delete:
continue
# delete old document__tags
bind = op.get_bind()
result = bind.execute(
sa.text(
f"""
DELETE FROM document__tag
WHERE document_id = '{document_id}'
AND tag_id IN ({','.join(to_delete)})
"""
)
)
n_deleted += result.rowcount
logger.info(f"Processed {len(batch)} documents and deleted {n_deleted} tags")
def active_search_settings() -> tuple[SearchSettings, SearchSettings | None]:
result = op.get_bind().execute(
sa.text(
"""
SELECT * FROM search_settings WHERE status = 'PRESENT' ORDER BY id DESC LIMIT 1
"""
)
)
search_settings_fetch = result.fetchall()
search_settings = (
SearchSettings(**search_settings_fetch[0]._asdict())
if search_settings_fetch
else None
)
result2 = op.get_bind().execute(
sa.text(
"""
SELECT * FROM search_settings WHERE status = 'FUTURE' ORDER BY id DESC LIMIT 1
"""
)
)
search_settings_future_fetch = result2.fetchall()
search_settings_future = (
SearchSettings(**search_settings_future_fetch[0]._asdict())
if search_settings_future_fetch
else None
)
if not isinstance(search_settings, SearchSettings):
raise RuntimeError(
"current search settings is of type " + str(type(search_settings))
)
if (
not isinstance(search_settings_future, SearchSettings)
and search_settings_future is not None
):
raise RuntimeError(
"future search settings is of type " + str(type(search_settings_future))
)
return search_settings, search_settings_future
def _get_batch_documents_with_multiple_tags(
batch_size: int = 128,
) -> Generator[list[str], None, None]:
"""
Returns a list of document ids which contain a one to many tag.
The document may either contain a list metadata value, or may contain leftover
old tags from reindexing.
"""
offset_clause = ""
bind = op.get_bind()
while True:
batch = bind.execute(
sa.text(
f"""
SELECT DISTINCT document__tag.document_id
FROM tag
JOIN document__tag ON tag.id = document__tag.tag_id
GROUP BY tag.tag_key, tag.source, document__tag.document_id
HAVING count(*) > 1 {offset_clause}
ORDER BY document__tag.document_id
LIMIT {batch_size}
"""
)
).fetchall()
if not batch:
break
doc_ids = [document_id for document_id, in batch]
yield doc_ids
offset_clause = f"AND document__tag.document_id > '{doc_ids[-1]}'"
def _get_vespa_metadata(
document_id: str, index_name: str
) -> dict[str, str | list[str]]:
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
# Document-Selector language
selection = (
f"{index_name}.document_id=='{document_id}' and {index_name}.chunk_id==0"
)
params: dict[str, str | int] = {
"selection": selection,
"wantedDocumentCount": 1,
"fieldSet": f"{index_name}:metadata",
}
with get_vespa_http_client() as client:
resp = client.get(url, params=params)
resp.raise_for_status()
docs = resp.json().get("documents", [])
if not docs:
raise RuntimeError(f"No chunk-0 found for document {document_id}")
# for some reason, metadata is a string
metadata = docs[0]["fields"]["metadata"]
return json.loads(metadata)
def _get_document_tags(document_id: str) -> list[tuple[int, str, str]]:
bind = op.get_bind()
result = bind.execute(
sa.text(
f"""
SELECT tag.id, tag.tag_key, tag.tag_value
FROM tag
JOIN document__tag ON tag.id = document__tag.tag_id
WHERE document__tag.document_id = '{document_id}'
"""
)
).fetchall()
return cast(list[tuple[int, str, str]], result)
def upgrade() -> None:
op.add_column(
"tag",
sa.Column("is_list", sa.Boolean(), nullable=False, server_default="false"),
)
op.drop_constraint(
constraint_name="_tag_key_value_source_uc",
table_name="tag",
type_="unique",
)
op.create_unique_constraint(
constraint_name="_tag_key_value_source_list_uc",
table_name="tag",
columns=["tag_key", "tag_value", "source", "is_list"],
)
set_is_list_for_known_tags()
if SKIP_TAG_FIX:
logger.warning(
"Skipping removal of old tags. "
"This can cause issues when using the knowledge graph, or "
"when filtering for documents by tags."
)
log_list_tags()
return
remove_old_tags()
set_is_list_for_list_tags()
# debug
log_list_tags()
def downgrade() -> None:
# the migration adds and populates the is_list column, and removes old bugged tags
# there isn't a point in adding back the bugged tags, so we just drop the column
op.drop_constraint(
constraint_name="_tag_key_value_source_list_uc",
table_name="tag",
type_="unique",
)
op.create_unique_constraint(
constraint_name="_tag_key_value_source_uc",
table_name="tag",
columns=["tag_key", "tag_value", "source"],
)
op.drop_column("tag", "is_list")

View File

@@ -0,0 +1,33 @@
"""Pause finished user file connectors
Revision ID: b558f51620b4
Revises: 90e3b9af7da4
Create Date: 2025-08-15 17:17:02.456704
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "b558f51620b4"
down_revision = "90e3b9af7da4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Set all user file connector credential pairs with ACTIVE status to PAUSED
# This ensures user files don't continue to run indexing tasks after processing
op.execute(
"""
UPDATE connector_credential_pair
SET status = 'PAUSED'
WHERE is_user_file = true
AND status = 'ACTIVE'
"""
)
def downgrade() -> None:
pass

View File

@@ -102,6 +102,19 @@ TEAMS_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("TEAMS_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
#####
# SharePoint
#####
# In seconds, default is 30 minutes
SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY") or 30 * 60
)
# In seconds, default is 5 minutes
SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
####
# Celery Job Frequency

View File

@@ -2,18 +2,14 @@ from collections.abc import Callable
from collections.abc import Generator
from typing import Optional
from typing import Protocol
from typing import TYPE_CHECKING
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.access.models import DocExternalAccess
from onyx.context.search.models import InferenceChunk
from onyx.db.models import ConnectorCredentialPair
from onyx.db.utils import DocumentRow
from onyx.db.utils import SortOrder
# Avoid circular imports
if TYPE_CHECKING:
from ee.onyx.db.external_perm import ExternalUserGroup # noqa
from onyx.access.models import DocExternalAccess # noqa
from onyx.db.models import ConnectorCredentialPair # noqa
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface # noqa
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
class FetchAllDocumentsFunction(Protocol):
@@ -52,20 +48,20 @@ class FetchAllDocumentsIdsFunction(Protocol):
# Defining the input/output types for the sync functions
DocSyncFuncType = Callable[
[
"ConnectorCredentialPair",
ConnectorCredentialPair,
FetchAllDocumentsFunction,
FetchAllDocumentsIdsFunction,
Optional["IndexingHeartbeatInterface"],
Optional[IndexingHeartbeatInterface],
],
Generator["DocExternalAccess", None, None],
Generator[DocExternalAccess, None, None],
]
GroupSyncFuncType = Callable[
[
str, # tenant_id
"ConnectorCredentialPair", # cc_pair
ConnectorCredentialPair, # cc_pair
],
Generator["ExternalUserGroup", None, None],
Generator[ExternalUserGroup, None, None],
]
# list of chunks to be censored and the user email. returns censored chunks

View File

@@ -0,0 +1,36 @@
from collections.abc import Generator
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from ee.onyx.external_permissions.utils import generic_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
SHAREPOINT_DOC_SYNC_TAG = "sharepoint_doc_sync"
def sharepoint_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None = None,
) -> Generator[DocExternalAccess, None, None]:
sharepoint_connector = SharepointConnector(
**cc_pair.connector.connector_specific_config,
)
sharepoint_connector.load_credentials(cc_pair.credential.credential_json)
yield from generic_doc_sync(
cc_pair=cc_pair,
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
callback=callback,
doc_source=DocumentSource.SHAREPOINT,
slim_connector=sharepoint_connector,
label=SHAREPOINT_DOC_SYNC_TAG,
)

View File

@@ -0,0 +1,63 @@
from collections.abc import Generator
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.sharepoint.permission_utils import (
get_sharepoint_external_groups,
)
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
logger = setup_logger()
def sharepoint_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> Generator[ExternalUserGroup, None, None]:
"""Sync SharePoint groups and their members"""
# Get site URLs from connector config
connector_config = cc_pair.connector.connector_specific_config
# Create SharePoint connector instance and load credentials
connector = SharepointConnector(**connector_config)
connector.load_credentials(cc_pair.credential.credential_json)
if not connector.msal_app:
raise RuntimeError("MSAL app not initialized in connector")
if not connector.sp_tenant_domain:
raise RuntimeError("Tenant domain not initialized in connector")
# Get site descriptors from connector (either configured sites or all sites)
site_descriptors = connector.site_descriptors or connector.fetch_sites()
if not site_descriptors:
raise RuntimeError("No SharePoint sites found for group sync")
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
msal_app = connector.msal_app
sp_tenant_domain = connector.sp_tenant_domain
# Process each site
for site_descriptor in site_descriptors:
logger.debug(f"Processing site: {site_descriptor.url}")
# Create client context for the site using connector's MSAL app
ctx = ClientContext(site_descriptor.url).with_access_token(
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
)
# Get external groups for this site
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
# Yield each group
for group in external_groups:
logger.debug(
f"Found group: {group.id} with {len(group.user_emails)} members"
)
yield group

View File

@@ -0,0 +1,684 @@
import re
from collections import deque
from typing import Any
from urllib.parse import unquote
from urllib.parse import urlparse
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.runtime.client_request import ClientRequestException # type: ignore
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from office365.sharepoint.permissions.securable_object import RoleAssignmentCollection # type: ignore[import-untyped]
from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.connectors.sharepoint.connector import sleep_and_retry
from onyx.utils.logger import setup_logger
logger = setup_logger()
# These values represent different types of SharePoint principals used in permission assignments
USER_PRINCIPAL_TYPE = 1 # Individual user accounts
ANONYMOUS_USER_PRINCIPAL_TYPE = 3 # Anonymous/unauthenticated users (public access)
AZURE_AD_GROUP_PRINCIPAL_TYPE = 4 # Azure Active Directory security groups
SHAREPOINT_GROUP_PRINCIPAL_TYPE = 8 # SharePoint site groups (local to the site)
MICROSOFT_DOMAIN = ".onmicrosoft"
# Limited Access role type, limited access is a travel through permission not a actual permission
LIMITED_ACCESS_ROLE_TYPES = [1, 9]
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
class SharepointGroup(BaseModel):
model_config = {"frozen": True}
name: str
login_name: str
principal_type: int
class GroupsResult(BaseModel):
groups_to_emails: dict[str, set[str]]
found_public_group: bool
def _get_azuread_group_guid_by_name(
graph_client: GraphClient, group_name: str
) -> str | None:
try:
# Search for groups by display name
groups = sleep_and_retry(
graph_client.groups.filter(f"displayName eq '{group_name}'").get(),
"get_azuread_group_guid_by_name",
)
if groups and len(groups) > 0:
return groups[0].id
return None
except Exception as e:
logger.error(f"Failed to get Azure AD group GUID for name {group_name}: {e}")
return None
def _extract_guid_from_claims_token(claims_token: str) -> str | None:
try:
# Pattern to match GUID in claims token
# Claims tokens often have format: c:0o.c|provider|GUID_suffix
guid_pattern = r"([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})"
match = re.search(guid_pattern, claims_token, re.IGNORECASE)
if match:
return match.group(1)
return None
except Exception as e:
logger.error(f"Failed to extract GUID from claims token {claims_token}: {e}")
return None
def _get_group_guid_from_identifier(
graph_client: GraphClient, identifier: str
) -> str | None:
try:
# Check if it's already a GUID
guid_pattern = r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
if re.match(guid_pattern, identifier, re.IGNORECASE):
return identifier
# Check if it's a SharePoint claims token
if identifier.startswith("c:0") and "|" in identifier:
guid = _extract_guid_from_claims_token(identifier)
if guid:
logger.info(f"Extracted GUID {guid} from claims token {identifier}")
return guid
# Try to search by display name as fallback
return _get_azuread_group_guid_by_name(graph_client, identifier)
except Exception as e:
logger.error(f"Failed to get group GUID from identifier {identifier}: {e}")
return None
def _get_security_group_owners(graph_client: GraphClient, group_id: str) -> list[str]:
try:
# Get group owners using Graph API
group = graph_client.groups[group_id]
owners = sleep_and_retry(
group.owners.get_all(page_loaded=lambda _: None),
"get_security_group_owners",
)
owner_emails: list[str] = []
logger.info(f"Owners: {owners}")
for owner in owners:
owner_data = owner.to_json()
# Extract email from the JSON data
mail: str | None = owner_data.get("mail")
user_principal_name: str | None = owner_data.get("userPrincipalName")
# Check if owner is a user and has an email
if mail:
if MICROSOFT_DOMAIN in mail:
mail = mail.replace(MICROSOFT_DOMAIN, "")
owner_emails.append(mail)
elif user_principal_name:
if MICROSOFT_DOMAIN in user_principal_name:
user_principal_name = user_principal_name.replace(
MICROSOFT_DOMAIN, ""
)
owner_emails.append(user_principal_name)
logger.info(
f"Retrieved {len(owner_emails)} owners from security group {group_id}"
)
return owner_emails
except Exception as e:
logger.error(f"Failed to get security group owners for group {group_id}: {e}")
return []
def _get_sharepoint_list_item_id(drive_item: DriveItem) -> str | None:
try:
# First try to get the list item directly from the drive item
if hasattr(drive_item, "listItem"):
list_item = drive_item.listItem
if list_item:
# Load the list item properties to get the ID
sleep_and_retry(list_item.get(), "get_sharepoint_list_item_id")
if hasattr(list_item, "id") and list_item.id:
return str(list_item.id)
# The SharePoint list item ID is typically available in the sharepointIds property
sharepoint_ids = getattr(drive_item, "sharepoint_ids", None)
if sharepoint_ids and hasattr(sharepoint_ids, "listItemId"):
return sharepoint_ids.listItemId
# Alternative: try to get it from the properties
properties = getattr(drive_item, "properties", None)
if properties:
# Sometimes the SharePoint list item ID is in the properties
for prop_name, prop_value in properties.items():
if "listitemid" in prop_name.lower():
return str(prop_value)
return None
except Exception as e:
logger.error(
f"Error getting SharePoint list item ID for item {drive_item.id}: {e}"
)
raise e
def _is_public_item(drive_item: DriveItem) -> bool:
is_public = False
try:
permissions = sleep_and_retry(
drive_item.permissions.get_all(page_loaded=lambda _: None), "is_public_item"
)
for permission in permissions:
if permission.link and (
permission.link.scope == "anonymous"
or permission.link.scope == "organization"
):
is_public = True
break
return is_public
except Exception as e:
logger.error(f"Failed to check if item {drive_item.id} is public: {e}")
return False
def _is_public_login_name(login_name: str) -> bool:
# Patterns that indicate public access
# This list is derived from the below link
# https://learn.microsoft.com/en-us/answers/questions/2085339/guid-in-the-loginname-of-site-user-everyone-except
public_login_patterns: list[str] = [
"c:0-.f|rolemanager|spo-grid-all-users/",
"c:0(.s|true",
]
for pattern in public_login_patterns:
if pattern in login_name:
logger.info(f"Login name {login_name} is public")
return True
return False
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
def _get_group_name_with_suffix(
login_name: str, group_name: str, graph_client: GraphClient
) -> str:
ad_group_suffix = _get_group_guid_from_identifier(graph_client, login_name)
return f"{group_name}_{ad_group_suffix}"
def _get_sharepoint_groups(
client_context: ClientContext, group_name: str, graph_client: GraphClient
) -> tuple[set[SharepointGroup], set[str]]:
groups: set[SharepointGroup] = set()
user_emails: set[str] = set()
def process_users(users: list[Any]) -> None:
nonlocal groups, user_emails
for user in users:
logger.debug(f"User: {user.to_json()}")
if user.principal_type == USER_PRINCIPAL_TYPE and hasattr(
user, "user_principal_name"
):
if user.user_principal_name:
email = user.user_principal_name
if MICROSOFT_DOMAIN in email:
email = email.replace(MICROSOFT_DOMAIN, "")
user_emails.add(email)
else:
logger.warning(
f"User don't have a user principal name: {user.login_name}"
)
elif user.principal_type in [
AZURE_AD_GROUP_PRINCIPAL_TYPE,
SHAREPOINT_GROUP_PRINCIPAL_TYPE,
]:
name = user.title
if user.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
name = _get_group_name_with_suffix(
user.login_name, name, graph_client
)
groups.add(
SharepointGroup(
login_name=user.login_name,
principal_type=user.principal_type,
name=name,
)
)
group = client_context.web.site_groups.get_by_name(group_name)
sleep_and_retry(
group.users.get_all(page_loaded=process_users), "get_sharepoint_groups"
)
return groups, user_emails
def _get_azuread_groups(
graph_client: GraphClient, group_name: str
) -> tuple[set[SharepointGroup], set[str]]:
group_id = _get_group_guid_from_identifier(graph_client, group_name)
if not group_id:
logger.error(f"Failed to get Azure AD group GUID for name {group_name}")
return set(), set()
group = graph_client.groups[group_id]
groups: set[SharepointGroup] = set()
user_emails: set[str] = set()
def process_members(members: list[Any]) -> None:
nonlocal groups, user_emails
for member in members:
member_data = member.to_json()
logger.debug(f"Member: {member_data}")
# Check for user-specific attributes
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
display_name = member_data.get("displayName") or member_data.get(
"display_name"
)
# Check object attributes directly (if available)
is_user = False
is_group = False
# Users typically have userPrincipalName or mail
if user_principal_name or (mail and "@" in str(mail)):
is_user = True
# Groups typically have displayName but no userPrincipalName
elif display_name and not user_principal_name:
# Additional check: try to access group-specific properties
if (
hasattr(member, "groupTypes")
or member_data.get("groupTypes") is not None
):
is_group = True
# Or check if it has an 'id' field typical for groups
elif member_data.get("id") and not user_principal_name:
is_group = True
# Check the object type name (fallback)
if not is_user and not is_group:
obj_type = type(member).__name__.lower()
if "user" in obj_type:
is_user = True
elif "group" in obj_type:
is_group = True
# Process based on identification
if is_user:
if user_principal_name:
email = user_principal_name
if MICROSOFT_DOMAIN in email:
email = email.replace(MICROSOFT_DOMAIN, "")
user_emails.add(email)
elif mail:
email = mail
if MICROSOFT_DOMAIN in email:
email = email.replace(MICROSOFT_DOMAIN, "")
user_emails.add(email)
logger.info(f"Added user: {user_principal_name or mail}")
elif is_group:
if not display_name:
logger.error(f"No display name for group: {member_data.get('id')}")
continue
name = _get_group_name_with_suffix(
member_data.get("id", ""), display_name, graph_client
)
groups.add(
SharepointGroup(
login_name=member_data.get("id", ""), # Use ID for groups
principal_type=AZURE_AD_GROUP_PRINCIPAL_TYPE,
name=name,
)
)
logger.info(f"Added group: {name}")
else:
# Log unidentified members for debugging
logger.warning(f"Could not identify member type for: {member_data}")
sleep_and_retry(
group.members.get_all(page_loaded=process_members), "get_azuread_groups"
)
owner_emails = _get_security_group_owners(graph_client, group_id)
user_emails.update(owner_emails)
return groups, user_emails
def _get_groups_and_members_recursively(
client_context: ClientContext,
graph_client: GraphClient,
groups: set[SharepointGroup],
is_group_sync: bool = False,
) -> GroupsResult:
"""
Get all groups and their members recursively.
"""
group_queue: deque[SharepointGroup] = deque(groups)
visited_groups: set[str] = set()
visited_group_name_to_emails: dict[str, set[str]] = {}
found_public_group = False
while group_queue:
group = group_queue.popleft()
if group.login_name in visited_groups:
continue
visited_groups.add(group.login_name)
visited_group_name_to_emails[group.name] = set()
logger.info(
f"Processing group: {group.name} principal type: {group.principal_type}"
)
if group.principal_type == SHAREPOINT_GROUP_PRINCIPAL_TYPE:
group_info, user_emails = _get_sharepoint_groups(
client_context, group.login_name, graph_client
)
visited_group_name_to_emails[group.name].update(user_emails)
if group_info:
group_queue.extend(group_info)
if group.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
try:
# if the site is public, we have default groups assigned to it, so we return early
if _is_public_login_name(group.login_name):
found_public_group = True
if not is_group_sync:
return GroupsResult(
groups_to_emails={}, found_public_group=True
)
else:
# we don't want to sync public groups, so we skip them
continue
group_info, user_emails = _get_azuread_groups(
graph_client, group.login_name
)
visited_group_name_to_emails[group.name].update(user_emails)
if group_info:
group_queue.extend(group_info)
except ClientRequestException as e:
# If the group is not found, we skip it. There is a chance that group is still referenced
# in sharepoint but it is removed from Azure AD. There is no actual documentation on this, but based on
# our testing we have seen this happen.
if e.response is not None and e.response.status_code == 404:
logger.warning(f"Group {group.login_name} not found")
continue
raise e
return GroupsResult(
groups_to_emails=visited_group_name_to_emails,
found_public_group=found_public_group,
)
def get_external_access_from_sharepoint(
client_context: ClientContext,
graph_client: GraphClient,
drive_name: str | None,
drive_item: DriveItem | None,
site_page: dict[str, Any] | None,
add_prefix: bool = False,
) -> ExternalAccess:
"""
Get external access information from SharePoint.
"""
groups: set[SharepointGroup] = set()
user_emails: set[str] = set()
group_ids: set[str] = set()
# Add all members to a processing set first
def add_user_and_group_to_sets(
role_assignments: RoleAssignmentCollection,
) -> None:
nonlocal user_emails, groups
for assignment in role_assignments:
logger.debug(f"Assignment: {assignment.to_json()}")
if assignment.role_definition_bindings:
is_limited_access = True
for role_definition_binding in assignment.role_definition_bindings:
if (
role_definition_binding.role_type_kind
not in LIMITED_ACCESS_ROLE_TYPES
or role_definition_binding.name not in LIMITED_ACCESS_ROLE_NAMES
):
is_limited_access = False
break
# Skip if the role is only Limited Access, because this is not a actual permission its a travel through permission
if is_limited_access:
logger.info(
"Skipping assignment because it has only Limited Access role"
)
continue
if assignment.member:
member = assignment.member
if member.principal_type == USER_PRINCIPAL_TYPE and hasattr(
member, "user_principal_name"
):
email = member.user_principal_name
if MICROSOFT_DOMAIN in email:
email = email.replace(MICROSOFT_DOMAIN, "")
user_emails.add(email)
elif member.principal_type in [
AZURE_AD_GROUP_PRINCIPAL_TYPE,
SHAREPOINT_GROUP_PRINCIPAL_TYPE,
]:
name = member.title
if member.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
name = _get_group_name_with_suffix(
member.login_name, name, graph_client
)
groups.add(
SharepointGroup(
login_name=member.login_name,
principal_type=member.principal_type,
name=name,
)
)
if drive_item and drive_name:
# Here we check if the item have have any public links, if so we return early
is_public = _is_public_item(drive_item)
if is_public:
logger.info(f"Item {drive_item.id} is public")
return ExternalAccess(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=True,
)
item_id = _get_sharepoint_list_item_id(drive_item)
if not item_id:
raise RuntimeError(
f"Failed to get SharePoint list item ID for item {drive_item.id}"
)
if drive_name == "Shared Documents":
drive_name = "Documents"
item = client_context.web.lists.get_by_title(drive_name).items.get_by_id(
item_id
)
sleep_and_retry(
item.role_assignments.expand(["Member", "RoleDefinitionBindings"]).get_all(
page_loaded=add_user_and_group_to_sets,
),
"get_external_access_from_sharepoint",
)
elif site_page:
site_url = site_page.get("webUrl")
# Prefer server-relative URL to avoid OData filters that break on apostrophes
server_relative_url = unquote(urlparse(site_url).path)
file_obj = client_context.web.get_file_by_server_relative_url(
server_relative_url
)
item = file_obj.listItemAllFields
sleep_and_retry(
item.role_assignments.expand(["Member", "RoleDefinitionBindings"]).get_all(
page_loaded=add_user_and_group_to_sets,
),
"get_external_access_from_sharepoint",
)
else:
raise RuntimeError("No drive item or site page provided")
groups_and_members: GroupsResult = _get_groups_and_members_recursively(
client_context, graph_client, groups
)
# If the site is public, w have default groups assigned to it, so we return early
if groups_and_members.found_public_group:
return ExternalAccess(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=True,
)
for group_name, _ in groups_and_members.groups_to_emails.items():
if add_prefix:
group_name = build_ext_group_name_for_onyx(
group_name, DocumentSource.SHAREPOINT
)
group_ids.add(group_name.lower())
logger.info(f"User emails: {len(user_emails)}")
logger.info(f"Group IDs: {len(group_ids)}")
return ExternalAccess(
external_user_emails=user_emails,
external_user_group_ids=group_ids,
is_public=False,
)
def get_sharepoint_external_groups(
client_context: ClientContext, graph_client: GraphClient
) -> list[ExternalUserGroup]:
groups: set[SharepointGroup] = set()
def add_group_to_sets(role_assignments: RoleAssignmentCollection) -> None:
nonlocal groups
for assignment in role_assignments:
if assignment.role_definition_bindings:
is_limited_access = True
for role_definition_binding in assignment.role_definition_bindings:
if (
role_definition_binding.role_type_kind
not in LIMITED_ACCESS_ROLE_TYPES
or role_definition_binding.name not in LIMITED_ACCESS_ROLE_NAMES
):
is_limited_access = False
break
# Skip if the role assignment is only Limited Access, because this is not a actual permission its
# a travel through permission
if is_limited_access:
logger.info(
"Skipping assignment because it has only Limited Access role"
)
continue
if assignment.member:
member = assignment.member
if member.principal_type in [
AZURE_AD_GROUP_PRINCIPAL_TYPE,
SHAREPOINT_GROUP_PRINCIPAL_TYPE,
]:
name = member.title
if member.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
name = _get_group_name_with_suffix(
member.login_name, name, graph_client
)
groups.add(
SharepointGroup(
login_name=member.login_name,
principal_type=member.principal_type,
name=name,
)
)
sleep_and_retry(
client_context.web.role_assignments.expand(
["Member", "RoleDefinitionBindings"]
).get_all(page_loaded=add_group_to_sets),
"get_sharepoint_external_groups",
)
groups_and_members: GroupsResult = _get_groups_and_members_recursively(
client_context, graph_client, groups, is_group_sync=True
)
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
azure_ad_groups = sleep_and_retry(
graph_client.groups.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups",
)
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
ad_groups_to_emails: dict[str, set[str]] = {}
for group in azure_ad_groups:
# If the group is already identified, we don't need to get the members
if group.display_name in identified_groups:
continue
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
name = group.display_name
name = _get_group_name_with_suffix(group.id, name, graph_client)
members = sleep_and_retry(
group.members.get_all(page_loaded=lambda _: None),
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
)
for member in members:
member_data = member.to_json()
user_principal_name = member_data.get("userPrincipalName")
mail = member_data.get("mail")
if not ad_groups_to_emails.get(name):
ad_groups_to_emails[name] = set()
if user_principal_name:
if MICROSOFT_DOMAIN in user_principal_name:
user_principal_name = user_principal_name.replace(
MICROSOFT_DOMAIN, ""
)
ad_groups_to_emails[name].add(user_principal_name)
elif mail:
if MICROSOFT_DOMAIN in mail:
mail = mail.replace(MICROSOFT_DOMAIN, "")
ad_groups_to_emails[name].add(mail)
external_user_groups: list[ExternalUserGroup] = []
for group_name, emails in groups_and_members.groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
for group_name, emails in ad_groups_to_emails.items():
external_user_group = ExternalUserGroup(
id=group_name,
user_emails=list(emails),
)
external_user_groups.append(external_user_group)
return external_user_groups

View File

@@ -11,6 +11,8 @@ from ee.onyx.configs.app_configs import GITHUB_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import JIRA_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import TEAMS_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
@@ -29,6 +31,8 @@ from ee.onyx.external_permissions.perm_sync_types import GroupSyncFuncType
from ee.onyx.external_permissions.salesforce.postprocessing import (
censor_salesforce_chunks,
)
from ee.onyx.external_permissions.sharepoint.doc_sync import sharepoint_doc_sync
from ee.onyx.external_permissions.sharepoint.group_sync import sharepoint_group_sync
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
from ee.onyx.external_permissions.teams.doc_sync import teams_doc_sync
from onyx.configs.constants import DocumentSource
@@ -156,6 +160,18 @@ _SOURCE_TO_SYNC_CONFIG: dict[DocumentSource, SyncConfig] = {
initial_index_should_sync=True,
),
),
DocumentSource.SHAREPOINT: SyncConfig(
doc_sync_config=DocSyncConfig(
doc_sync_frequency=SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY,
doc_sync_func=sharepoint_doc_sync,
initial_index_should_sync=True,
),
group_sync_config=GroupSyncConfig(
group_sync_frequency=SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY,
group_sync_func=sharepoint_group_sync,
group_sync_is_cc_pair_agnostic=False,
),
),
}

View File

@@ -206,7 +206,7 @@ def _handle_standard_answers(
restate_question_blocks = get_restate_blocks(
msg=query_msg.message,
is_bot_msg=message_info.is_bot_msg,
is_slash_command=message_info.is_slash_command,
)
answer_blocks = build_standard_answer_blocks(

View File

@@ -67,7 +67,7 @@ def generate_chat_messages_report(
file_id = file_store.save_file(
content=temp_file,
display_name=file_name,
file_origin=FileOrigin.OTHER,
file_origin=FileOrigin.GENERATED_REPORT,
file_type="text/csv",
)
@@ -99,7 +99,7 @@ def generate_user_report(
file_id = file_store.save_file(
content=temp_file,
display_name=file_name,
file_origin=FileOrigin.OTHER,
file_origin=FileOrigin.GENERATED_REPORT,
file_type="text/csv",
)

View File

@@ -1,34 +1,5 @@
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
MODEL_WARM_UP_STRING = "hi " * 512
INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi " * 16
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
DEFAULT_VERTEX_MODEL = "text-embedding-005"
class EmbeddingModelTextType:
PROVIDER_TEXT_TYPE_MAP = {
EmbeddingProvider.COHERE: {
EmbedTextType.QUERY: "search_query",
EmbedTextType.PASSAGE: "search_document",
},
EmbeddingProvider.VOYAGE: {
EmbedTextType.QUERY: "query",
EmbedTextType.PASSAGE: "document",
},
EmbeddingProvider.GOOGLE: {
EmbedTextType.QUERY: "RETRIEVAL_QUERY",
EmbedTextType.PASSAGE: "RETRIEVAL_DOCUMENT",
},
}
@staticmethod
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
class GPUStatus:

View File

@@ -1,55 +1,30 @@
import asyncio
import json
import time
from types import TracebackType
from typing import cast
from typing import Any
from typing import Optional
import aioboto3 # type: ignore
import httpx
import openai
import vertexai # type: ignore
import voyageai # type: ignore
from cohere import AsyncClient as CohereAsyncClient
from fastapi import APIRouter
from fastapi import HTTPException
from fastapi import Request
from google.oauth2 import service_account # type: ignore
from litellm import aembedding
from litellm.exceptions import RateLimitError
from retry import retry
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from vertexai.language_models import TextEmbeddingInput # type: ignore
from vertexai.language_models import TextEmbeddingModel # type: ignore
from model_server.constants import DEFAULT_COHERE_MODEL
from model_server.constants import DEFAULT_OPENAI_MODEL
from model_server.constants import DEFAULT_VERTEX_MODEL
from model_server.constants import DEFAULT_VOYAGE_MODEL
from model_server.constants import EmbeddingModelTextType
from model_server.constants import EmbeddingProvider
from model_server.utils import pass_aws_key
from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger
from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
from shared_configs.configs import VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
from shared_configs.utils import batch_list
logger = setup_logger()
router = APIRouter(prefix="/encoder")
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
_RERANK_MODEL: Optional["CrossEncoder"] = None
@@ -57,315 +32,6 @@ _RERANK_MODEL: Optional["CrossEncoder"] = None
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
# OpenAI only allows 2048 embeddings to be computed at once
_OPENAI_MAX_INPUT_LEN = 2048
# Cohere allows up to 96 embeddings in a single embedding calling
_COHERE_MAX_INPUT_LEN = 96
# Authentication error string constants
_AUTH_ERROR_401 = "401"
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
_AUTH_ERROR_PERMISSION = "permission"
def is_authentication_error(error: Exception) -> bool:
"""Check if an exception is related to authentication issues.
Args:
error: The exception to check
Returns:
bool: True if the error appears to be authentication-related
"""
error_str = str(error).lower()
return (
_AUTH_ERROR_401 in error_str
or _AUTH_ERROR_UNAUTHORIZED in error_str
or _AUTH_ERROR_INVALID_API_KEY in error_str
or _AUTH_ERROR_PERMISSION in error_str
)
def format_embedding_error(
error: Exception,
service_name: str,
model: str | None,
provider: EmbeddingProvider,
sanitized_api_key: str | None = None,
status_code: int | None = None,
) -> str:
"""
Format a standardized error string for embedding errors.
"""
detail = f"Status {status_code}" if status_code else f"{type(error)}"
return (
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
f"Model: {model} "
f"Provider: {provider} "
f"API Key: {sanitized_api_key} "
f"Exception: {error}"
)
# Custom exception for authentication errors
class AuthenticationError(Exception):
"""Raised when authentication fails with a provider."""
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
self.provider = provider
self.message = message
super().__init__(f"{provider} authentication failed: {message}")
class CloudEmbedding:
def __init__(
self,
api_key: str,
provider: EmbeddingProvider,
api_url: str | None = None,
api_version: str | None = None,
timeout: int = API_BASED_EMBEDDING_TIMEOUT,
) -> None:
self.provider = provider
self.api_key = api_key
self.api_url = api_url
self.api_version = api_version
self.timeout = timeout
self.http_client = httpx.AsyncClient(timeout=timeout)
self._closed = False
self.sanitized_api_key = api_key[:4] + "********" + api_key[-4:]
async def _embed_openai(
self, texts: list[str], model: str | None, reduced_dimension: int | None
) -> list[Embedding]:
if not model:
model = DEFAULT_OPENAI_MODEL
# Use the OpenAI specific timeout for this one
client = openai.AsyncOpenAI(
api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT
)
final_embeddings: list[Embedding] = []
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
async def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_COHERE_MODEL
client = CohereAsyncClient(api_key=self.api_key)
final_embeddings: list[Embedding] = []
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
# Does not use the same tokenizer as the Onyx API server but it's approximately the same
# empirically it's only off by a very few tokens so it's not a big deal
response = await client.embed(
texts=text_batch,
model=model,
input_type=embedding_type,
truncate="END",
)
final_embeddings.extend(cast(list[Embedding], response.embeddings))
return final_embeddings
async def _embed_voyage(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_VOYAGE_MODEL
client = voyageai.AsyncClient(
api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT
)
response = await client.embed(
texts=texts,
model=model,
input_type=embedding_type,
truncation=True,
)
return response.embeddings
async def _embed_azure(
self, texts: list[str], model: str | None
) -> list[Embedding]:
response = await aembedding(
model=model,
input=texts,
timeout=API_BASED_EMBEDDING_TIMEOUT,
api_key=self.api_key,
api_base=self.api_url,
api_version=self.api_version,
)
embeddings = [embedding["embedding"] for embedding in response.data]
return embeddings
async def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_VERTEX_MODEL
credentials = service_account.Credentials.from_service_account_info(
json.loads(self.api_key)
)
project_id = json.loads(self.api_key)["project_id"]
vertexai.init(project=project_id, credentials=credentials)
client = TextEmbeddingModel.from_pretrained(model)
inputs = [TextEmbeddingInput(text, embedding_type) for text in texts]
# Split into batches of 25 texts
max_texts_per_batch = VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
batches = [
inputs[i : i + max_texts_per_batch]
for i in range(0, len(inputs), max_texts_per_batch)
]
# Dispatch all embedding calls asynchronously at once
tasks = [
client.get_embeddings_async(batch, auto_truncate=True) for batch in batches
]
# Wait for all tasks to complete in parallel
results = await asyncio.gather(*tasks)
return [embedding.values for batch in results for embedding in batch]
async def _embed_litellm_proxy(
self, texts: list[str], model_name: str | None
) -> list[Embedding]:
if not model_name:
raise ValueError("Model name is required for LiteLLM proxy embedding.")
if not self.api_url:
raise ValueError("API URL is required for LiteLLM proxy embedding.")
headers = (
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
)
response = await self.http_client.post(
self.api_url,
json={
"model": model_name,
"input": texts,
},
headers=headers,
)
response.raise_for_status()
result = response.json()
return [embedding["embedding"] for embedding in result["data"]]
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
async def embed(
self,
*,
texts: list[str],
text_type: EmbedTextType,
model_name: str | None = None,
deployment_name: str | None = None,
reduced_dimension: int | None = None,
) -> list[Embedding]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except openai.AuthenticationError:
raise AuthenticationError(provider="OpenAI")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e,
str(self.provider),
model_name or deployment_name,
self.provider,
sanitized_api_key=self.sanitized_api_key,
status_code=e.response.status_code,
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
except Exception as e:
if is_authentication_error(e):
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e,
str(self.provider),
model_name or deployment_name,
self.provider,
sanitized_api_key=self.sanitized_api_key,
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
@staticmethod
def create(
api_key: str,
provider: EmbeddingProvider,
api_url: str | None = None,
api_version: str | None = None,
) -> "CloudEmbedding":
logger.debug(f"Creating Embedding instance for provider: {provider}")
return CloudEmbedding(api_key, provider, api_url, api_version)
async def aclose(self) -> None:
"""Explicitly close the client."""
if not self._closed:
await self.http_client.aclose()
self._closed = True
async def __aenter__(self) -> "CloudEmbedding":
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
def __del__(self) -> None:
"""Finalizer to warn about unclosed clients."""
if not self._closed:
logger.warning(
"CloudEmbedding was not properly closed. Use 'async with' or call aclose()"
)
def get_embedding_model(
model_name: str,
@@ -404,20 +70,34 @@ def get_local_reranking_model(
return _RERANK_MODEL
ENCODING_RETRIES = 3
ENCODING_RETRY_DELAY = 0.1
def _concurrent_embedding(
texts: list[str], model: "SentenceTransformer", normalize_embeddings: bool
) -> Any:
"""Synchronous wrapper for concurrent_embedding to use with run_in_executor."""
for _ in range(ENCODING_RETRIES):
try:
return model.encode(texts, normalize_embeddings=normalize_embeddings)
except RuntimeError as e:
# There is a concurrency bug in the SentenceTransformer library that causes
# the model to fail to encode texts. It's pretty rare and we want to allow
# concurrent embedding, hence we retry (the specific error is
# "RuntimeError: Already borrowed" and occurs in the transformers library)
logger.error(f"Error encoding texts, retrying: {e}")
time.sleep(ENCODING_RETRY_DELAY)
return model.encode(texts, normalize_embeddings=normalize_embeddings)
@simple_log_function_time()
async def embed_text(
texts: list[str],
text_type: EmbedTextType,
model_name: str | None,
deployment_name: str | None,
max_context_length: int,
normalize_embeddings: bool,
api_key: str | None,
provider_type: EmbeddingProvider | None,
prefix: str | None,
api_url: str | None,
api_version: str | None,
reduced_dimension: int | None,
gpu_type: str = "UNKNOWN",
) -> list[Embedding]:
if not all(texts):
@@ -434,52 +114,10 @@ async def embed_text(
for text in texts:
total_chars += len(text)
if provider_type is not None:
logger.info(
f"Embedding {len(texts)} texts with {total_chars} total characters with provider: {provider_type}"
)
# Only local models should call this function now
# API providers should go directly to API server
if api_key is None:
logger.error("API key not provided for cloud model")
raise RuntimeError("API key not provided for cloud model")
if prefix:
logger.warning("Prefix provided for cloud model, which is not supported")
raise ValueError(
"Prefix string is not valid for cloud models. "
"Cloud models take an explicit text type instead."
)
async with CloudEmbedding(
api_key=api_key,
provider=provider_type,
api_url=api_url,
api_version=api_version,
) as cloud_model:
embeddings = await cloud_model.embed(
texts=texts,
model_name=model_name,
deployment_name=deployment_name,
text_type=text_type,
reduced_dimension=reduced_dimension,
)
if any(embedding is None for embedding in embeddings):
error_message = "Embeddings contain None values\n"
error_message += "Corresponding texts:\n"
error_message += "\n".join(texts)
logger.error(error_message)
raise ValueError(error_message)
elapsed = time.monotonic() - start
logger.info(
f"event=embedding_provider "
f"texts={len(texts)} "
f"chars={total_chars} "
f"provider={provider_type} "
f"elapsed={elapsed:.2f}"
)
elif model_name is not None:
if model_name is not None:
logger.info(
f"Embedding {len(texts)} texts with {total_chars} total characters with local model: {model_name}"
)
@@ -492,8 +130,8 @@ async def embed_text(
# Run CPU-bound embedding in a thread pool
embeddings_vectors = await asyncio.get_event_loop().run_in_executor(
None,
lambda: local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
lambda: _concurrent_embedding(
prefixed_texts, local_model, normalize_embeddings
),
)
embeddings = [
@@ -515,10 +153,8 @@ async def embed_text(
f"elapsed={elapsed:.2f}"
)
else:
logger.error("Neither model name nor provider specified for embedding")
raise ValueError(
"Either model name or provider must be provided to run embeddings."
)
logger.error("Model name not specified for embedding")
raise ValueError("Model name must be provided to run embeddings.")
return embeddings
@@ -533,77 +169,6 @@ async def local_rerank(query: str, docs: list[str], model_name: str) -> list[flo
)
async def cohere_rerank_api(
query: str, docs: list[str], model_name: str, api_key: str
) -> list[float]:
cohere_client = CohereAsyncClient(api_key=api_key)
response = await cohere_client.rerank(query=query, documents=docs, model=model_name)
results = response.results
sorted_results = sorted(results, key=lambda item: item.index)
return [result.relevance_score for result in sorted_results]
async def cohere_rerank_aws(
query: str,
docs: list[str],
model_name: str,
region_name: str,
aws_access_key_id: str,
aws_secret_access_key: str,
) -> list[float]:
session = aioboto3.Session(
aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key
)
async with session.client(
"bedrock-runtime", region_name=region_name
) as bedrock_client:
body = json.dumps(
{
"query": query,
"documents": docs,
"api_version": 2,
}
)
# Invoke the Bedrock model asynchronously
response = await bedrock_client.invoke_model(
modelId=model_name,
accept="application/json",
contentType="application/json",
body=body,
)
# Read the response asynchronously
response_body = json.loads(await response["body"].read())
# Extract and sort the results
results = response_body.get("results", [])
sorted_results = sorted(results, key=lambda item: item["index"])
return [result["relevance_score"] for result in sorted_results]
async def litellm_rerank(
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
) -> list[float]:
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
async with httpx.AsyncClient() as client:
response = await client.post(
api_url,
json={
"model": model_name,
"query": query,
"documents": docs,
},
headers=headers,
)
response.raise_for_status()
result = response.json()
return [
item["relevance_score"]
for item in sorted(result["results"], key=lambda x: x["index"])
]
@router.post("/bi-encoder-embed")
async def route_bi_encoder_embed(
request: Request,
@@ -615,6 +180,13 @@ async def route_bi_encoder_embed(
async def process_embed_request(
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
) -> EmbedResponse:
# Only local models should use this endpoint - API providers should make direct API calls
if embed_request.provider_type is not None:
raise ValueError(
f"Model server embedding endpoint should only be used for local models. "
f"API provider '{embed_request.provider_type}' should make direct API calls instead."
)
if not embed_request.texts:
raise HTTPException(status_code=400, detail="No texts to be embedded")
@@ -632,26 +204,12 @@ async def process_embed_request(
embeddings = await embed_text(
texts=embed_request.texts,
model_name=embed_request.model_name,
deployment_name=embed_request.deployment_name,
max_context_length=embed_request.max_context_length,
normalize_embeddings=embed_request.normalize_embeddings,
api_key=embed_request.api_key,
provider_type=embed_request.provider_type,
text_type=embed_request.text_type,
api_url=embed_request.api_url,
api_version=embed_request.api_version,
reduced_dimension=embed_request.reduced_dimension,
prefix=prefix,
gpu_type=gpu_type,
)
return EmbedResponse(embeddings=embeddings)
except AuthenticationError as e:
# Handle authentication errors consistently
logger.error(f"Authentication error: {e.provider}")
raise HTTPException(
status_code=401,
detail=f"Authentication failed: {e.message}",
)
except RateLimitError as e:
raise HTTPException(
status_code=429,
@@ -669,6 +227,13 @@ async def process_embed_request(
@router.post("/cross-encoder-scores")
async def process_rerank_request(rerank_request: RerankRequest) -> RerankResponse:
"""Cross encoders can be purely black box from the app perspective"""
# Only local models should use this endpoint - API providers should make direct API calls
if rerank_request.provider_type is not None:
raise ValueError(
f"Model server reranking endpoint should only be used for local models. "
f"API provider '{rerank_request.provider_type}' should make direct API calls instead."
)
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
@@ -680,55 +245,13 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
raise ValueError("Empty documents cannot be reranked.")
try:
if rerank_request.provider_type is None:
sim_scores = await local_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.LITELLM:
if rerank_request.api_url is None:
raise ValueError("API URL is required for LiteLLM reranking.")
sim_scores = await litellm_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
api_url=rerank_request.api_url,
model_name=rerank_request.model_name,
api_key=rerank_request.api_key,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.COHERE:
if rerank_request.api_key is None:
raise RuntimeError("Cohere Rerank Requires an API Key")
sim_scores = await cohere_rerank_api(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
api_key=rerank_request.api_key,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.BEDROCK:
if rerank_request.api_key is None:
raise RuntimeError("Bedrock Rerank Requires an API Key")
aws_access_key_id, aws_secret_access_key, aws_region = pass_aws_key(
rerank_request.api_key
)
sim_scores = await cohere_rerank_aws(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
region_name=aws_region,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
return RerankResponse(scores=sim_scores)
else:
raise ValueError(f"Unsupported provider: {rerank_request.provider_type}")
# At this point, provider_type is None, so handle local reranking
sim_scores = await local_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
)
return RerankResponse(scores=sim_scores)
except Exception as e:
logger.exception(f"Error during reranking process:\n{str(e)}")

View File

@@ -34,8 +34,8 @@ from shared_configs.configs import SENTRY_DSN
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/huggingface"
TEMP_HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/temp_huggingface"
HF_CACHE_PATH = Path(".cache/huggingface")
TEMP_HF_CACHE_PATH = Path(".cache/temp_huggingface")
transformer_logging.set_verbosity_error()

View File

@@ -70,32 +70,3 @@ def get_gpu_type() -> str:
return GPUStatus.MAC_MPS
return GPUStatus.NONE
def pass_aws_key(api_key: str) -> tuple[str, str, str]:
"""Parse AWS API key string into components.
Args:
api_key: String in format 'aws_ACCESSKEY_SECRETKEY_REGION'
Returns:
Tuple of (access_key, secret_key, region)
Raises:
ValueError: If key format is invalid
"""
if not api_key.startswith("aws"):
raise ValueError("API key must start with 'aws' prefix")
parts = api_key.split("_")
if len(parts) != 4:
raise ValueError(
f"API key must be in format 'aws_ACCESSKEY_SECRETKEY_REGION', got {len(parts) - 1} parts"
"this is an onyx specific format for formatting the aws secrets for bedrock"
)
try:
_, aws_access_key_id, aws_secret_access_key, aws_region = parts
return aws_access_key_id, aws_secret_access_key, aws_region
except Exception as e:
raise ValueError(f"Failed to parse AWS key components: {str(e)}")

View File

@@ -1,50 +0,0 @@
from collections.abc import Hashable
from datetime import datetime
from langgraph.types import Send
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
AnswerQuestionOutput,
)
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
SubQuestionAnsweringInput,
)
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
SubQuestionRetrievalState,
)
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
def parallelize_initial_sub_question_answering(
state: SubQuestionRetrievalState,
) -> list[Send | Hashable]:
"""
LangGraph edge to parallelize the initial sub-question answering. If there are no sub-questions,
we send empty answers to the initial answer generation, and that answer would be generated
solely based on the documents retrieved for the original question.
"""
edge_start_time = datetime.now()
if len(state.initial_sub_questions) > 0:
return [
Send(
"answer_query_subgraph",
SubQuestionAnsweringInput(
question=question,
question_id=make_question_id(0, question_num + 1),
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_num, question in enumerate(state.initial_sub_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]

View File

@@ -40,7 +40,7 @@ def parallelize_initial_sub_question_answering(
else:
return [
Send(
"ingest_answers",
"format_initial_sub_question_answers",
AnswerQuestionOutput(
answer_results=[],
),

View File

@@ -43,36 +43,6 @@ def route_initial_tool_choice(
return "call_tool"
def parallelize_initial_sub_question_answering(
state: MainState,
) -> list[Send | Hashable]:
edge_start_time = datetime.now()
if len(state.initial_sub_questions) > 0:
return [
Send(
"answer_query_subgraph",
SubQuestionAnsweringInput(
question=question,
question_id=make_question_id(0, question_num + 1),
log_messages=[
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
],
),
)
for question_num, question in enumerate(state.initial_sub_questions)
]
else:
return [
Send(
"ingest_answers",
AnswerQuestionOutput(
answer_results=[],
),
)
]
# Define the function that determines whether to continue or not
def continue_to_refined_answer_or_end(
state: RequireRefinemenEvalUpdate,

View File

@@ -7,6 +7,17 @@ from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.utils.special_types import JSON_ro
def remove_user_from_invited_users(email: str) -> int:
try:
store = get_kv_store()
user_emails = cast(list, store.load(KV_USER_STORE_KEY))
remaining_users = [user for user in user_emails if user != email]
store.store(KV_USER_STORE_KEY, cast(JSON_ro, remaining_users))
return len(remaining_users)
except KvKeyNotFoundError:
return 0
def get_invited_users() -> list[str]:
try:
store = get_kv_store()

View File

@@ -60,6 +60,7 @@ from onyx.auth.api_key import get_hashed_api_key_from_request
from onyx.auth.email_utils import send_forgot_password_email
from onyx.auth.email_utils import send_user_verification_email
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import remove_user_from_invited_users
from onyx.auth.schemas import AuthBackend
from onyx.auth.schemas import UserCreate
from onyx.auth.schemas import UserRole
@@ -241,7 +242,7 @@ def verify_email_domain(email: str) -> None:
status_code=status.HTTP_400_BAD_REQUEST,
detail="Email is not valid",
)
domain = email.split("@")[-1]
domain = email.split("@")[-1].lower()
if domain not in VALID_EMAIL_DOMAINS:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
@@ -350,6 +351,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
role=user_create.role,
)
user = await self.update(user_update, user)
remove_user_from_invited_users(user_create.email)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return user
@@ -527,7 +529,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
):
await self.user_db.update(user, {"oidc_expiry": None})
user.oidc_expiry = None # type: ignore
remove_user_from_invited_users(user.email)
if token:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)

View File

@@ -231,10 +231,7 @@ class DynamicTenantScheduler(PersistentScheduler):
True if equivalent, False if not."""
current_tasks = set(name for name, _ in schedule1)
new_tasks = set(schedule2.keys())
if current_tasks != new_tasks:
return False
return True
return current_tasks == new_tasks
@beat_init.connect

View File

@@ -32,7 +32,6 @@ from onyx.db.indexing_coordination import IndexingCoordination
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_document_set import RedisDocumentSet
@@ -161,7 +160,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
RedisUserGroup.reset_all(r)
RedisConnectorDelete.reset_all(r)
RedisConnectorPrune.reset_all(r)
RedisConnectorIndex.reset_all(r)
RedisConnectorStop.reset_all(r)
RedisConnectorPermissionSync.reset_all(r)
RedisConnectorExternalGroupSync.reset_all(r)

View File

@@ -1,3 +1,5 @@
from collections.abc import Generator
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from pathlib import Path
@@ -8,10 +10,12 @@ import httpx
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
from onyx.connectors.connector_runner import batched_doc_ids
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
@@ -22,12 +26,14 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
PRUNING_CHECKPOINTED_BATCH_SIZE = 32
def document_batch_to_ids(
doc_batch: list[Document],
) -> set[str]:
return {doc.id for doc in doc_batch}
doc_batch: Iterator[list[Document]],
) -> Generator[set[str], None, None]:
for doc_list in doc_batch:
yield {doc.id for doc in doc_list}
def extract_ids_from_runnable_connector(
@@ -46,33 +52,50 @@ def extract_ids_from_runnable_connector(
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
doc_batch_generator = None
doc_batch_id_generator = None
if isinstance(runnable_connector, LoadConnector):
doc_batch_generator = runnable_connector.load_from_state()
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.load_from_state()
)
elif isinstance(runnable_connector, PollConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.poll_source(start=start, end=end)
)
elif isinstance(runnable_connector, CheckpointedConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
checkpoint = runnable_connector.build_dummy_checkpoint()
checkpoint_generator = runnable_connector.load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint
)
doc_batch_id_generator = batched_doc_ids(
checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE
)
else:
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
doc_batch_processing_func = document_batch_to_ids
# this function is called per batch for rate limiting
def doc_batch_processing_func(doc_batch_ids: set[str]) -> set[str]:
return doc_batch_ids
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
doc_batch_processing_func = rate_limit_builder(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
)(lambda x: x)
for doc_batch_ids in doc_batch_id_generator:
if callback:
if callback.should_stop():
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch_ids))
if callback:
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))
callback.progress("extract_ids_from_runnable_connector", len(doc_batch_ids))
return all_connector_doc_ids

View File

@@ -193,12 +193,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
task_logger.info(
"Timed out waiting for tasks blocking deletion. Resetting blocking fences."
)
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings.id
)
redis_connector_index.reset()
redis_connector.prune.reset()
redis_connector.permissions.reset()
redis_connector.external_group_sync.reset()

View File

@@ -2,7 +2,6 @@ import multiprocessing
import os
import time
import traceback
from http import HTTPStatus
from time import sleep
import sentry_sdk
@@ -22,7 +21,7 @@ from onyx.background.celery.tasks.models import SimpleJobResult
from onyx.background.indexing.job_client import SimpleJob
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.job_client import SimpleJobException
from onyx.background.indexing.run_docfetching import run_indexing_entrypoint
from onyx.background.indexing.run_docfetching import run_docfetching_entrypoint
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.exceptions import ConnectorValidationError
@@ -34,7 +33,6 @@ from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import SENTRY_DSN
@@ -156,7 +154,6 @@ def _docfetching_task(
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector.new_index(search_settings_id)
# TODO: remove all fences, cause all signals to be set in postgres
if redis_connector.delete.fenced:
@@ -214,7 +211,7 @@ def _docfetching_task(
)
# This is where the heavy/real work happens
run_indexing_entrypoint(
run_docfetching_entrypoint(
app,
index_attempt_id,
tenant_id,
@@ -261,7 +258,7 @@ def _docfetching_task(
def process_job_result(
job: SimpleJob,
connector_source: str | None,
redis_connector_index: RedisConnectorIndex,
index_attempt_id: int,
log_builder: ConnectorIndexingLogBuilder,
) -> SimpleJobResult:
result = SimpleJobResult()
@@ -278,13 +275,11 @@ def process_job_result(
# In EKS, there is an edge case where successful tasks return exit
# code 1 in the cloud due to the set_spawn_method not sticking.
# We've since worked around this, but the following is a safe way to
# work around this issue. Basically, we ignore the job error state
# if the completion signal is OK.
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
# Workaround: check that the total number of batches is set, since this only
# happens when docfetching completed successfully
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(db_session, index_attempt_id)
if index_attempt and index_attempt.total_batches is not None:
ignore_exitcode = True
if ignore_exitcode:
@@ -300,7 +295,11 @@ def process_job_result(
if result.exit_code is not None:
result.status = IndexingWatchdogTerminalStatus.from_code(result.exit_code)
result.exception_str = job.exception()
job_level_exception = job.exception()
result.exception_str = (
f"Docfetching returned exit code {result.exit_code} "
f"with exception: {job_level_exception}"
)
return result
@@ -458,9 +457,6 @@ def docfetching_proxy_task(
)
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# Track the last time memory info was emitted
last_memory_emit_time = 0.0
@@ -487,7 +483,7 @@ def docfetching_proxy_task(
if job.done():
try:
result = process_job_result(
job, result.connector_source, redis_connector_index, log_builder
job, result.connector_source, index_attempt_id, log_builder
)
except Exception:
task_logger.exception(

View File

@@ -5,6 +5,9 @@ from sqlalchemy import update
from onyx.configs.constants import INDEXING_WORKER_HEARTBEAT_INTERVAL
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import IndexAttempt
from onyx.utils.logger import setup_logger
logger = setup_logger()
def start_heartbeat(index_attempt_id: int) -> tuple[threading.Thread, threading.Event]:
@@ -21,9 +24,15 @@ def start_heartbeat(index_attempt_id: int) -> tuple[threading.Thread, threading.
.values(heartbeat_counter=IndexAttempt.heartbeat_counter + 1)
)
db_session.commit()
logger.debug(
"Updated heartbeat counter for index attempt %s",
index_attempt_id,
)
except Exception:
# Silently continue if heartbeat fails
pass
logger.exception(
"Failed to update heartbeat counter for index attempt %s",
index_attempt_id,
)
thread = threading.Thread(target=heartbeat_loop, daemon=True)
thread.start()

View File

@@ -4,7 +4,6 @@ from collections import defaultdict
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from http import HTTPStatus
from typing import Any
from celery import shared_task
@@ -16,6 +15,8 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
@@ -66,6 +67,7 @@ from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.index_attempt import mark_attempt_partially_succeeded
from onyx.db.index_attempt import mark_attempt_succeeded
from onyx.db.indexing_coordination import CoordinationStatus
from onyx.db.indexing_coordination import INDEXING_PROGRESS_TIMEOUT_HOURS
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import IndexAttempt
from onyx.db.search_settings import get_active_search_settings_list
@@ -90,7 +92,6 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_utils import is_fence
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.middleware import make_randomized_onyx_request_id
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
@@ -98,10 +99,16 @@ from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
logger = setup_logger()
USER_FILE_INDEXING_LIMIT = 100
DOCPROCESSING_STALL_TIMEOUT_MULTIPLIER = 4
DOCPROCESSING_HEARTBEAT_TIMEOUT_MULTIPLIER = 24
# Heartbeat timeout: if no heartbeat received for 30 minutes, consider it dead
# This should be much longer than INDEXING_WORKER_HEARTBEAT_INTERVAL (30s)
HEARTBEAT_TIMEOUT_SECONDS = 30 * 60 # 30 minutes
def _get_fence_validation_block_expiration() -> int:
@@ -133,14 +140,10 @@ def validate_active_indexing_attempts(
every INDEXING_WORKER_HEARTBEAT_INTERVAL seconds.
"""
logger.info("Validating active indexing attempts")
# Heartbeat timeout: if no heartbeat received for 5 minutes, consider it dead
# This should be much longer than INDEXING_WORKER_HEARTBEAT_INTERVAL (30s)
HEARTBEAT_TIMEOUT_SECONDS = 30 * 60 # 30 minutes
heartbeat_timeout_seconds = HEARTBEAT_TIMEOUT_SECONDS
with get_session_with_current_tenant() as db_session:
cutoff_time = datetime.now(timezone.utc) - timedelta(
seconds=HEARTBEAT_TIMEOUT_SECONDS
)
# Find all active indexing attempts
active_attempts = (
@@ -199,6 +202,15 @@ def validate_active_indexing_attempts(
)
continue
if fresh_attempt.total_batches and fresh_attempt.completed_batches == 0:
heartbeat_timeout_seconds = (
HEARTBEAT_TIMEOUT_SECONDS
* DOCPROCESSING_HEARTBEAT_TIMEOUT_MULTIPLIER
)
cutoff_time = datetime.now(timezone.utc) - timedelta(
seconds=heartbeat_timeout_seconds
)
# Heartbeat hasn't advanced - check if it's been too long
if last_check_time >= cutoff_time:
task_logger.debug(
@@ -208,7 +220,7 @@ def validate_active_indexing_attempts(
# No heartbeat for too long - mark as failed
failure_reason = (
f"No heartbeat received for {HEARTBEAT_TIMEOUT_SECONDS} seconds"
f"No heartbeat received for {heartbeat_timeout_seconds} seconds"
)
task_logger.warning(
@@ -257,7 +269,7 @@ class ConnectorIndexingLogBuilder:
def monitor_indexing_attempt_progress(
attempt: IndexAttempt, tenant_id: str, db_session: Session
attempt: IndexAttempt, tenant_id: str, db_session: Session, task: Task
) -> None:
"""
TODO: rewrite this docstring
@@ -316,7 +328,9 @@ def monitor_indexing_attempt_progress(
# Check task completion using Celery
try:
check_indexing_completion(attempt.id, coordination_status, storage, tenant_id)
check_indexing_completion(
attempt.id, coordination_status, storage, tenant_id, task
)
except Exception as e:
logger.exception(
f"Failed to monitor document processing completion: "
@@ -350,6 +364,7 @@ def check_indexing_completion(
coordination_status: CoordinationStatus,
storage: DocumentBatchStorage,
tenant_id: str,
task: Task,
) -> None:
logger.info(
@@ -376,20 +391,78 @@ def check_indexing_completion(
# Update progress tracking and check for stalls
with get_session_with_current_tenant() as db_session:
# Update progress tracking
stalled_timeout_hours = INDEXING_PROGRESS_TIMEOUT_HOURS
# Index attempts that are waiting between docfetching and
# docprocessing get a generous stalling timeout
if batches_total is not None and batches_processed == 0:
stalled_timeout_hours = (
stalled_timeout_hours * DOCPROCESSING_STALL_TIMEOUT_MULTIPLIER
)
timed_out = not IndexingCoordination.update_progress_tracking(
db_session, index_attempt_id, batches_processed
db_session,
index_attempt_id,
batches_processed,
timeout_hours=stalled_timeout_hours,
)
# Check for stalls (3-6 hour timeout)
if timed_out:
logger.error(
f"Indexing attempt {index_attempt_id} has been indexing for 3-6 hours without progress. "
f"Marking it as failed."
)
mark_attempt_failed(
index_attempt_id, db_session, failure_reason="Stalled indexing"
)
# Check for stalls (3-6 hour timeout). Only applies to in-progress attempts.
attempt = get_index_attempt(db_session, index_attempt_id)
if attempt and timed_out:
if attempt.status == IndexingStatus.IN_PROGRESS:
logger.error(
f"Indexing attempt {index_attempt_id} has been indexing for "
f"{stalled_timeout_hours//2}-{stalled_timeout_hours} hours without progress. "
f"Marking it as failed."
)
mark_attempt_failed(
index_attempt_id, db_session, failure_reason="Stalled indexing"
)
elif (
attempt.status == IndexingStatus.NOT_STARTED and attempt.celery_task_id
):
# Check if the task exists in the celery queue
# This handles the case where Redis dies after task creation but before task execution
redis_celery = task.app.broker_connection().channel().client # type: ignore
task_exists = celery_find_task(
attempt.celery_task_id,
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
redis_celery,
)
unacked_task_ids = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, redis_celery
)
if not task_exists and attempt.celery_task_id not in unacked_task_ids:
# there is a race condition where the docfetching task has been taken off
# the queues (i.e. started) but the indexing attempt still has a status of
# Not Started because the switch to in progress takes like 0.1 seconds.
# sleep a bit and confirm that the attempt is still not in progress.
time.sleep(1)
attempt = get_index_attempt(db_session, index_attempt_id)
if attempt and attempt.status == IndexingStatus.NOT_STARTED:
logger.error(
f"Task {attempt.celery_task_id} attached to indexing attempt "
f"{index_attempt_id} does not exist in the queue. "
f"Marking indexing attempt as failed."
)
mark_attempt_failed(
index_attempt_id,
db_session,
failure_reason="Task not in queue",
)
else:
logger.info(
f"Indexing attempt {index_attempt_id} is {attempt.status}. 3-6 hours without heartbeat "
"but task is in the queue. Likely underprovisioned docfetching worker."
)
# Update last progress time so we won't time out again for another 3 hours
IndexingCoordination.update_progress_tracking(
db_session,
index_attempt_id,
batches_processed,
force_update_progress=True,
)
# check again on the next check_for_indexing task
# TODO: on the cloud this is currently 25 minutes at most, which
@@ -432,7 +505,14 @@ def check_indexing_completion(
ConnectorCredentialPairStatus.SCHEDULED,
ConnectorCredentialPairStatus.INITIAL_INDEXING,
]:
cc_pair.status = ConnectorCredentialPairStatus.ACTIVE
# User file connectors must be paused on success
# NOTE: _run_indexing doesn't update connectors if the index attempt is the future embedding model
# TODO: figure out why this doesn't pause connectors during swap
cc_pair.status = (
ConnectorCredentialPairStatus.PAUSED
if cc_pair.is_user_file
else ConnectorCredentialPairStatus.ACTIVE
)
db_session.commit()
# Clear repeated error state on success
@@ -449,15 +529,6 @@ def check_indexing_completion(
db_session=db_session,
)
# TODO: make it so we don't need this (might already be true)
redis_connector = RedisConnector(
tenant_id, attempt.connector_credential_pair_id
)
redis_connector_index = redis_connector.new_index(
attempt.search_settings_id
)
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
# Clean up FileStore storage (still needed for document batches during transition)
try:
logger.info(f"Cleaning up storage after indexing completion: {storage}")
@@ -811,7 +882,9 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
for attempt in active_attempts:
try:
monitor_indexing_attempt_progress(attempt, tenant_id, db_session)
monitor_indexing_attempt_progress(
attempt, tenant_id, db_session, self
)
except Exception:
task_logger.exception(f"Error monitoring attempt {attempt.id}")
@@ -1015,9 +1088,12 @@ def docprocessing_task(
# Start heartbeat for this indexing attempt
heartbeat_thread, stop_event = start_heartbeat(index_attempt_id)
try:
# Cannot use the TaskSingleton approach here because the worker is multithreaded
token = INDEX_ATTEMPT_INFO_CONTEXTVAR.set((cc_pair_id, index_attempt_id))
_docprocessing_task(index_attempt_id, cc_pair_id, tenant_id, batch_num)
finally:
stop_heartbeat(heartbeat_thread, stop_event) # Stop heartbeat before exiting
INDEX_ATTEMPT_INFO_CONTEXTVAR.reset(token)
def _docprocessing_task(
@@ -1028,9 +1104,6 @@ def _docprocessing_task(
) -> None:
start_time = time.monotonic()
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
TaskAttemptSingleton.set_cc_and_index_id(index_attempt_id, cc_pair_id)
if tenant_id:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -1085,12 +1158,8 @@ def _docprocessing_task(
f"Index attempt {index_attempt_id} is not running, status {index_attempt.status}"
)
redis_connector_index = redis_connector.new_index(
index_attempt.search_settings.id
)
cross_batch_db_lock: RedisLock = r.lock(
redis_connector_index.db_lock_key,
redis_connector.db_lock_key(index_attempt.search_settings.id),
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
thread_local=False,
)
@@ -1230,17 +1299,6 @@ def _docprocessing_task(
f"attempt={index_attempt_id} "
)
# on failure, signal completion with an error to unblock the watchdog
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(db_session, index_attempt_id)
if index_attempt and index_attempt.search_settings:
redis_connector_index = redis_connector.new_index(
index_attempt.search_settings.id
)
redis_connector_index.set_generator_complete(
HTTPStatus.INTERNAL_SERVER_ERROR.value
)
raise
finally:
if per_batch_lock and per_batch_lock.owned():

View File

@@ -47,7 +47,6 @@ from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.search_settings import get_current_search_settings
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.tag import delete_orphan_tags__no_commit
@@ -519,9 +518,6 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
search_settings = get_current_search_settings(db_session)
redis_connector.new_index(search_settings.id)
callback = PruneCallback(
0,
redis_connector,

View File

@@ -153,10 +153,9 @@ class SimpleJob:
if self._exception is None and self.queue and not self.queue.empty():
self._exception = self.queue.get() # Get exception from queue
if self._exception:
return self._exception
return f"Job with ID '{self.id}' did not report an exception."
return (
self._exception or f"Job with ID '{self.id}' did not report an exception."
)
class SimpleJobClient:

View File

@@ -71,13 +71,13 @@ from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.middleware import make_randomized_onyx_request_id
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
logger = setup_logger(propagate=False)
@@ -226,8 +226,12 @@ def _check_connector_and_attempt_status(
raise ConnectorStopSignal(f"Index attempt {index_attempt_id} was canceled")
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
error_str = ""
if index_attempt_loop.error_msg:
error_str = f" Original error: {index_attempt_loop.error_msg}"
raise RuntimeError(
f"Index Attempt is not running, status is {index_attempt_loop.status}"
f"Index Attempt is not running, status is {index_attempt_loop.status}.{error_str}"
)
if index_attempt_loop.celery_task_id is None:
@@ -267,7 +271,7 @@ def _check_failure_threshold(
# NOTE: this is the old run_indexing function that the new decoupled approach
# is based on. Leaving this for comparison purposes, but if you see this comment
# has been here for >1 month, please delete this function.
# has been here for >2 month, please delete this function.
def _run_indexing(
db_session: Session,
index_attempt_id: int,
@@ -832,7 +836,7 @@ def _run_indexing(
)
def run_indexing_entrypoint(
def run_docfetching_entrypoint(
app: Celery,
index_attempt_id: int,
tenant_id: str,
@@ -847,8 +851,8 @@ def run_indexing_entrypoint(
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
TaskAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
token = INDEX_ATTEMPT_INFO_CONTEXTVAR.set(
(connector_credential_pair_id, index_attempt_id)
)
with get_session_with_current_tenant() as db_session:
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
@@ -886,6 +890,8 @@ def run_indexing_entrypoint(
f"credentials='{credential_id}'"
)
INDEX_ATTEMPT_INFO_CONTEXTVAR.reset(token)
def connector_document_extraction(
app: Celery,
@@ -1350,6 +1356,9 @@ def reissue_old_batches(
)
path_info = batch_storage.extract_path_info(batch_id)
if path_info is None:
logger.warning(
f"Could not extract path info from batch {batch_id}, skipping"
)
continue
if path_info.cc_pair_id != cc_pair_id:
raise RuntimeError(f"Batch {batch_id} is not for cc pair {cc_pair_id}")

View File

@@ -108,7 +108,11 @@ _VALID_EMAIL_DOMAINS_STR = (
os.environ.get("VALID_EMAIL_DOMAINS", "") or _VALID_EMAIL_DOMAIN
)
VALID_EMAIL_DOMAINS = (
[domain.strip() for domain in _VALID_EMAIL_DOMAINS_STR.split(",")]
[
domain.strip().lower()
for domain in _VALID_EMAIL_DOMAINS_STR.split(",")
if domain.strip()
]
if _VALID_EMAIL_DOMAINS_STR
else []
)
@@ -121,6 +125,8 @@ OAUTH_CLIENT_SECRET = (
os.environ.get("OAUTH_CLIENT_SECRET", os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET"))
or ""
)
# OpenID Connect configuration URL for Okta Profile Tool and other OIDC integrations
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL") or ""
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
@@ -359,6 +365,12 @@ POLL_CONNECTOR_OFFSET = 30 # Minutes overlap between poll windows
# only very select connectors are enabled and admins cannot add other connector types
ENABLED_CONNECTOR_TYPES = os.environ.get("ENABLED_CONNECTOR_TYPES") or ""
# If set to true, curators can only access and edit assistants that they created
CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS = (
os.environ.get("CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS", "").lower()
== "true"
)
# Some calls to get information on expert users are quite costly especially with rate limiting
# Since experts are not used in the actual user experience, currently it is turned off
# for some connectors
@@ -611,6 +623,17 @@ AVERAGE_SUMMARY_EMBEDDINGS = (
MAX_TOKENS_FOR_FULL_INCLUSION = 4096
#####
# Tool Configs
#####
OKTA_PROFILE_TOOL_ENABLED = (
os.environ.get("OKTA_PROFILE_TOOL_ENABLED", "").lower() == "true"
)
# API token for SSWS auth to Okta Admin API. If set, Users API will be used to enrich profile.
OKTA_API_TOKEN = os.environ.get("OKTA_API_TOKEN") or ""
#####
# Miscellaneous
#####

View File

@@ -25,6 +25,28 @@ TimeRange = tuple[datetime, datetime]
CT = TypeVar("CT", bound=ConnectorCheckpoint)
def batched_doc_ids(
checkpoint_connector_generator: CheckpointOutput[CT],
batch_size: int,
) -> Generator[set[str], None, None]:
batch: set[str] = set()
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
checkpoint_connector_generator
):
if document is not None:
batch.add(document.id)
elif (
failure and failure.failed_document and failure.failed_document.document_id
):
batch.add(failure.failed_document.document_id)
if len(batch) >= batch_size:
yield batch
batch = set()
if len(batch) > 0:
yield batch
class CheckpointOutputWrapper(Generic[CT]):
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format,

View File

@@ -24,6 +24,7 @@ from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -72,6 +73,7 @@ def _process_file(
file: IO[Any],
metadata: dict[str, Any] | None,
pdf_pass: str | None,
file_type: str | None,
) -> list[Document]:
"""
Process a file and return a list of Documents.
@@ -148,6 +150,7 @@ def _process_file(
file=file,
file_name=file_name,
pdf_pass=pdf_pass,
content_type=file_type,
)
# Each file may have file-specific ONYX_METADATA https://docs.onyx.app/connectors/file
@@ -229,21 +232,18 @@ class LocalFileConnector(LoadConnector):
# Note: file_names is a required parameter, but should not break backwards compatibility.
# If add_file_names migration is not run, old file connector configs will not have file_names.
# This is fine because the configs are not re-used to instantiate the connector.
# file_names is only used for display purposes in the UI and file_locations is used as a fallback.
def __init__(
self,
file_locations: list[Path | str],
file_names: list[
str
], # Must accept this parameter as connector_specific_config is unpacked as args
zip_metadata: dict[str, Any],
file_names: list[str] | None = None,
zip_metadata: dict[str, Any] | None = None,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.file_locations = [str(loc) for loc in file_locations]
self.batch_size = batch_size
self.pdf_pass: str | None = None
self.zip_metadata = zip_metadata
self.zip_metadata = zip_metadata or {}
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.pdf_pass = credentials.get("pdf_password")
@@ -278,6 +278,7 @@ class LocalFileConnector(LoadConnector):
file=file_io,
metadata=metadata,
pdf_pass=self.pdf_pass,
file_type=file_record.file_type,
)
documents.extend(new_docs)

View File

@@ -119,7 +119,19 @@ class LoopioConnector(LoadConnector, PollConnector):
part["name"] for part in entry["location"].values() if part
)
answer = parse_html_page_basic(entry.get("answer", {}).get("text", ""))
answer_text = entry.get("answer", {}).get("text", "")
if not answer_text:
logger.warning(
f"The Library entry {entry['id']} has no answer text. Skipping."
)
continue
try:
answer = parse_html_page_basic(answer_text)
except Exception as e:
logger.error(f"Error parsing HTML for entry {entry['id']}: {e}")
continue
questions = [
question.get("text").replace("\xa0", " ")
for question in entry["questions"]

View File

@@ -1,5 +1,6 @@
import csv
import gc
import json
import os
import sys
import tempfile
@@ -28,8 +29,12 @@ from onyx.connectors.salesforce.doc_conversion import ID_PREFIX
from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_sqlite_db_path
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -38,27 +43,27 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
_DEFAULT_PARENT_OBJECT_TYPES = [ACCOUNT_OBJECT_TYPE]
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
"Opportunity": {
"Account": "account",
ACCOUNT_OBJECT_TYPE: "account",
"FiscalQuarter": "fiscal_quarter",
"FiscalYear": "fiscal_year",
"IsClosed": "is_closed",
"Name": "name",
NAME_FIELD: "name",
"StageName": "stage_name",
"Type": "type",
"Amount": "amount",
"CloseDate": "close_date",
"Probability": "probability",
"CreatedDate": "created_date",
"LastModifiedDate": "last_modified_date",
MODIFIED_FIELD: "last_modified_date",
},
"Contact": {
"Account": "account",
ACCOUNT_OBJECT_TYPE: "account",
"CreatedDate": "created_date",
"LastModifiedDate": "last_modified_date",
MODIFIED_FIELD: "last_modified_date",
},
}
@@ -74,19 +79,77 @@ class SalesforceConnectorContext:
parent_to_child_types: dict[str, set[str]] = {} # map from parent to child types
child_to_parent_types: dict[str, set[str]] = {} # map from child to parent types
parent_reference_fields_by_type: dict[str, dict[str, list[str]]] = {}
type_to_queryable_fields: dict[str, list[str]] = {}
type_to_queryable_fields: dict[str, set[str]] = {}
prefix_to_type: dict[str, str] = {} # infer the object type of an id immediately
parent_to_child_relationships: dict[str, set[str]] = (
{}
) # map from parent to child relationships
parent_to_relationship_queryable_fields: dict[str, dict[str, list[str]]] = (
parent_to_relationship_queryable_fields: dict[str, dict[str, set[str]]] = (
{}
) # map from relationship to queryable fields
parent_child_names_to_relationships: dict[str, str] = {}
def _extract_fields_and_associations_from_config(
config: dict[str, Any], object_type: str
) -> tuple[list[str] | None, dict[str, list[str]]]:
"""
Extract fields and associations for a specific object type from custom config.
Returns:
tuple of (fields_list, associations_dict)
- fields_list: List of fields to query, or None if not specified (use all)
- associations_dict: Dict mapping association names to their config
"""
if object_type not in config:
return None, {}
obj_config = config[object_type]
fields = obj_config.get("fields")
associations = obj_config.get("associations", {})
return fields, associations
def _validate_custom_query_config(config: dict[str, Any]) -> None:
"""
Validate the structure of the custom query configuration.
"""
for object_type, obj_config in config.items():
if not isinstance(obj_config, dict):
raise ValueError(
f"top level object {object_type} must be mapped to a dictionary"
)
# Check if fields is a list when present
if "fields" in obj_config:
if not isinstance(obj_config["fields"], list):
raise ValueError("if fields key exists, value must be a list")
for v in obj_config["fields"]:
if not isinstance(v, str):
raise ValueError(f"if fields list value {v} is not a string")
# Check if associations is a dict when present
if "associations" in obj_config:
if not isinstance(obj_config["associations"], dict):
raise ValueError(
"if associations key exists, value must be a dictionary"
)
for assoc_name, assoc_fields in obj_config["associations"].items():
if not isinstance(assoc_fields, list):
raise ValueError(
f"associations list value {assoc_fields} for key {assoc_name} is not a list"
)
for v in assoc_fields:
if not isinstance(v, str):
raise ValueError(
f"if associations list value {v} is not a string"
)
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
"""Approach outline
@@ -134,14 +197,25 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
self,
batch_size: int = INDEX_BATCH_SIZE,
requested_objects: list[str] = [],
custom_query_config: str | None = None,
) -> None:
self.batch_size = batch_size
self._sf_client: OnyxSalesforce | None = None
self.parent_object_list = (
[obj.capitalize() for obj in requested_objects]
if requested_objects
else _DEFAULT_PARENT_OBJECT_TYPES
)
# Validate and store custom query config
if custom_query_config:
config_json = json.loads(custom_query_config)
self.custom_query_config: dict[str, Any] | None = config_json
# If custom query config is provided, use the object types from it
self.parent_object_list = list(config_json.keys())
else:
self.custom_query_config = None
# Use the traditional requested_objects approach
self.parent_object_list = (
[obj.strip().capitalize() for obj in requested_objects]
if requested_objects
else _DEFAULT_PARENT_OBJECT_TYPES
)
def load_credentials(
self,
@@ -187,7 +261,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
@staticmethod
def _download_object_csvs(
all_types_to_filter: dict[str, bool],
queryable_fields_by_type: dict[str, list[str]],
queryable_fields_by_type: dict[str, set[str]],
directory: str,
sf_client: OnyxSalesforce,
start: SecondsSinceUnixEpoch | None = None,
@@ -325,9 +399,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# all_types.update(child_types.keys())
# # Always want to make sure user is grabbed for permissioning purposes
# all_types.add("User")
# all_types.add(USER_OBJECT_TYPE)
# # Always want to make sure account is grabbed for reference purposes
# all_types.add("Account")
# all_types.add(ACCOUNT_OBJECT_TYPE)
# logger.info(f"All object types: num={len(all_types)} list={all_types}")
@@ -351,7 +425,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# all_types.update(child_types)
# # Always want to make sure user is grabbed for permissioning purposes
# all_types.add("User")
# all_types.add(USER_OBJECT_TYPE)
# logger.info(f"All object types: num={len(all_types)} list={all_types}")
@@ -364,7 +438,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
) -> GenerateDocumentsOutput:
type_to_processed: dict[str, int] = {}
logger.info("_fetch_from_salesforce starting.")
logger.info("_fetch_from_salesforce starting (full sync).")
if not self._sf_client:
raise RuntimeError("self._sf_client is None!")
@@ -548,7 +622,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
) -> GenerateDocumentsOutput:
type_to_processed: dict[str, int] = {}
logger.info("_fetch_from_salesforce starting.")
logger.info("_fetch_from_salesforce starting (delta sync).")
if not self._sf_client:
raise RuntimeError("self._sf_client is None!")
@@ -677,7 +751,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
try:
last_modified_by_id = record["LastModifiedById"]
user_record = self.sf_client.query_object(
"User", last_modified_by_id, ctx.type_to_queryable_fields
USER_OBJECT_TYPE,
last_modified_by_id,
ctx.type_to_queryable_fields,
)
if user_record:
primary_owner = BasicExpertInfo.from_dict(user_record)
@@ -792,7 +868,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
parent_reference_fields_by_type: dict[str, dict[str, list[str]]] = (
{}
) # for a given object, the fields reference parent objects
type_to_queryable_fields: dict[str, list[str]] = {}
type_to_queryable_fields: dict[str, set[str]] = {}
prefix_to_type: dict[str, str] = {}
parent_to_child_relationships: dict[str, set[str]] = (
@@ -802,15 +878,13 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# relationship keys are formatted as "parent__relationship"
# we have to do this because relationship names are not unique!
# values are a dict of relationship names to a list of queryable fields
parent_to_relationship_queryable_fields: dict[str, dict[str, list[str]]] = {}
parent_to_relationship_queryable_fields: dict[str, dict[str, set[str]]] = {}
parent_child_names_to_relationships: dict[str, str] = {}
full_sync = False
if start is None and end is None:
full_sync = True
full_sync = start is None and end is None
# Step 1 - make a list of all the types to download (parent + direct child + "User")
# Step 1 - make a list of all the types to download (parent + direct child + USER_OBJECT_TYPE)
# prefixes = {}
global_description = sf_client.describe()
@@ -831,16 +905,63 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
logger.info(f"Parent object types: num={len(parent_types)} list={parent_types}")
for parent_type in parent_types:
# parent_onyx_sf_type = OnyxSalesforceType(parent_type, sf_client)
type_to_queryable_fields[parent_type] = (
sf_client.get_queryable_fields_by_type(parent_type)
)
child_types_working = sf_client.get_children_of_sf_type(parent_type)
logger.debug(
f"Found {len(child_types_working)} child types for {parent_type}"
)
custom_fields: list[str] | None = []
associations_config: dict[str, list[str]] | None = None
# parent_to_child_relationships[parent_type] = child_types_working
# Set queryable fields for parent type
if self.custom_query_config:
custom_fields, associations_config = (
_extract_fields_and_associations_from_config(
self.custom_query_config, parent_type
)
)
custom_fields = custom_fields or []
# Get custom fields for parent type
field_set = set(custom_fields)
# these are expected and used during doc conversion
field_set.add(NAME_FIELD)
field_set.add(MODIFIED_FIELD)
# Use only the specified fields
type_to_queryable_fields[parent_type] = field_set
logger.info(f"Using custom fields for {parent_type}: {field_set}")
else:
# Use all queryable fields
type_to_queryable_fields[parent_type] = (
sf_client.get_queryable_fields_by_type(parent_type)
)
logger.info(f"Using all fields for {parent_type}")
child_types_all = sf_client.get_children_of_sf_type(parent_type)
logger.debug(f"Found {len(child_types_all)} child types for {parent_type}")
logger.debug(f"child types: {child_types_all}")
child_types_working = child_types_all.copy()
if associations_config is not None:
child_types_working = {
k: v for k, v in child_types_all.items() if k in associations_config
}
any_not_found = False
for k in associations_config:
if k not in child_types_working:
any_not_found = True
logger.warning(f"Association {k} not found in {parent_type}")
if any_not_found:
queryable_fields = sf_client.get_queryable_fields_by_type(
parent_type
)
raise RuntimeError(
f"Associations {associations_config} not found in {parent_type} "
"make sure your parent-child associations are in the right order"
# f"with child objects {child_types_all}"
# f" and fields {queryable_fields}"
)
parent_to_child_relationships[parent_type] = set()
parent_to_child_types[parent_type] = set()
parent_to_relationship_queryable_fields[parent_type] = {}
for child_type, child_relationship in child_types_working.items():
child_type = cast(str, child_type)
@@ -848,8 +969,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# onyx_sf_type = OnyxSalesforceType(child_type, sf_client)
# map parent name to child name
if parent_type not in parent_to_child_types:
parent_to_child_types[parent_type] = set()
parent_to_child_types[parent_type].add(child_type)
# reverse map child name to parent name
@@ -858,19 +977,25 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
child_to_parent_types[child_type].add(parent_type)
# map parent name to child relationship
if parent_type not in parent_to_child_relationships:
parent_to_child_relationships[parent_type] = set()
parent_to_child_relationships[parent_type].add(child_relationship)
# map relationship to queryable fields of the target table
queryable_fields = sf_client.get_queryable_fields_by_type(child_type)
if config_fields := (
associations_config and associations_config.get(child_type)
):
field_set = set(config_fields)
# these are expected and used during doc conversion
field_set.add(NAME_FIELD)
field_set.add(MODIFIED_FIELD)
queryable_fields = field_set
else:
queryable_fields = sf_client.get_queryable_fields_by_type(
child_type
)
if child_relationship in parent_to_relationship_queryable_fields:
raise RuntimeError(f"{child_relationship=} already exists")
if parent_type not in parent_to_relationship_queryable_fields:
parent_to_relationship_queryable_fields[parent_type] = {}
parent_to_relationship_queryable_fields[parent_type][
child_relationship
] = queryable_fields
@@ -894,14 +1019,22 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
all_types.update(child_types)
# NOTE(rkuo): should this be an implicit parent type?
all_types.add("User") # Always add User for permissioning purposes
all_types.add("Account") # Always add Account for reference purposes
all_types.add(USER_OBJECT_TYPE) # Always add User for permissioning purposes
all_types.add(ACCOUNT_OBJECT_TYPE) # Always add Account for reference purposes
logger.info(f"All object types: num={len(all_types)} list={all_types}")
# Ensure User and Account have queryable fields if they weren't already processed
essential_types = [USER_OBJECT_TYPE, ACCOUNT_OBJECT_TYPE]
for essential_type in essential_types:
if essential_type not in type_to_queryable_fields:
type_to_queryable_fields[essential_type] = (
sf_client.get_queryable_fields_by_type(essential_type)
)
# 1.1 - Detect all fields in child types which reference a parent type.
# build dicts to detect relationships between parent and child
for child_type in child_types:
for child_type in child_types.union(essential_types):
# onyx_sf_type = OnyxSalesforceType(child_type, sf_client)
parent_reference_fields = sf_client.get_parent_reference_fields(
child_type, parent_types
@@ -1003,6 +1136,32 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
yield doc_metadata_list
def validate_connector_settings(self) -> None:
"""
Validate that the Salesforce credentials and connector settings are correct.
Specifically checks that we can make an authenticated request to Salesforce.
"""
try:
# Attempt to fetch a small batch of objects (arbitrary endpoint) to verify credentials
self.sf_client.describe()
except Exception as e:
raise ConnectorMissingCredentialError(
"Failed to validate Salesforce credentials. Please check your"
f"credentials and try again. Error: {e}"
)
if self.custom_query_config:
try:
_validate_custom_query_config(self.custom_query_config)
except Exception as e:
raise ConnectorMissingCredentialError(
"Failed to validate Salesforce custom query config. Please check your"
f"config and try again. Error: {e}"
)
logger.info("Salesforce credentials validated successfully.")
# @override
# def load_from_checkpoint(
# self,
@@ -1032,7 +1191,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
if __name__ == "__main__":
connector = SalesforceConnector(requested_objects=["Account"])
connector = SalesforceConnector(requested_objects=[ACCOUNT_OBJECT_TYPE])
connector.load_credentials(
{

View File

@@ -10,6 +10,8 @@ from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.utils.logger import setup_logger
@@ -140,7 +142,7 @@ def _extract_primary_owner(
first_name=user_data.get("FirstName"),
last_name=user_data.get("LastName"),
email=user_data.get("Email"),
display_name=user_data.get("Name"),
display_name=user_data.get(NAME_FIELD),
)
# Check if all fields are None
@@ -166,8 +168,8 @@ def convert_sf_query_result_to_doc(
"""Generates a yieldable Document from query results"""
base_url = f"https://{sf_client.sf_instance}"
extracted_doc_updated_at = time_str_to_utc(record["LastModifiedDate"])
extracted_semantic_identifier = record.get("Name", "Unknown Object")
extracted_doc_updated_at = time_str_to_utc(record[MODIFIED_FIELD])
extracted_semantic_identifier = record.get(NAME_FIELD, "Unknown Object")
sections = [_extract_section(record, f"{base_url}/{record_id}")]
for child_record_key, child_record in child_records.items():
@@ -205,8 +207,8 @@ def convert_sf_object_to_doc(
salesforce_id = object_dict["Id"]
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
base_url = f"https://{sf_instance}"
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
extracted_doc_updated_at = time_str_to_utc(object_dict[MODIFIED_FIELD])
extracted_semantic_identifier = object_dict.get(NAME_FIELD, "Unknown Object")
sections = [_extract_section(sf_object.data, f"{base_url}/{sf_object.id}")]
for id in sf_db.get_child_ids(sf_object.id):

View File

@@ -60,7 +60,7 @@ class OnyxSalesforce(Salesforce):
return True
for suffix in SALESFORCE_BLACKLISTED_SUFFIXES:
if object_type_lower.endswith(prefix):
if object_type_lower.endswith(suffix):
return True
return False
@@ -112,7 +112,7 @@ class OnyxSalesforce(Salesforce):
object_id: str,
sf_type: str,
child_relationships: list[str],
relationships_to_fields: dict[str, list[str]],
relationships_to_fields: dict[str, set[str]],
) -> str:
"""Returns a SOQL query given the object id, type and child relationships.
@@ -148,7 +148,7 @@ class OnyxSalesforce(Salesforce):
self,
object_type: str,
object_id: str,
type_to_queryable_fields: dict[str, list[str]],
type_to_queryable_fields: dict[str, set[str]],
) -> dict[str, Any] | None:
record: dict[str, Any] = {}
@@ -172,7 +172,7 @@ class OnyxSalesforce(Salesforce):
object_id: str,
sf_type: str,
child_relationships: list[str],
relationships_to_fields: dict[str, list[str]],
relationships_to_fields: dict[str, set[str]],
) -> dict[str, dict[str, Any]]:
"""There's a limit on the number of subqueries we can put in a single query."""
child_records: dict[str, dict[str, Any]] = {}
@@ -264,10 +264,10 @@ class OnyxSalesforce(Salesforce):
time.sleep(3)
raise
def get_queryable_fields_by_type(self, name: str) -> list[str]:
def get_queryable_fields_by_type(self, name: str) -> set[str]:
object_description = self.describe_type(name)
if object_description is None:
return []
return set()
fields: list[dict[str, Any]] = object_description["fields"]
valid_fields: set[str] = set()
@@ -286,7 +286,7 @@ class OnyxSalesforce(Salesforce):
if field_name:
valid_fields.add(field_name)
return list(valid_fields - field_names_to_remove)
return valid_fields - field_names_to_remove
def get_children_of_sf_type(self, sf_type: str) -> dict[str, str]:
"""Returns a dict of child object names to relationship names.

View File

@@ -14,6 +14,7 @@ from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -54,12 +55,12 @@ def _build_created_date_time_filter_for_salesforce(
def _make_time_filter_for_sf_type(
queryable_fields: list[str],
queryable_fields: set[str],
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> str | None:
if "LastModifiedDate" in queryable_fields:
if MODIFIED_FIELD in queryable_fields:
return _build_last_modified_time_filter_for_salesforce(start, end)
if "CreatedDate" in queryable_fields:
@@ -69,14 +70,14 @@ def _make_time_filter_for_sf_type(
def _make_time_filtered_query(
queryable_fields: list[str], sf_type: str, time_filter: str
queryable_fields: set[str], sf_type: str, time_filter: str
) -> str:
query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}"
return query
def get_object_by_id_query(
object_id: str, sf_type: str, queryable_fields: list[str]
object_id: str, sf_type: str, queryable_fields: set[str]
) -> str:
query = (
f"SELECT {', '.join(queryable_fields)} FROM {sf_type} WHERE Id = '{object_id}'"
@@ -193,7 +194,7 @@ def _bulk_retrieve_from_salesforce(
def fetch_all_csvs_in_parallel(
sf_client: Salesforce,
all_types_to_filter: dict[str, bool],
queryable_fields_by_type: dict[str, list[str]],
queryable_fields_by_type: dict[str, set[str]],
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
target_dir: str,

View File

@@ -8,11 +8,15 @@ from pathlib import Path
from typing import Any
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
from onyx.connectors.salesforce.utils import validate_salesforce_id
from onyx.utils.logger import setup_logger
from shared_configs.utils import batch_list
logger = setup_logger()
@@ -567,7 +571,7 @@ class OnyxSalesforceSQLite:
uncommitted_rows = 0
# If we're updating User objects, update the email map
if object_type == "User":
if object_type == USER_OBJECT_TYPE:
OnyxSalesforceSQLite._update_user_email_map(cursor)
return updated_ids
@@ -619,7 +623,7 @@ class OnyxSalesforceSQLite:
with self._conn:
cursor = self._conn.cursor()
# Get the object data and account data
if object_type == "Account" or isChild:
if object_type == ACCOUNT_OBJECT_TYPE or isChild:
cursor.execute(
"SELECT data FROM salesforce_objects WHERE id = ?", (object_id,)
)
@@ -638,7 +642,7 @@ class OnyxSalesforceSQLite:
data = json.loads(result[0][0])
if object_type != "Account":
if object_type != ACCOUNT_OBJECT_TYPE:
# convert any account ids of the relationships back into data fields, with name
for row in result:
@@ -647,14 +651,14 @@ class OnyxSalesforceSQLite:
if len(row) < 3:
continue
if row[1] and row[2] and row[2] == "Account":
if row[1] and row[2] and row[2] == ACCOUNT_OBJECT_TYPE:
data["AccountId"] = row[1]
cursor.execute(
"SELECT data FROM salesforce_objects WHERE id = ?",
(row[1],),
)
account_data = json.loads(cursor.fetchone()[0])
data["Account"] = account_data.get("Name", "")
data[ACCOUNT_OBJECT_TYPE] = account_data.get(NAME_FIELD, "")
return SalesforceObject(id=object_id, type=object_type, data=data)

View File

@@ -2,6 +2,11 @@ import os
from dataclasses import dataclass
from typing import Any
NAME_FIELD = "Name"
MODIFIED_FIELD = "LastModifiedDate"
ACCOUNT_OBJECT_TYPE = "Account"
USER_OBJECT_TYPE = "User"
@dataclass
class SalesforceObject:

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,38 @@
from typing import Any
from office365.graph_client import GraphClient # type: ignore[import-untyped]
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
from onyx.connectors.models import ExternalAccess
from onyx.utils.variable_functionality import (
fetch_versioned_implementation_with_fallback,
)
def get_sharepoint_external_access(
ctx: ClientContext,
graph_client: GraphClient,
drive_item: DriveItem | None = None,
drive_name: str | None = None,
site_page: dict[str, Any] | None = None,
add_prefix: bool = False,
) -> ExternalAccess:
if drive_item and drive_item.id is None:
raise ValueError("DriveItem ID is required")
# Get external access using the EE implementation
def noop_fallback(*args: Any, **kwargs: Any) -> ExternalAccess:
return ExternalAccess.empty()
get_external_access_func = fetch_versioned_implementation_with_fallback(
"onyx.external_permissions.sharepoint.permission_utils",
"get_external_access_from_sharepoint",
fallback=noop_fallback,
)
external_access = get_external_access_func(
ctx, graph_client, drive_name, drive_item, site_page, add_prefix
)
return external_access

View File

@@ -267,6 +267,7 @@ class IndexingCoordination:
index_attempt_id: int,
current_batches_completed: int,
timeout_hours: int = INDEXING_PROGRESS_TIMEOUT_HOURS,
force_update_progress: bool = False,
) -> bool:
"""
Update progress tracking for stall detection.
@@ -281,7 +282,8 @@ class IndexingCoordination:
current_time = get_db_current_time(db_session)
# No progress - check if this is the first time tracking
if attempt.last_progress_time is None:
# or if the caller wants to simulate guaranteed progress
if attempt.last_progress_time is None or force_update_progress:
# First time tracking - initialize
attempt.last_progress_time = current_time
attempt.last_batches_completed_count = current_batches_completed

View File

@@ -1293,6 +1293,7 @@ class Tag(Base):
source: Mapped[DocumentSource] = mapped_column(
Enum(DocumentSource, native_enum=False)
)
is_list: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
documents = relationship(
"Document",
@@ -1302,7 +1303,11 @@ class Tag(Base):
__table_args__ = (
UniqueConstraint(
"tag_key", "tag_value", "source", name="_tag_key_value_source_uc"
"tag_key",
"tag_value",
"source",
"is_list",
name="_tag_key_value_source_list_uc",
),
)
@@ -1685,12 +1690,14 @@ class IndexAttempt(Base):
# can be taken to the FileStore to grab the actual checkpoint value
checkpoint_pointer: Mapped[str | None] = mapped_column(String, nullable=True)
# NEW: Database-based coordination fields (replacing Redis fencing)
# Database-based coordination fields (replacing Redis fencing)
celery_task_id: Mapped[str | None] = mapped_column(String, nullable=True)
cancellation_requested: Mapped[bool] = mapped_column(Boolean, default=False)
# NEW: Batch coordination fields (replacing FileStore state)
# Batch coordination fields
# Once this is set, docfetching has completed
total_batches: Mapped[int | None] = mapped_column(Integer, nullable=True)
# batches that are fully indexed (i.e. have completed docfetching and docprocessing)
completed_batches: Mapped[int] = mapped_column(Integer, default=0)
# TODO: unused, remove this column
total_failures_batch_level: Mapped[int] = mapped_column(Integer, default=0)
@@ -1702,7 +1709,7 @@ class IndexAttempt(Base):
)
last_batches_completed_count: Mapped[int] = mapped_column(Integer, default=0)
# NEW: Heartbeat tracking for worker liveness detection
# Heartbeat tracking for worker liveness detection
heartbeat_counter: Mapped[int] = mapped_column(Integer, default=0)
last_heartbeat_value: Mapped[int] = mapped_column(Integer, default=0)
last_heartbeat_time: Mapped[datetime.datetime | None] = mapped_column(

View File

@@ -15,6 +15,7 @@ from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.chat_configs import BING_API_KEY
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
@@ -96,6 +97,14 @@ def _add_user_filters(
where_clause = Persona.is_public == True # noqa: E712
return stmt.where(where_clause)
# If curator ownership restriction is enabled, curators can only access their own assistants
if CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS and user.role in [
UserRole.CURATOR,
UserRole.GLOBAL_CURATOR,
]:
where_clause = (Persona.user_id == user.id) | (Persona.user_id.is_(None))
return stmt.where(where_clause)
where_clause = User__UserGroup.user_id == user.id
if user.role == UserRole.CURATOR and get_editable:
where_clause &= User__UserGroup.is_curator == True # noqa: E712

View File

@@ -47,11 +47,12 @@ def create_or_add_document_tag(
Tag.tag_key == tag_key,
Tag.tag_value == tag_value,
Tag.source == source,
Tag.is_list.is_(False),
)
tag = db_session.execute(tag_stmt).scalar_one_or_none()
if not tag:
tag = Tag(tag_key=tag_key, tag_value=tag_value, source=source)
tag = Tag(tag_key=tag_key, tag_value=tag_value, source=source, is_list=False)
db_session.add(tag)
if tag not in document.tags:
@@ -82,6 +83,7 @@ def create_or_add_document_tag_list(
Tag.tag_key == tag_key,
Tag.tag_value.in_(valid_tag_values),
Tag.source == source,
Tag.is_list.is_(True),
)
existing_tags = list(db_session.execute(existing_tags_stmt).scalars().all())
existing_tag_values = {tag.tag_value for tag in existing_tags}
@@ -89,7 +91,9 @@ def create_or_add_document_tag_list(
new_tags = []
for tag_value in valid_tag_values:
if tag_value not in existing_tag_values:
new_tag = Tag(tag_key=tag_key, tag_value=tag_value, source=source)
new_tag = Tag(
tag_key=tag_key, tag_value=tag_value, source=source, is_list=True
)
db_session.add(new_tag)
new_tags.append(new_tag)
existing_tag_values.add(tag_value)
@@ -109,6 +113,45 @@ def create_or_add_document_tag_list(
return all_tags
def upsert_document_tags(
document_id: str,
source: DocumentSource,
metadata: dict[str, str | list[str]],
db_session: Session,
) -> list[Tag]:
document = db_session.get(Document, document_id)
if not document:
raise ValueError("Invalid Document, cannot attach Tags")
old_tag_ids: set[int] = {tag.id for tag in document.tags}
new_tags: list[Tag] = []
new_tag_ids: set[int] = set()
for k, v in metadata.items():
if isinstance(v, list):
new_tags.extend(
create_or_add_document_tag_list(k, v, source, document_id, db_session)
)
new_tag_ids.update({tag.id for tag in new_tags})
continue
new_tag = create_or_add_document_tag(k, v, source, document_id, db_session)
if new_tag:
new_tag_ids.add(new_tag.id)
new_tags.append(new_tag)
delete_tags = old_tag_ids - new_tag_ids
if delete_tags:
delete_stmt = delete(Document__Tag).where(
Document__Tag.document_id == document_id,
Document__Tag.tag_id.in_(delete_tags),
)
db_session.execute(delete_stmt)
db_session.commit()
return new_tags
def find_tags(
tag_key_prefix: str | None,
tag_value_prefix: str | None,
@@ -147,24 +190,37 @@ def find_tags(
def get_structured_tags_for_document(
document_id: str, db_session: Session
) -> dict[str, str | list[str]]:
"""Essentially returns the document metadata from postgres."""
document = db_session.get(Document, document_id)
if not document:
raise ValueError("Invalid Document, cannot find tags")
document_metadata: dict[str, Any] = {}
for tag in document.tags:
if tag.tag_key in document_metadata:
# NOTE: we convert to list if there are multiple values for the same key
# Thus, it won't know if a tag is a list if it only contains one value
if isinstance(document_metadata[tag.tag_key], str):
document_metadata[tag.tag_key] = [
document_metadata[tag.tag_key],
tag.tag_value,
]
else:
document_metadata[tag.tag_key].append(tag.tag_value)
else:
document_metadata[tag.tag_key] = tag.tag_value
if tag.is_list:
document_metadata.setdefault(tag.tag_key, [])
# should always be a list (if tag.is_list is always True for this key), but just in case
if not isinstance(document_metadata[tag.tag_key], list):
logger.warning(
"Inconsistent is_list for document %s, tag_key %s",
document_id,
tag.tag_key,
)
document_metadata[tag.tag_key] = [document_metadata[tag.tag_key]]
document_metadata[tag.tag_key].append(tag.tag_value)
continue
# set value (ignore duplicate keys, though there should be none)
document_metadata.setdefault(tag.tag_key, tag.tag_value)
# should always be a value, but just in case (treat it as a list in this case)
if isinstance(document_metadata[tag.tag_key], list):
logger.warning(
"Inconsistent is_list for document %s, tag_key %s",
document_id,
tag.tag_key,
)
document_metadata[tag.tag_key] = [document_metadata[tag.tag_key]]
return document_metadata

View File

@@ -12,8 +12,7 @@ from sqlalchemy.sql import expression
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import KeyedColumnElement
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.invited_users import remove_user_from_invited_users
from onyx.auth.schemas import UserRole
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.db.models import DocumentSet__User
@@ -342,10 +341,4 @@ def delete_user_from_db(
# NOTE: edge case may exist with race conditions
# with this `invited user` scheme generally.
user_emails = get_invited_users()
remaining_users = [
remaining_user_email
for remaining_user_email in user_emails
if remaining_user_email != user_to_delete.email
]
write_invited_users(remaining_users)
remove_user_from_invited_users(user_to_delete.email)

View File

@@ -17,11 +17,11 @@ from typing import NamedTuple
from zipfile import BadZipFile
import chardet
import docx # type: ignore
import openpyxl # type: ignore
import pptx # type: ignore
from docx import Document as DocxDocument
from fastapi import UploadFile
from markitdown import FileConversionException
from markitdown import MarkItDown
from markitdown import UnsupportedFormatException
from PIL import Image
from pypdf import PdfReader
from pypdf.errors import PdfStreamError
@@ -29,6 +29,7 @@ from pypdf.errors import PdfStreamError
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import ONYX_METADATA_FILENAME
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.file_processing.file_validation import TEXT_MIME_TYPE
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.file_processing.unstructured import get_unstructured_api_key
from onyx.file_processing.unstructured import unstructured_to_text
@@ -83,11 +84,6 @@ IMAGE_MEDIA_TYPES = [
"image/webp",
]
KNOWN_OPENPYXL_BUGS = [
"Value must be either numerical or a string containing a wildcard",
"File contains no valid workbook part",
]
class OnyxExtensionType(IntFlag):
Plain = auto()
@@ -149,6 +145,13 @@ def is_macos_resource_fork_file(file_name: str) -> bool:
)
def to_bytesio(stream: IO[bytes]) -> BytesIO:
if isinstance(stream, BytesIO):
return stream
data = stream.read() # consumes the stream!
return BytesIO(data)
def load_files_from_zip(
zip_file_io: IO,
ignore_macos_resource_fork_files: bool = True,
@@ -305,19 +308,38 @@ def read_pdf_file(
return "", metadata, []
def extract_docx_images(docx_bytes: IO[Any]) -> list[tuple[bytes, str]]:
"""
Given the bytes of a docx file, extract all the images.
Returns a list of tuples (image_bytes, image_name).
"""
out = []
try:
with zipfile.ZipFile(docx_bytes) as z:
for name in z.namelist():
if name.startswith("word/media/"):
out.append((z.read(name), name.split("/")[-1]))
except Exception:
logger.exception("Failed to extract all docx images")
return out
def docx_to_text_and_images(
file: IO[Any], file_name: str = ""
) -> tuple[str, Sequence[tuple[bytes, str]]]:
"""
Extract text from a docx. If embed_images=True, also extract inline images.
Extract text from a docx.
Return (text_content, list_of_images).
"""
paragraphs = []
embedded_images: list[tuple[bytes, str]] = []
md = MarkItDown(enable_plugins=False)
try:
doc = docx.Document(file)
except (BadZipFile, ValueError) as e:
doc = md.convert(to_bytesio(file))
except (
BadZipFile,
ValueError,
FileConversionException,
UnsupportedFormatException,
) as e:
logger.warning(
f"Failed to extract docx {file_name or 'docx file'}: {e}. Attempting to read as text file."
)
@@ -330,96 +352,44 @@ def docx_to_text_and_images(
)
return text_content_raw or "", []
# Grab text from paragraphs
for paragraph in doc.paragraphs:
paragraphs.append(paragraph.text)
# Reset position so we can re-load the doc (python-docx has read the stream)
# Note: if python-docx has fully consumed the stream, you may need to open it again from memory.
# For large docs, a more robust approach is needed.
# This is a simplified example.
for rel_id, rel in doc.part.rels.items():
if "image" in rel.reltype:
# Skip images that are linked rather than embedded (TargetMode="External")
if getattr(rel, "is_external", False):
continue
try:
# image is typically in rel.target_part.blob
image_bytes = rel.target_part.blob
except ValueError:
# Safeguard against relationships that lack an internal target_part
# (e.g., external relationships or other anomalies)
continue
image_name = rel.target_part.partname
# store
embedded_images.append((image_bytes, os.path.basename(str(image_name))))
text_content = "\n".join(paragraphs)
return text_content, embedded_images
file.seek(0)
return doc.markdown, extract_docx_images(to_bytesio(file))
def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
md = MarkItDown(enable_plugins=False)
try:
presentation = pptx.Presentation(file)
except BadZipFile as e:
presentation = md.convert(to_bytesio(file))
except (
BadZipFile,
ValueError,
FileConversionException,
UnsupportedFormatException,
) as e:
error_str = f"Failed to extract text from {file_name or 'pptx file'}: {e}"
logger.warning(error_str)
return ""
text_content = []
for slide_number, slide in enumerate(presentation.slides, start=1):
slide_text = f"\nSlide {slide_number}:\n"
for shape in slide.shapes:
if hasattr(shape, "text"):
slide_text += shape.text + "\n"
text_content.append(slide_text)
return TEXT_SECTION_SEPARATOR.join(text_content)
return presentation.markdown
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
md = MarkItDown(enable_plugins=False)
try:
workbook = openpyxl.load_workbook(file, read_only=True)
except BadZipFile as e:
workbook = md.convert(to_bytesio(file))
except (
BadZipFile,
ValueError,
FileConversionException,
UnsupportedFormatException,
) as e:
error_str = f"Failed to extract text from {file_name or 'xlsx file'}: {e}"
if file_name.startswith("~"):
logger.debug(error_str + " (this is expected for files with ~)")
else:
logger.warning(error_str)
return ""
except Exception as e:
if any(s in str(e) for s in KNOWN_OPENPYXL_BUGS):
logger.error(
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
)
return ""
raise e
text_content = []
for sheet in workbook.worksheets:
rows = []
num_empty_consecutive_rows = 0
for row in sheet.iter_rows(min_row=1, values_only=True):
row_str = ",".join(str(cell or "") for cell in row)
# Only add the row if there are any values in the cells
if len(row_str) >= len(row):
rows.append(row_str)
num_empty_consecutive_rows = 0
else:
num_empty_consecutive_rows += 1
if num_empty_consecutive_rows > 100:
# handle massive excel sheets with mostly empty cells
logger.warning(
f"Found {num_empty_consecutive_rows} empty rows in {file_name},"
" skipping rest of file"
)
break
sheet_str = "\n".join(rows)
text_content.append(sheet_str)
return TEXT_SECTION_SEPARATOR.join(text_content)
return workbook.markdown
def eml_to_text(file: IO[Any]) -> str:
@@ -472,9 +442,9 @@ def extract_file_text(
"""
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
".pdf": pdf_to_text,
".docx": lambda f: docx_to_text_and_images(f)[0], # no images
".pptx": pptx_to_text,
".xlsx": xlsx_to_text,
".docx": lambda f: docx_to_text_and_images(f, file_name)[0], # no images
".pptx": lambda f: pptx_to_text(f, file_name),
".xlsx": lambda f: xlsx_to_text(f, file_name),
".eml": eml_to_text,
".epub": epub_to_text,
".html": parse_html_page_basic,
@@ -523,10 +493,23 @@ class ExtractionResult(NamedTuple):
metadata: dict[str, Any]
def extract_result_from_text_file(file: IO[Any]) -> ExtractionResult:
encoding = detect_encoding(file)
text_content_raw, file_metadata = read_text_file(
file, encoding=encoding, ignore_onyx_metadata=False
)
return ExtractionResult(
text_content=text_content_raw,
embedded_images=[],
metadata=file_metadata,
)
def extract_text_and_images(
file: IO[Any],
file_name: str,
pdf_pass: str | None = None,
content_type: str | None = None,
) -> ExtractionResult:
"""
Primary new function for the updated connector.
@@ -547,13 +530,20 @@ def extract_text_and_images(
)
file.seek(0) # Reset file pointer just in case
# When we upload a document via a connector or MyDocuments, we extract and store the content of files
# with content types in UploadMimeTypes.DOCUMENT_MIME_TYPES as plain text files.
# As a result, the file name extension may differ from the original content type.
# We process files with a plain text content type first to handle this scenario.
if content_type == TEXT_MIME_TYPE:
return extract_result_from_text_file(file)
# Default processing
try:
extension = get_file_ext(file_name)
# docx example for embedded images
if extension == ".docx":
text_content, images = docx_to_text_and_images(file)
text_content, images = docx_to_text_and_images(file, file_name)
return ExtractionResult(
text_content=text_content, embedded_images=images, metadata={}
)
@@ -605,15 +595,7 @@ def extract_text_and_images(
# If we reach here and it's a recognized text extension
if is_text_file_extension(file_name):
encoding = detect_encoding(file)
text_content_raw, file_metadata = read_text_file(
file, encoding=encoding, ignore_onyx_metadata=False
)
return ExtractionResult(
text_content=text_content_raw,
embedded_images=[],
metadata=file_metadata,
)
return extract_result_from_text_file(file)
# If it's an image file or something else, we do not parse embedded images from them
# just return empty text

View File

@@ -21,6 +21,9 @@ EXCLUDED_IMAGE_TYPES = [
"image/avif",
]
# Text MIME types
TEXT_MIME_TYPE = "text/plain"
def is_valid_image_type(mime_type: str) -> bool:
"""
@@ -32,9 +35,11 @@ def is_valid_image_type(mime_type: str) -> bool:
Returns:
True if the MIME type is a valid image type, False otherwise
"""
if not mime_type:
return False
return mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
return (
bool(mime_type)
and mime_type.startswith("image/")
and mime_type not in EXCLUDED_IMAGE_TYPES
)
def is_supported_by_vision_llm(mime_type: str) -> bool:

View File

@@ -1,6 +1,7 @@
import re
from copy import copy
from dataclasses import dataclass
from io import BytesIO
from typing import IO
import bs4
@@ -161,7 +162,7 @@ def format_document_soup(
return strip_excessive_newlines_and_spaces(text)
def parse_html_page_basic(text: str | IO[bytes]) -> str:
def parse_html_page_basic(text: str | BytesIO | IO[bytes]) -> str:
soup = bs4.BeautifulSoup(text, "html.parser")
return format_document_soup(soup)

View File

@@ -196,6 +196,9 @@ class FileStoreDocumentBatchStorage(DocumentBatchStorage):
for batch_file_name in batch_names:
path_info = self.extract_path_info(batch_file_name)
if path_info is None:
logger.warning(
f"Could not extract path info from batch file: {batch_file_name}"
)
continue
new_batch_file_name = self._get_batch_file_name(path_info.batch_num)
self.file_store.change_file_id(batch_file_name, new_batch_file_name)

View File

@@ -19,6 +19,14 @@ class ChatFileType(str, Enum):
# "user knowledge" is not a file type, it's a source or intent
USER_KNOWLEDGE = "user_knowledge"
def is_text_file(self) -> bool:
return self in (
ChatFileType.PLAIN_TEXT,
ChatFileType.DOC,
ChatFileType.CSV,
ChatFileType.USER_KNOWLEDGE,
)
class FileDescriptor(TypedDict):
"""NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column

View File

@@ -49,11 +49,10 @@ def sanitize_s3_key_name(file_name: str) -> str:
# Characters to avoid completely (replace with underscore)
# These are characters that AWS recommends avoiding
avoid_chars = r'[\\{}^%`\[\]"<>#|~]'
avoid_chars = r'[\\{}^%`\[\]"<>#|~/]'
# Replace avoided characters with underscore
sanitized = re.sub(avoid_chars, "_", file_name)
# Characters that might require special handling but are allowed
# We'll URL encode these to be safe
special_chars = r"[&$@=;:+,?\s]"
@@ -81,6 +80,9 @@ def sanitize_s3_key_name(file_name: str) -> str:
# Remove any trailing periods to avoid download issues
sanitized = sanitized.rstrip(".")
# Remove multiple separators
sanitized = re.sub(r"[-_]{2,}", "-", sanitized)
# If sanitization resulted in empty string, use a default
if not sanitized:
sanitized = "sanitized_file"

View File

@@ -22,6 +22,8 @@ from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
RECENT_FOLDER_ID = -1
def user_file_id_to_plaintext_file_name(user_file_id: int) -> str:
"""Generate a consistent file name for storing plaintext content of a user file."""
@@ -46,7 +48,6 @@ def store_user_file_plaintext(user_file_id: int, plaintext_content: str) -> bool
# Get plaintext file name
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
# Use a separate session to avoid committing the caller's transaction
try:
file_store = get_default_file_store()
file_content = BytesIO(plaintext_content.encode("utf-8"))
@@ -245,14 +246,21 @@ def get_user_files_as_user(
Fetches all UserFile database records for a given user.
"""
user_files = get_user_files(user_file_ids, user_folder_ids, db_session)
current_user_files = []
for user_file in user_files:
# Note: if user_id is None, then all files should be None as well
# (since auth must be disabled in this case)
if user_file.user_id != user_id:
raise ValueError(
f"User {user_id} does not have access to file {user_file.id}"
)
return user_files
if user_file.folder_id == RECENT_FOLDER_ID:
if user_file.user_id == user_id:
current_user_files.append(user_file)
else:
if user_file.user_id != user_id:
raise ValueError(
f"User {user_id} does not have access to file {user_file.id}"
)
current_user_files.append(user_file)
return current_user_files
def save_file_from_url(url: str) -> str:

View File

@@ -44,8 +44,7 @@ from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.db.models import Document as DBDocument
from onyx.db.models import IndexModelStatus
from onyx.db.search_settings import get_active_search_settings
from onyx.db.tag import create_or_add_document_tag
from onyx.db.tag import create_or_add_document_tag_list
from onyx.db.tag import upsert_document_tags
from onyx.db.user_documents import fetch_user_files_for_documents
from onyx.db.user_documents import fetch_user_folders_for_documents
from onyx.db.user_documents import update_user_file_token_count__no_commit
@@ -150,24 +149,12 @@ def _upsert_documents_in_db(
# Insert document content metadata
for doc in documents:
for k, v in doc.metadata.items():
if isinstance(v, list):
create_or_add_document_tag_list(
tag_key=k,
tag_values=v,
source=doc.source,
document_id=doc.id,
db_session=db_session,
)
continue
create_or_add_document_tag(
tag_key=k,
tag_value=v,
source=doc.source,
document_id=doc.id,
db_session=db_session,
)
upsert_document_tags(
document_id=doc.id,
source=doc.source,
metadata=doc.metadata,
db_session=db_session,
)
def _get_aggregated_chunk_boost_factor(
@@ -867,31 +854,27 @@ def index_doc_batch(
user_file_id_to_raw_text: dict[int, str] = {}
for document_id in updatable_ids:
# Only calculate token counts for documents that have a user file ID
if (
document_id in doc_id_to_user_file_id
and doc_id_to_user_file_id[document_id] is not None
):
user_file_id = doc_id_to_user_file_id[document_id]
if not user_file_id:
continue
document_chunks = [
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == document_id
]
if document_chunks:
combined_content = " ".join(
[chunk.content for chunk in document_chunks]
)
token_count = (
len(llm_tokenizer.encode(combined_content))
if llm_tokenizer
else 0
)
user_file_id_to_token_count[user_file_id] = token_count
user_file_id_to_raw_text[user_file_id] = combined_content
else:
user_file_id_to_token_count[user_file_id] = None
user_file_id = doc_id_to_user_file_id.get(document_id)
if user_file_id is None:
continue
document_chunks = [
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == document_id
]
if document_chunks:
combined_content = " ".join(
[chunk.content for chunk in document_chunks]
)
token_count = (
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
)
user_file_id_to_token_count[user_file_id] = token_count
user_file_id_to_raw_text[user_file_id] = combined_content
else:
user_file_id_to_token_count[user_file_id] = None
# we're concerned about race conditions where multiple simultaneous indexings might result
# in one set of metadata overwriting another one in vespa.

View File

@@ -24,6 +24,7 @@ from langchain_core.messages import SystemMessageChunk
from langchain_core.messages.tool import ToolCallChunk
from langchain_core.messages.tool import ToolMessage
from langchain_core.prompt_values import PromptValue
from litellm.utils import get_supported_openai_params
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
@@ -52,6 +53,8 @@ litellm.telemetry = False
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
VERTEX_LOCATION_KWARG = "vertex_location"
LEGACY_MAX_TOKENS_KWARG = "max_tokens"
STANDARD_MAX_TOKENS_KWARG = "max_completion_tokens"
class LLMTimeoutError(Exception):
@@ -313,14 +316,22 @@ class DefaultMultiLLM(LLM):
self._model_kwargs = model_kwargs
def log_model_configs(self) -> None:
logger.debug(f"Config: {self.config}")
self._max_token_param = LEGACY_MAX_TOKENS_KWARG
try:
params = get_supported_openai_params(model_name, model_provider)
if STANDARD_MAX_TOKENS_KWARG in (params or []):
self._max_token_param = STANDARD_MAX_TOKENS_KWARG
except Exception as e:
logger.warning(f"Error getting supported openai params: {e}")
def _safe_model_config(self) -> dict:
dump = self.config.model_dump()
dump["api_key"] = mask_string(dump.get("api_key", ""))
return dump
def log_model_configs(self) -> None:
logger.debug(f"Config: {self._safe_model_config()}")
def _record_call(self, prompt: LanguageModelInput) -> None:
if self._long_term_logger:
self._long_term_logger.record(
@@ -393,11 +404,14 @@ class DefaultMultiLLM(LLM):
messages=processed_prompt,
tools=tools,
tool_choice=tool_choice if tools else None,
max_tokens=max_tokens,
# streaming choice
stream=stream,
# model params
temperature=self._temperature,
temperature=(
1
if self.config.model_name in ["gpt-5", "gpt-5-mini", "gpt-5-nano"]
else self._temperature
),
timeout=timeout_override or self._timeout,
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
@@ -422,6 +436,7 @@ class DefaultMultiLLM(LLM):
if structured_response_format
else {}
),
**({self._max_token_param: max_tokens} if max_tokens else {}),
**self._model_kwargs,
)
except Exception as e:

View File

@@ -47,6 +47,9 @@ class WellKnownLLMProviderDescriptor(BaseModel):
OPENAI_PROVIDER_NAME = "openai"
OPEN_AI_MODEL_NAMES = [
"gpt-5",
"gpt-5-mini",
"gpt-5-nano",
"o4-mini",
"o3-mini",
"o1-mini",
@@ -73,7 +76,14 @@ OPEN_AI_MODEL_NAMES = [
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0301",
]
OPEN_AI_VISIBLE_MODEL_NAMES = ["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"]
OPEN_AI_VISIBLE_MODEL_NAMES = [
"gpt-5",
"gpt-5-mini",
"o1",
"o3-mini",
"gpt-4o",
"gpt-4o-mini",
]
BEDROCK_PROVIDER_NAME = "bedrock"
# need to remove all the weird "bedrock/eu-central-1/anthropic.claude-v1" named

View File

@@ -72,6 +72,7 @@ class PreviousMessage(BaseModel):
message_type = MessageType.USER
elif isinstance(msg, AIMessage):
message_type = MessageType.ASSISTANT
message = message_to_string(msg)
return cls(
message=message,

View File

@@ -136,16 +136,7 @@ def _build_content(
if not files:
return message
text_files = [
file
for file in files
if file.file_type
in (
ChatFileType.PLAIN_TEXT,
ChatFileType.CSV,
ChatFileType.USER_KNOWLEDGE,
)
]
text_files = [file for file in files if file.file_type.is_text_file()]
if not text_files:
return message

View File

@@ -0,0 +1,40 @@
"""
Constants for natural language processing, including embedding and reranking models.
This file contains constants moved from model_server to support the gradual migration
of API-based calls to bypass the model server.
"""
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
# Default model names for different providers
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
DEFAULT_VERTEX_MODEL = "text-embedding-005"
class EmbeddingModelTextType:
"""Mapping of Onyx text types to provider-specific text types."""
PROVIDER_TEXT_TYPE_MAP = {
EmbeddingProvider.COHERE: {
EmbedTextType.QUERY: "search_query",
EmbedTextType.PASSAGE: "search_document",
},
EmbeddingProvider.VOYAGE: {
EmbedTextType.QUERY: "query",
EmbedTextType.PASSAGE: "document",
},
EmbeddingProvider.GOOGLE: {
EmbedTextType.QUERY: "RETRIEVAL_QUERY",
EmbedTextType.PASSAGE: "RETRIEVAL_DOCUMENT",
},
}
@staticmethod
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
"""Get provider-specific text type string."""
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]

View File

@@ -1,3 +1,5 @@
import asyncio
import json
import threading
import time
from collections.abc import Callable
@@ -5,14 +7,26 @@ from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from functools import wraps
from types import TracebackType
from typing import Any
from typing import cast
import aioboto3 # type: ignore
import httpx
import openai
import requests
import vertexai # type: ignore
import voyageai # type: ignore
from cohere import AsyncClient as CohereAsyncClient
from google.oauth2 import service_account # type: ignore
from httpx import HTTPError
from litellm import aembedding
from requests import JSONDecodeError
from requests import RequestException
from requests import Response
from retry import retry
from vertexai.language_models import TextEmbeddingInput # type: ignore
from vertexai.language_models import TextEmbeddingModel # type: ignore
from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
@@ -25,16 +39,26 @@ from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.connectors.models import ConnectorStopSignal
from onyx.db.models import SearchSettings
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.natural_language_processing.constants import DEFAULT_COHERE_MODEL
from onyx.natural_language_processing.constants import DEFAULT_OPENAI_MODEL
from onyx.natural_language_processing.constants import DEFAULT_VERTEX_MODEL
from onyx.natural_language_processing.constants import DEFAULT_VOYAGE_MODEL
from onyx.natural_language_processing.constants import EmbeddingModelTextType
from onyx.natural_language_processing.exceptions import (
ModelServerRateLimitError,
)
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.natural_language_processing.utils import tokenizer_trim_content
from onyx.utils.logger import setup_logger
from onyx.utils.search_nlp_models_utils import pass_aws_key
from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
from shared_configs.configs import VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
@@ -53,6 +77,21 @@ from shared_configs.utils import batch_list
logger = setup_logger()
# If we are not only indexing, dont want retry very long
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
# OpenAI only allows 2048 embeddings to be computed at once
_OPENAI_MAX_INPUT_LEN = 2048
# Cohere allows up to 96 embeddings in a single embedding calling
_COHERE_MAX_INPUT_LEN = 96
# Authentication error string constants
_AUTH_ERROR_401 = "401"
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
_AUTH_ERROR_PERMISSION = "permission"
WARM_UP_STRINGS = [
"Onyx is amazing!",
@@ -79,6 +118,377 @@ def build_model_server_url(
return f"http://{model_server_url}"
def is_authentication_error(error: Exception) -> bool:
"""Check if an exception is related to authentication issues.
Args:
error: The exception to check
Returns:
bool: True if the error appears to be authentication-related
"""
error_str = str(error).lower()
return (
_AUTH_ERROR_401 in error_str
or _AUTH_ERROR_UNAUTHORIZED in error_str
or _AUTH_ERROR_INVALID_API_KEY in error_str
or _AUTH_ERROR_PERMISSION in error_str
)
def format_embedding_error(
error: Exception,
service_name: str,
model: str | None,
provider: EmbeddingProvider,
sanitized_api_key: str | None = None,
status_code: int | None = None,
) -> str:
"""
Format a standardized error string for embedding errors.
"""
detail = f"Status {status_code}" if status_code else f"{type(error)}"
return (
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
f"Model: {model} "
f"Provider: {provider} "
f"API Key: {sanitized_api_key} "
f"Exception: {error}"
)
# Custom exception for authentication errors
class AuthenticationError(Exception):
"""Raised when authentication fails with a provider."""
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
self.provider = provider
self.message = message
super().__init__(f"{provider} authentication failed: {message}")
class CloudEmbedding:
def __init__(
self,
api_key: str,
provider: EmbeddingProvider,
api_url: str | None = None,
api_version: str | None = None,
timeout: int = API_BASED_EMBEDDING_TIMEOUT,
) -> None:
self.provider = provider
self.api_key = api_key
self.api_url = api_url
self.api_version = api_version
self.timeout = timeout
self.http_client = httpx.AsyncClient(timeout=timeout)
self._closed = False
self.sanitized_api_key = api_key[:4] + "********" + api_key[-4:]
async def _embed_openai(
self, texts: list[str], model: str | None, reduced_dimension: int | None
) -> list[Embedding]:
if not model:
model = DEFAULT_OPENAI_MODEL
# Use the OpenAI specific timeout for this one
client = openai.AsyncOpenAI(
api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT
)
final_embeddings: list[Embedding] = []
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
async def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_COHERE_MODEL
client = CohereAsyncClient(api_key=self.api_key)
final_embeddings: list[Embedding] = []
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
# Does not use the same tokenizer as the Onyx API server but it's approximately the same
# empirically it's only off by a very few tokens so it's not a big deal
response = await client.embed(
texts=text_batch,
model=model,
input_type=embedding_type,
truncate="END",
)
final_embeddings.extend(cast(list[Embedding], response.embeddings))
return final_embeddings
async def _embed_voyage(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_VOYAGE_MODEL
client = voyageai.AsyncClient(
api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT
)
response = await client.embed(
texts=texts,
model=model,
input_type=embedding_type,
truncation=True,
)
return response.embeddings
async def _embed_azure(
self, texts: list[str], model: str | None
) -> list[Embedding]:
response = await aembedding(
model=model,
input=texts,
timeout=API_BASED_EMBEDDING_TIMEOUT,
api_key=self.api_key,
api_base=self.api_url,
api_version=self.api_version,
)
embeddings = [embedding["embedding"] for embedding in response.data]
return embeddings
async def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[Embedding]:
if not model:
model = DEFAULT_VERTEX_MODEL
service_account_info = json.loads(self.api_key)
credentials = service_account.Credentials.from_service_account_info(
service_account_info
)
project_id = service_account_info["project_id"]
vertexai.init(project=project_id, credentials=credentials)
client = TextEmbeddingModel.from_pretrained(model)
inputs = [TextEmbeddingInput(text, embedding_type) for text in texts]
# Split into batches of 25 texts
max_texts_per_batch = VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
batches = [
inputs[i : i + max_texts_per_batch]
for i in range(0, len(inputs), max_texts_per_batch)
]
# Dispatch all embedding calls asynchronously at once
tasks = [
client.get_embeddings_async(batch, auto_truncate=True) for batch in batches
]
# Wait for all tasks to complete in parallel
results = await asyncio.gather(*tasks)
return [embedding.values for batch in results for embedding in batch]
async def _embed_litellm_proxy(
self, texts: list[str], model_name: str | None
) -> list[Embedding]:
if not model_name:
raise ValueError("Model name is required for LiteLLM proxy embedding.")
if not self.api_url:
raise ValueError("API URL is required for LiteLLM proxy embedding.")
headers = (
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
)
response = await self.http_client.post(
self.api_url,
json={
"model": model_name,
"input": texts,
},
headers=headers,
)
response.raise_for_status()
result = response.json()
return [embedding["embedding"] for embedding in result["data"]]
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
async def embed(
self,
*,
texts: list[str],
text_type: EmbedTextType,
model_name: str | None = None,
deployment_name: str | None = None,
reduced_dimension: int | None = None,
) -> list[Embedding]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except openai.AuthenticationError:
raise AuthenticationError(provider="OpenAI")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e,
str(self.provider),
model_name or deployment_name,
self.provider,
sanitized_api_key=self.sanitized_api_key,
status_code=e.response.status_code,
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
except Exception as e:
if is_authentication_error(e):
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e,
str(self.provider),
model_name or deployment_name,
self.provider,
sanitized_api_key=self.sanitized_api_key,
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
@staticmethod
def create(
api_key: str,
provider: EmbeddingProvider,
api_url: str | None = None,
api_version: str | None = None,
) -> "CloudEmbedding":
logger.debug(f"Creating Embedding instance for provider: {provider}")
return CloudEmbedding(api_key, provider, api_url, api_version)
async def aclose(self) -> None:
"""Explicitly close the client."""
if not self._closed:
await self.http_client.aclose()
self._closed = True
async def __aenter__(self) -> "CloudEmbedding":
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
await self.aclose()
def __del__(self) -> None:
"""Finalizer to warn about unclosed clients."""
if not self._closed:
logger.warning(
"CloudEmbedding was not properly closed. Use 'async with' or call aclose()"
)
# API-based reranking functions (moved from model server)
async def cohere_rerank_api(
query: str, docs: list[str], model_name: str, api_key: str
) -> list[float]:
cohere_client = CohereAsyncClient(api_key=api_key)
response = await cohere_client.rerank(query=query, documents=docs, model=model_name)
results = response.results
sorted_results = sorted(results, key=lambda item: item.index)
return [result.relevance_score for result in sorted_results]
async def cohere_rerank_aws(
query: str,
docs: list[str],
model_name: str,
region_name: str,
aws_access_key_id: str,
aws_secret_access_key: str,
) -> list[float]:
session = aioboto3.Session(
aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key
)
async with session.client(
"bedrock-runtime", region_name=region_name
) as bedrock_client:
body = json.dumps(
{
"query": query,
"documents": docs,
"api_version": 2,
}
)
# Invoke the Bedrock model asynchronously
response = await bedrock_client.invoke_model(
modelId=model_name,
accept="application/json",
contentType="application/json",
body=body,
)
# Read the response asynchronously
response_body = json.loads(await response["body"].read())
# Extract and sort the results
results = response_body.get("results", [])
sorted_results = sorted(results, key=lambda item: item["index"])
return [result["relevance_score"] for result in sorted_results]
async def litellm_rerank(
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
) -> list[float]:
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
async with httpx.AsyncClient() as client:
response = await client.post(
api_url,
json={
"model": model_name,
"query": query,
"documents": docs,
},
headers=headers,
)
response.raise_for_status()
result = response.json()
return [
item["relevance_score"]
for item in sorted(result["results"], key=lambda x: x["index"])
]
class EmbeddingModel:
def __init__(
self,
@@ -113,8 +523,84 @@ class EmbeddingModel:
)
self.callback = callback
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
# Only build model server endpoint for local models
if self.provider_type is None:
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint: str | None = (
f"{model_server_url}/encoder/bi-encoder-embed"
)
else:
# API providers don't need model server endpoint
self.embed_server_endpoint = None
async def _make_direct_api_call(
self,
embed_request: EmbedRequest,
tenant_id: str | None = None,
request_id: str | None = None,
) -> EmbedResponse:
"""Make direct API call to cloud provider, bypassing model server."""
if self.provider_type is None:
raise ValueError("Provider type is required for direct API calls")
if self.api_key is None:
logger.error("API key not provided for cloud model")
raise RuntimeError("API key not provided for cloud model")
# Check for prefix usage with cloud models
if embed_request.manual_query_prefix or embed_request.manual_passage_prefix:
logger.warning("Prefix provided for cloud model, which is not supported")
raise ValueError(
"Prefix string is not valid for cloud models. "
"Cloud models take an explicit text type instead."
)
if not all(embed_request.texts):
logger.error("Empty strings provided for embedding")
raise ValueError("Empty strings are not allowed for embedding.")
if not embed_request.texts:
logger.error("No texts provided for embedding")
raise ValueError("No texts provided for embedding.")
start_time = time.monotonic()
total_chars = sum(len(text) for text in embed_request.texts)
logger.info(
f"Embedding {len(embed_request.texts)} texts with {total_chars} total characters with provider: {self.provider_type}"
)
async with CloudEmbedding(
api_key=self.api_key,
provider=self.provider_type,
api_url=self.api_url,
api_version=self.api_version,
) as cloud_model:
embeddings = await cloud_model.embed(
texts=embed_request.texts,
model_name=embed_request.model_name,
deployment_name=embed_request.deployment_name,
text_type=embed_request.text_type,
reduced_dimension=embed_request.reduced_dimension,
)
if any(embedding is None for embedding in embeddings):
error_message = "Embeddings contain None values\n"
error_message += "Corresponding texts:\n"
error_message += "\n".join(embed_request.texts)
logger.error(error_message)
raise ValueError(error_message)
elapsed = time.monotonic() - start_time
logger.info(
f"event=embedding_provider "
f"texts={len(embed_request.texts)} "
f"chars={total_chars} "
f"provider={self.provider_type} "
f"elapsed={elapsed:.2f}"
)
return EmbedResponse(embeddings=embeddings)
def _make_model_server_request(
self,
@@ -122,6 +608,12 @@ class EmbeddingModel:
tenant_id: str | None = None,
request_id: str | None = None,
) -> EmbedResponse:
if self.embed_server_endpoint is None:
raise ValueError("Model server endpoint is not configured for local models")
# Store the endpoint in a local variable to help mypy understand it's not None
endpoint = self.embed_server_endpoint
def _make_request() -> Response:
headers = {}
if tenant_id:
@@ -131,7 +623,7 @@ class EmbeddingModel:
headers["X-Onyx-Request-ID"] = request_id
response = requests.post(
self.embed_server_endpoint,
endpoint,
headers=headers,
json=embed_request.model_dump(),
)
@@ -219,11 +711,28 @@ class EmbeddingModel:
reduced_dimension=self.reduced_dimension,
)
start_time = time.time()
response = self._make_model_server_request(
embed_request, tenant_id=tenant_id, request_id=request_id
)
end_time = time.time()
start_time = time.monotonic()
# Route between direct API calls and model server calls
if self.provider_type is not None:
# For API providers, make direct API call
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
response = loop.run_until_complete(
self._make_direct_api_call(
embed_request, tenant_id=tenant_id, request_id=request_id
)
)
finally:
loop.close()
else:
# For local models, use model server
response = self._make_model_server_request(
embed_request, tenant_id=tenant_id, request_id=request_id
)
end_time = time.monotonic()
processing_time = end_time - start_time
logger.debug(
@@ -360,29 +869,92 @@ class RerankingModel:
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
self.model_name = model_name
self.provider_type = provider_type
self.api_key = api_key
self.api_url = api_url
# Only build model server endpoint for local models
if self.provider_type is None:
model_server_url = build_model_server_url(
model_server_host, model_server_port
)
self.rerank_server_endpoint: str | None = (
model_server_url + "/encoder/cross-encoder-scores"
)
else:
# API providers don't need model server endpoint
self.rerank_server_endpoint = None
async def _make_direct_rerank_call(
self, query: str, passages: list[str]
) -> list[float]:
"""Make direct API call to cloud provider, bypassing model server."""
if self.provider_type is None:
raise ValueError("Provider type is required for direct API calls")
if self.api_key is None:
raise ValueError("API key is required for cloud provider")
if self.provider_type == RerankerProvider.COHERE:
return await cohere_rerank_api(
query, passages, self.model_name, self.api_key
)
elif self.provider_type == RerankerProvider.BEDROCK:
aws_access_key_id, aws_secret_access_key, aws_region = pass_aws_key(
self.api_key
)
return await cohere_rerank_aws(
query,
passages,
self.model_name,
aws_region,
aws_access_key_id,
aws_secret_access_key,
)
elif self.provider_type == RerankerProvider.LITELLM:
if self.api_url is None:
raise ValueError("API URL is required for LiteLLM reranking.")
return await litellm_rerank(
query, passages, self.api_url, self.model_name, self.api_key
)
else:
raise ValueError(f"Unsupported reranking provider: {self.provider_type}")
def predict(self, query: str, passages: list[str]) -> list[float]:
rerank_request = RerankRequest(
query=query,
documents=passages,
model_name=self.model_name,
provider_type=self.provider_type,
api_key=self.api_key,
api_url=self.api_url,
)
# Route between direct API calls and model server calls
if self.provider_type is not None:
# For API providers, make direct API call
loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(loop)
return loop.run_until_complete(
self._make_direct_rerank_call(query, passages)
)
finally:
loop.close()
else:
# For local models, use model server
if self.rerank_server_endpoint is None:
raise ValueError(
"Rerank server endpoint is not configured for local models"
)
response = requests.post(
self.rerank_server_endpoint, json=rerank_request.model_dump()
)
response.raise_for_status()
rerank_request = RerankRequest(
query=query,
documents=passages,
model_name=self.model_name,
provider_type=self.provider_type,
api_key=self.api_key,
api_url=self.api_url,
)
return RerankResponse(**response.json()).scores
response = requests.post(
self.rerank_server_endpoint, json=rerank_request.model_dump()
)
response.raise_for_status()
return RerankResponse(**response.json()).scores
class QueryAnalysisModel:

View File

@@ -151,7 +151,7 @@ def _build_ephemeral_publication_block(
email=message_info.email,
sender_id=message_info.sender_id,
thread_messages=[],
is_bot_msg=message_info.is_bot_msg,
is_slash_command=message_info.is_slash_command,
is_bot_dm=message_info.is_bot_dm,
thread_to_respond=respond_ts,
)
@@ -225,10 +225,10 @@ def _build_doc_feedback_block(
def get_restate_blocks(
msg: str,
is_bot_msg: bool,
is_slash_command: bool,
) -> list[Block]:
# Only the slash command needs this context because the user doesn't see their own input
if not is_bot_msg:
if not is_slash_command:
return []
return [
@@ -576,7 +576,7 @@ def build_slack_response_blocks(
# If called with the OnyxBot slash command, the question is lost so we have to reshow it
if not skip_restated_question:
restate_question_block = get_restate_blocks(
message_info.thread_messages[-1].message, message_info.is_bot_msg
message_info.thread_messages[-1].message, message_info.is_slash_command
)
else:
restate_question_block = []

View File

@@ -177,7 +177,7 @@ def handle_generate_answer_button(
sender_id=user_id or None,
email=email or None,
bypass_filters=True,
is_bot_msg=False,
is_slash_command=False,
is_bot_dm=False,
),
slack_channel_config=slack_channel_config,

View File

@@ -28,7 +28,7 @@ logger_base = setup_logger()
def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
if details.is_bot_msg and details.sender_id:
if details.is_slash_command and details.sender_id:
respond_in_thread_or_channel(
client=client,
channel=details.channel_to_respond,
@@ -124,11 +124,11 @@ def handle_message(
messages = message_info.thread_messages
sender_id = message_info.sender_id
bypass_filters = message_info.bypass_filters
is_bot_msg = message_info.is_bot_msg
is_slash_command = message_info.is_slash_command
is_bot_dm = message_info.is_bot_dm
action = "slack_message"
if is_bot_msg:
if is_slash_command:
action = "slack_slash_message"
elif bypass_filters:
action = "slack_tag_message"
@@ -197,7 +197,7 @@ def handle_message(
# If configured to respond to team members only, then cannot be used with a /OnyxBot command
# which would just respond to the sender
if send_to and is_bot_msg:
if send_to and is_slash_command:
if sender_id:
respond_in_thread_or_channel(
client=client,

View File

@@ -81,15 +81,15 @@ def handle_regular_answer(
messages = message_info.thread_messages
message_ts_to_respond_to = message_info.msg_to_respond
is_bot_msg = message_info.is_bot_msg
is_slash_command = message_info.is_slash_command
# Capture whether response mode for channel is ephemeral. Even if the channel is set
# to respond with an ephemeral message, we still send as non-ephemeral if
# the message is a dm with the Onyx bot.
send_as_ephemeral = (
slack_channel_config.channel_config.get("is_ephemeral", False)
and not message_info.is_bot_dm
)
or message_info.is_slash_command
) and not message_info.is_bot_dm
# If the channel mis configured to respond with an ephemeral message,
# or the message is a dm to the Onyx bot, we should use the proper onyx user from the email.
@@ -164,7 +164,7 @@ def handle_regular_answer(
# in an attached document set were available to all users in the channel.)
bypass_acl = False
if not message_ts_to_respond_to and not is_bot_msg:
if not message_ts_to_respond_to and not is_slash_command:
# if the message is not "/onyx" command, then it should have a message ts to respond to
raise RuntimeError(
"No message timestamp to respond to in `handle_message`. This should never happen."
@@ -316,13 +316,14 @@ def handle_regular_answer(
return True
# Got an answer at this point, can remove reaction and give results
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
if not is_slash_command: # Slash commands don't have reactions
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
if answer.answer_valid is False:
logger.notice(

View File

@@ -130,6 +130,10 @@ _SLACK_GREETINGS_TO_IGNORE = {
# This is always (currently) the user id of Slack's official slackbot
_OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
# Fields to exclude from Slack payload logging
# Intention is to not log slack message content
_EXCLUDED_SLACK_PAYLOAD_FIELDS = {"text", "blocks"}
class SlackbotHandler:
def __init__(self) -> None:
@@ -570,6 +574,20 @@ class SlackbotHandler:
sys.exit(0)
def sanitize_slack_payload(payload: dict) -> dict:
"""Remove message content from Slack payload for logging"""
sanitized = {
k: v for k, v in payload.items() if k not in _EXCLUDED_SLACK_PAYLOAD_FIELDS
}
if "event" in sanitized and isinstance(sanitized["event"], dict):
sanitized["event"] = {
k: v
for k, v in sanitized["event"].items()
if k not in _EXCLUDED_SLACK_PAYLOAD_FIELDS
}
return sanitized
def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool:
"""True to keep going, False to ignore this Slack request"""
@@ -762,7 +780,10 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
if not check_message_limit():
return False
logger.debug(f"Handling Slack request: {client.bot_name=} '{req.payload=}'")
# Don't log Slack message content
logger.debug(
f"Handling Slack request: {client.bot_name=} '{sanitize_slack_payload(req.payload)=}'"
)
return True
@@ -876,12 +897,13 @@ def build_request_details(
sender_id=sender_id,
email=email,
bypass_filters=tagged,
is_bot_msg=False,
is_slash_command=False,
is_bot_dm=event.get("channel_type") == "im",
)
elif req.type == "slash_commands":
channel = req.payload["channel_id"]
channel_name = req.payload["channel_name"]
msg = req.payload["text"]
sender = req.payload["user_id"]
expert_info = expert_info_from_slack_id(
@@ -899,8 +921,8 @@ def build_request_details(
sender_id=sender,
email=email,
bypass_filters=True,
is_bot_msg=True,
is_bot_dm=False,
is_slash_command=True,
is_bot_dm=channel_name == "directmessage",
)
raise RuntimeError("Programming fault, this should never happen.")
@@ -928,10 +950,9 @@ def process_message(
if req.type == "events_api":
event = cast(dict[str, Any], req.payload["event"])
event_type = event.get("type")
msg = cast(str, event.get("text", ""))
logger.info(
f"process_message start: {tenant_id=} {req.type=} {req.envelope_id=} "
f"{event_type=} {msg=}"
f"{event_type=}"
)
else:
logger.info(

View File

@@ -13,7 +13,7 @@ class SlackMessageInfo(BaseModel):
sender_id: str | None
email: str | None
bypass_filters: bool # User has tagged @OnyxBot
is_bot_msg: bool # User is using /OnyxBot
is_slash_command: bool # User is using /OnyxBot
is_bot_dm: bool # User is direct messaging to OnyxBot
@@ -25,7 +25,7 @@ class ActionValuesEphemeralMessageMessageInfo(BaseModel):
email: str | None
sender_id: str | None
thread_messages: list[ThreadMessage] | None
is_bot_msg: bool | None
is_slash_command: bool | None
is_bot_dm: bool | None
thread_to_respond: str | None

View File

@@ -3,7 +3,6 @@ import redis
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_pool import get_redis_client
@@ -31,11 +30,6 @@ class RedisConnector:
tenant_id, cc_pair_id, self.redis
)
def new_index(self, search_settings_id: int) -> RedisConnectorIndex:
return RedisConnectorIndex(
self.tenant_id, self.cc_pair_id, search_settings_id, self.redis
)
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""
@@ -81,3 +75,11 @@ class RedisConnector:
object_id = parts[1]
return object_id
def db_lock_key(self, search_settings_id: int) -> str:
"""
Key for the db lock for an indexing attempt.
Prevents multiple modifications to the current indexing attempt row
from multiple docfetching/docprocessing tasks.
"""
return f"da_lock:indexing:db_{self.cc_pair_id}/{search_settings_id}"

View File

@@ -1,126 +1,10 @@
from datetime import datetime
from typing import cast
import redis
from pydantic import BaseModel
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
class RedisConnectorIndexPayload(BaseModel):
index_attempt_id: int | None
started: datetime | None
submitted: datetime
celery_task_id: str | None
class RedisConnectorIndex:
"""Manages interactions with redis for indexing tasks. Should only be accessed
through RedisConnector."""
PREFIX = "connectorindexing"
FENCE_PREFIX = f"{PREFIX}_fence" # "connectorindexing_fence"
GENERATOR_TASK_PREFIX = PREFIX + "+generator" # "connectorindexing+generator_fence"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # connectorindexing_generator_progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # connectorindexing_generator_complete
GENERATOR_LOCK_PREFIX = "da_lock:indexing:docfetching"
FILESTORE_LOCK_PREFIX = "da_lock:indexing:filestore"
DB_LOCK_PREFIX = "da_lock:indexing:db"
PER_WORKER_LOCK_PREFIX = "da_lock:indexing:per_worker"
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
TERMINATE_TTL = 600
# used to signal the overall workflow is still active
# it's impossible to get the exact state of the system at a single point in time
# so we need a signal with a TTL to bridge gaps in our checks
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = 3600
# used to signal that the watchdog is running
WATCHDOG_PREFIX = PREFIX + "_watchdog"
WATCHDOG_TTL = 300
# used to signal that the connector itself is still running
CONNECTOR_ACTIVE_PREFIX = PREFIX + "_connector_active"
CONNECTOR_ACTIVE_TTL = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
def __init__(
self,
tenant_id: str,
cc_pair_id: int,
search_settings_id: int,
redis: redis.Redis,
) -> None:
self.tenant_id: str = tenant_id
self.cc_pair_id = cc_pair_id
self.search_settings_id = search_settings_id
self.redis = redis
self.generator_complete_key = (
f"{self.GENERATOR_COMPLETE_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.filestore_lock_key = (
f"{self.FILESTORE_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.generator_lock_key = (
f"{self.GENERATOR_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.per_worker_lock_key = (
f"{self.PER_WORKER_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.db_lock_key = f"{self.DB_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
self.terminate_key = (
f"{self.TERMINATE_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
def set_generator_complete(self, payload: int | None) -> None:
if not payload:
self.redis.delete(self.generator_complete_key)
return
self.redis.set(self.generator_complete_key, payload)
def generator_clear(self) -> None:
self.redis.delete(self.generator_complete_key)
def get_completion(self) -> int | None:
bytes = self.redis.get(self.generator_complete_key)
if bytes is None:
return None
status = int(cast(int, bytes))
return status
def reset(self) -> None:
self.redis.delete(self.filestore_lock_key)
self.redis.delete(self.db_lock_key)
self.redis.delete(self.generator_lock_key)
self.redis.delete(self.generator_complete_key)
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
# leaving these temporarily for backwards compat, TODO: remove
for key in r.scan_iter(RedisConnectorIndex.CONNECTOR_ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.FILESTORE_LOCK_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
r.delete(key)

View File

@@ -1,6 +1,5 @@
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_usergroup import RedisUserGroup
@@ -16,8 +15,6 @@ def is_fence(key_bytes: bytes) -> bool:
return True
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
return True
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
return True
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
return True

View File

@@ -21,6 +21,7 @@ from onyx.db.connector import create_connector
from onyx.db.connector_credential_pair import add_credential_to_connector
from onyx.db.credentials import PUBLIC_CREDENTIAL_ID
from onyx.db.document import check_docs_exist
from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.index_attempt import mock_successful_index_attempt
@@ -264,5 +265,13 @@ def seed_initial_documents(
.values(chunk_count=doc.chunk_count)
)
# Since we bypass the indexing flow, we need to manually mark the document as indexed
mark_document_as_indexed_for_cc_pair__no_commit(
connector_id=connector_id,
credential_id=PUBLIC_CREDENTIAL_ID,
document_ids=[doc.id for doc in docs],
db_session=db_session,
)
db_session.commit()
kv_store.store(KV_DOCUMENTS_SEEDED_KEY, True)

View File

@@ -1,3 +1,4 @@
import io
import json
import mimetypes
import os
@@ -101,8 +102,9 @@ from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import User
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.file_processing.extract_file_text import convert_docx_to_txt
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.documents.models import AuthStatus
from onyx.server.documents.models import AuthUrl
@@ -124,6 +126,7 @@ from onyx.server.documents.models import IndexAttemptSnapshot
from onyx.server.documents.models import ObjectCreationIdResponse
from onyx.server.documents.models import RunConnectorRequest
from onyx.server.models import StatusResponse
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -438,7 +441,9 @@ def is_zip_file(file: UploadFile) -> bool:
)
def upload_files(files: list[UploadFile]) -> FileUploadResponse:
def upload_files(
files: list[UploadFile], file_origin: FileOrigin = FileOrigin.CONNECTOR
) -> FileUploadResponse:
for file in files:
if not file.filename:
raise HTTPException(status_code=400, detail="File name cannot be empty")
@@ -487,12 +492,17 @@ def upload_files(files: list[UploadFile]) -> FileUploadResponse:
# For mypy, actual check happens at start of function
assert file.filename is not None
# Special handling for docx files - only store the plaintext version
if file.content_type and file.content_type.startswith(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
docx_file_id = convert_docx_to_txt(file, file_store)
deduped_file_paths.append(docx_file_id)
# Special handling for doc files - only store the plaintext version
file_type = mime_type_to_chat_file_type(file.content_type)
if file_type == ChatFileType.DOC:
extracted_text = extract_file_text(file.file, file.filename or "")
text_file_id = file_store.save_file(
content=io.BytesIO(extracted_text.encode()),
display_name=file.filename,
file_origin=file_origin,
file_type="text/plain",
)
deduped_file_paths.append(text_file_id)
deduped_file_names.append(file.filename)
continue
@@ -520,7 +530,7 @@ def upload_files_api(
files: list[UploadFile],
_: User = Depends(current_curator_or_admin_user),
) -> FileUploadResponse:
return upload_files(files)
return upload_files(files, FileOrigin.OTHER)
@router.get("/admin/connector")

View File

@@ -1,7 +1,12 @@
import json
from fastapi import APIRouter
from fastapi import Depends
from fastapi import File
from fastapi import Form
from fastapi import HTTPException
from fastapi import Query
from fastapi import UploadFile
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
@@ -27,6 +32,9 @@ from onyx.server.documents.models import CredentialDataUpdateRequest
from onyx.server.documents.models import CredentialSnapshot
from onyx.server.documents.models import CredentialSwapRequest
from onyx.server.documents.models import ObjectCreationIdResponse
from onyx.server.documents.private_key_types import FILE_TYPE_TO_FILE_PROCESSOR
from onyx.server.documents.private_key_types import PrivateKeyFileTypes
from onyx.server.documents.private_key_types import ProcessPrivateKeyFileProtocol
from onyx.server.models import StatusResponse
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -76,6 +84,7 @@ def get_cc_source_full_info(
document_source=source_type,
get_editable=get_editable,
)
return [
CredentialSnapshot.from_credential_db_model(credential)
for credential in credentials
@@ -149,6 +158,70 @@ def create_credential_from_model(
)
@router.post("/credential/private-key")
def create_credential_with_private_key(
credential_json: str = Form(...),
admin_public: bool = Form(False),
curator_public: bool = Form(False),
groups: list[int] = Form([]),
name: str | None = Form(None),
source: str = Form(...),
user: User | None = Depends(current_curator_or_admin_user),
uploaded_file: UploadFile = File(...),
field_key: str = Form(...),
type_definition_key: str = Form(...),
db_session: Session = Depends(get_session),
) -> ObjectCreationIdResponse:
try:
credential_data = json.loads(credential_json)
except json.JSONDecodeError as e:
raise HTTPException(
status_code=400,
detail=f"Invalid JSON in credential_json: {str(e)}",
)
private_key_processor: ProcessPrivateKeyFileProtocol | None = (
FILE_TYPE_TO_FILE_PROCESSOR.get(PrivateKeyFileTypes(type_definition_key))
)
if private_key_processor is None:
raise HTTPException(
status_code=400,
detail="Invalid type definition key for private key file",
)
private_key_content: str = private_key_processor(uploaded_file)
credential_data[field_key] = private_key_content
credential_info = CredentialBase(
credential_json=credential_data,
admin_public=admin_public,
curator_public=curator_public,
groups=groups,
name=name,
source=DocumentSource(source),
)
if not _ignore_credential_permissions(DocumentSource(source)):
fetch_ee_implementation_or_noop(
"onyx.db.user_group", "validate_object_creation_for_user", None
)(
db_session=db_session,
user=user,
target_group_ids=groups,
object_is_public=curator_public,
)
# Temporary fix for empty Google App credentials
if DocumentSource(source) == DocumentSource.GMAIL:
cleanup_gmail_credentials(db_session=db_session)
credential = create_credential(credential_info, user, db_session)
return ObjectCreationIdResponse(
id=credential.id,
credential=CredentialSnapshot.from_credential_db_model(credential),
)
"""Endpoints for all"""
@@ -209,6 +282,53 @@ def update_credential_data(
return CredentialSnapshot.from_credential_db_model(credential)
@router.put("/admin/credential/private-key/{credential_id}")
def update_credential_private_key(
credential_id: int,
name: str = Form(...),
credential_json: str = Form(...),
uploaded_file: UploadFile = File(...),
field_key: str = Form(...),
type_definition_key: str = Form(...),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> CredentialBase:
try:
credential_data = json.loads(credential_json)
except json.JSONDecodeError as e:
raise HTTPException(
status_code=400,
detail=f"Invalid JSON in credential_json: {str(e)}",
)
private_key_processor: ProcessPrivateKeyFileProtocol | None = (
FILE_TYPE_TO_FILE_PROCESSOR.get(PrivateKeyFileTypes(type_definition_key))
)
if private_key_processor is None:
raise HTTPException(
status_code=400,
detail="Invalid type definition key for private key file",
)
private_key_content: str = private_key_processor(uploaded_file)
credential_data[field_key] = private_key_content
credential = alter_credential(
credential_id,
name,
credential_data,
user,
db_session,
)
if credential is None:
raise HTTPException(
status_code=401,
detail=f"Credential {credential_id} does not exist or does not belong to user",
)
return CredentialSnapshot.from_credential_db_model(credential)
@router.patch("/credential/{credential_id}")
def update_credential_from_model(
credential_id: int,

View File

@@ -0,0 +1,75 @@
from cryptography.hazmat.primitives.serialization import pkcs12
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _is_password_related_error(error: Exception) -> bool:
"""
Check if the exception indicates a password-related issue rather than a format issue.
"""
error_msg = str(error).lower()
password_keywords = ["mac", "integrity", "password", "authentication", "verify"]
return any(keyword in error_msg for keyword in password_keywords)
def validate_pkcs12_content(file_bytes: bytes) -> bool:
"""
Validate that the file content is actually a PKCS#12 file.
This performs basic format validation without requiring passwords.
"""
try:
# Basic file size check
if len(file_bytes) < 10:
logger.debug("File too small to be a valid PKCS#12 file")
return False
# Check for PKCS#12 magic bytes/ASN.1 structure
# PKCS#12 files start with ASN.1 SEQUENCE tag (0x30)
if file_bytes[0] != 0x30:
logger.debug("File does not start with ASN.1 SEQUENCE tag")
return False
# Try to parse the outer ASN.1 structure without password validation
# This checks if the file has the basic PKCS#12 structure
try:
# Attempt to load just to validate the basic format
# We expect this to fail due to password, but it should fail with a specific error
pkcs12.load_key_and_certificates(file_bytes, password=None)
return True
except ValueError as e:
# Check if the error is related to password (expected) vs format issues
if _is_password_related_error(e):
# These errors indicate the file format is correct but password is wrong/missing
logger.debug(
f"PKCS#12 format appears valid, password-related error: {e}"
)
return True
else:
# Other ValueError likely indicates format issues
logger.debug(f"PKCS#12 format validation failed: {e}")
return False
except Exception as e:
# Try with empty password as fallback
try:
pkcs12.load_key_and_certificates(file_bytes, password=b"")
return True
except ValueError as e2:
if _is_password_related_error(e2):
logger.debug(
f"PKCS#12 format appears valid with empty password attempt: {e2}"
)
return True
else:
logger.debug(
f"PKCS#12 validation failed on both attempts: {e}, {e2}"
)
return False
except Exception:
logger.debug(f"PKCS#12 validation failed: {e}")
return False
except Exception as e:
logger.debug(f"Unexpected error during PKCS#12 validation: {e}")
return False

View File

@@ -0,0 +1,57 @@
import base64
from enum import Enum
from typing import Protocol
from fastapi import HTTPException
from fastapi import UploadFile
from onyx.server.documents.document_utils import validate_pkcs12_content
class ProcessPrivateKeyFileProtocol(Protocol):
def __call__(self, file: UploadFile) -> str:
"""
Accepts a file-like object, validates the file (e.g., checks extension and content),
and returns its contents as a base64-encoded string if valid.
Raises an exception if validation fails.
"""
...
class PrivateKeyFileTypes(Enum):
SHAREPOINT_PFX_FILE = "sharepoint_pfx_file"
def process_sharepoint_private_key_file(file: UploadFile) -> str:
"""
Process and validate a private key file upload.
Validates both the file extension and file content to ensure it's a valid PKCS#12 file.
Content validation prevents attacks that rely on file extension spoofing.
"""
# First check file extension (basic filter)
if not (file.filename and file.filename.lower().endswith(".pfx")):
raise HTTPException(
status_code=400, detail="Invalid file type. Only .pfx files are supported."
)
# Read file content for validation and processing
private_key_bytes = file.file.read()
# Validate file content to prevent extension spoofing attacks
if not validate_pkcs12_content(private_key_bytes):
raise HTTPException(
status_code=400,
detail="Invalid file content. The uploaded file does not appear to be a valid PKCS#12 (.pfx) file.",
)
# Convert to base64 if validation passes
pfx_64 = base64.b64encode(private_key_bytes).decode("ascii")
return pfx_64
FILE_TYPE_TO_FILE_PROCESSOR: dict[
PrivateKeyFileTypes, ProcessPrivateKeyFileProtocol
] = {
PrivateKeyFileTypes.SHAREPOINT_PFX_FILE: process_sharepoint_private_key_file,
}

View File

@@ -23,6 +23,7 @@ from sqlalchemy.orm import Session
from onyx.auth.email_utils import send_user_email_invite
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import remove_user_from_invited_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.noauth_user import fetch_no_auth_user
from onyx.auth.noauth_user import set_no_auth_user_preferences
@@ -367,15 +368,11 @@ def remove_invited_user(
db_session: Session = Depends(get_session),
) -> int:
tenant_id = get_current_tenant_id()
user_emails = get_invited_users()
remaining_users = [user for user in user_emails if user != user_email.user_email]
if MULTI_TENANT:
fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
)([user_email.user_email], tenant_id)
number_of_invited_users = write_invited_users(remaining_users)
number_of_invited_users = remove_user_from_invited_users(user_email.user_email)
try:
if MULTI_TENANT and not DEV_MODE:

View File

@@ -1,6 +1,5 @@
import asyncio
import datetime
import io
import json
import os
import time
@@ -31,7 +30,6 @@ from onyx.chat.prompt_builder.citations_prompt import (
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.chat_configs import HARD_DELETE_CHATS
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
@@ -63,9 +61,7 @@ from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
from onyx.db.user_documents import create_user_files
from onyx.file_processing.extract_file_text import docx_to_txt_filename
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_default_llms
@@ -717,106 +713,65 @@ def upload_files_for_chat(
):
raise HTTPException(
status_code=400,
detail="File size must be less than 20MB",
detail="Images must be less than 20MB",
)
file_store = get_default_file_store()
file_info: list[tuple[str, str | None, ChatFileType]] = []
for file in files:
file_type = mime_type_to_chat_file_type(file.content_type)
file_content = file.file.read() # Read the file content
# NOTE: Image conversion to JPEG used to be enforced here.
# This was removed to:
# 1. Preserve original file content for downloads
# 2. Maintain transparency in formats like PNG
# 3. Ameliorate issue with file conversion
file_content_io = io.BytesIO(file_content)
new_content_type = file.content_type
# Store the file normally
file_id = file_store.save_file(
content=file_content_io,
display_name=file.filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=new_content_type or file_type.value,
# 5) Create a user file for each uploaded file
user_files = create_user_files(files, RECENT_DOCS_FOLDER_ID, user, db_session)
for user_file in user_files:
# 6) Create connector
connector_base = ConnectorBase(
name=f"UserFile-{int(time.time())}",
source=DocumentSource.FILE,
input_type=InputType.LOAD_STATE,
connector_specific_config={
"file_locations": [user_file.file_id],
"file_names": [user_file.name],
"zip_metadata": {},
},
refresh_freq=None,
prune_freq=None,
indexing_start=None,
)
connector = create_connector(
db_session=db_session,
connector_data=connector_base,
)
# 4) If the file is a doc, extract text and store that separately
if file_type == ChatFileType.DOC:
# Re-wrap bytes in a fresh BytesIO so we start at position 0
extracted_text_io = io.BytesIO(file_content)
extracted_text = extract_file_text(
file=extracted_text_io, # use the bytes we already read
file_name=file.filename or "",
)
# 7) Create credential
credential_info = CredentialBase(
credential_json={},
admin_public=True,
source=DocumentSource.FILE,
curator_public=True,
groups=[],
name=f"UserFileCredential-{int(time.time())}",
is_user_file=True,
)
credential = create_credential(credential_info, user, db_session)
text_file_id = file_store.save_file(
content=io.BytesIO(extracted_text.encode()),
display_name=file.filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type="text/plain",
)
# Return the text file as the "main" file descriptor for doc types
file_info.append((text_file_id, file.filename, ChatFileType.PLAIN_TEXT))
else:
file_info.append((file_id, file.filename, file_type))
# 5) Create a user file for each uploaded file
user_files = create_user_files([file], RECENT_DOCS_FOLDER_ID, user, db_session)
for user_file in user_files:
# 6) Create connector
connector_base = ConnectorBase(
name=f"UserFile-{int(time.time())}",
source=DocumentSource.FILE,
input_type=InputType.LOAD_STATE,
connector_specific_config={
"file_locations": [user_file.file_id],
"file_names": [user_file.name],
"zip_metadata": {},
},
refresh_freq=None,
prune_freq=None,
indexing_start=None,
)
connector = create_connector(
db_session=db_session,
connector_data=connector_base,
)
# 7) Create credential
credential_info = CredentialBase(
credential_json={},
admin_public=True,
source=DocumentSource.FILE,
curator_public=True,
groups=[],
name=f"UserFileCredential-{int(time.time())}",
is_user_file=True,
)
credential = create_credential(credential_info, user, db_session)
# 8) Create connector credential pair
cc_pair = add_credential_to_connector(
db_session=db_session,
user=user,
connector_id=connector.id,
credential_id=credential.id,
cc_pair_name=f"UserFileCCPair-{int(time.time())}",
access_type=AccessType.PRIVATE,
auto_sync_options=None,
groups=[],
)
user_file.cc_pair_id = cc_pair.data
db_session.commit()
# 8) Create connector credential pair
cc_pair = add_credential_to_connector(
db_session=db_session,
user=user,
connector_id=connector.id,
credential_id=credential.id,
cc_pair_name=f"UserFileCCPair-{int(time.time())}",
access_type=AccessType.PRIVATE,
auto_sync_options=None,
groups=[],
)
user_file.cc_pair_id = cc_pair.data
db_session.commit()
return {
"files": [
{"id": file_id, "type": file_type, "name": file_name}
for file_id, file_name, file_type in file_info
{
"id": user_file.file_id,
"type": mime_type_to_chat_file_type(user_file.content_type),
"name": user_file.name,
}
for user_file in user_files
]
}

View File

@@ -6,6 +6,7 @@ from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.configs.app_configs import OKTA_PROFILE_TOOL_ENABLED
from onyx.db.models import Persona
from onyx.db.models import Tool as ToolDBModel
from onyx.tools.tool_implementations.images.image_generation_tool import (
@@ -17,6 +18,9 @@ from onyx.tools.tool_implementations.internet_search.internet_search_tool import
from onyx.tools.tool_implementations.internet_search.providers import (
get_available_providers,
)
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
OktaProfileTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool import Tool
from onyx.utils.logger import setup_logger
@@ -63,6 +67,19 @@ BUILT_IN_TOOLS: list[InCodeToolInfo] = [
if (bool(get_available_providers()))
else []
),
# Show Okta Profile tool if the environment variables are set
*(
[
InCodeToolInfo(
cls=OktaProfileTool,
description="The Okta Profile Action allows the assistant to fetch user information from Okta.",
in_code_tool_id=OktaProfileTool.__name__,
display_name=OktaProfileTool._DISPLAY_NAME,
)
]
if OKTA_PROFILE_TOOL_ENABLED
else []
),
]

View File

@@ -5,12 +5,12 @@ from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.utils.special_types import JSON_ro
if TYPE_CHECKING:
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolResponse
@@ -53,8 +53,8 @@ class Tool(abc.ABC, Generic[OVERRIDE_T]):
def get_args_for_non_tool_calling_llm(
self,
query: str,
history: list[PreviousMessage],
llm: LLM,
history: list["PreviousMessage"],
llm: "LLM",
force_run: bool = False,
) -> dict[str, Any] | None:
raise NotImplementedError

View File

@@ -14,6 +14,10 @@ from onyx.configs.app_configs import AZURE_DALLE_API_KEY
from onyx.configs.app_configs import AZURE_DALLE_API_VERSION
from onyx.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from onyx.configs.app_configs import IMAGE_MODEL_NAME
from onyx.configs.app_configs import OAUTH_CLIENT_ID
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
from onyx.configs.app_configs import OKTA_API_TOKEN
from onyx.configs.app_configs import OPENID_CONFIG_URL
from onyx.configs.chat_configs import NUM_INTERNET_SEARCH_CHUNKS
from onyx.configs.chat_configs import NUM_INTERNET_SEARCH_RESULTS
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
@@ -41,6 +45,9 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
OktaProfileTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import compute_all_tool_tokens
from onyx.tools.utils import explicit_tool_calling_supported
@@ -265,6 +272,33 @@ def construct_tools(
"Internet search tool requires a Bing or Exa API key, please contact your Onyx admin to get it added!"
)
# Handle Okta Profile Tool
elif tool_cls.__name__ == OktaProfileTool.__name__:
if not user_oauth_token:
raise ValueError(
"Okta Profile Tool requires user OAuth token but none found"
)
if not all([OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL]):
raise ValueError(
"Okta Profile Tool requires OAuth configuration to be set"
)
if not OKTA_API_TOKEN:
raise ValueError(
"Okta Profile Tool requires OKTA_API_TOKEN to be set"
)
tool_dict[db_tool_model.id] = [
OktaProfileTool(
access_token=user_oauth_token,
client_id=OAUTH_CLIENT_ID,
client_secret=OAUTH_CLIENT_SECRET,
openid_config_url=OPENID_CONFIG_URL,
okta_api_token=OKTA_API_TOKEN,
)
]
# Handle custom tools
elif db_tool_model.openapi_schema:
if not custom_tool_config:

View File

@@ -0,0 +1,243 @@
import json
from collections.abc import Generator
from typing import Any
from urllib.parse import urlparse
import requests
from pydantic import BaseModel
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import message_to_string
from onyx.prompts.constants import GENERAL_SEP_PAT
from onyx.tools.base_tool import BaseTool
from onyx.tools.models import ToolResponse
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
logger = setup_logger()
OKTA_PROFILE_RESPONSE_ID = "okta_profile"
OKTA_TOOL_DESCRIPTION = """
The Okta profile tool can retrieve user profile information from Okta including:
- User ID, status, creation date
- Profile details like name, email, department, location, title, manager, and more
- Account status and activity
"""
class OIDCConfig(BaseModel):
issuer: str
jwks_uri: str | None = None
userinfo_endpoint: str | None = None
introspection_endpoint: str | None = None
token_endpoint: str | None = None
class OktaProfileTool(BaseTool):
_NAME = "get_okta_profile"
_DESCRIPTION = "This tool is used to get the user's profile information."
_DISPLAY_NAME = "Okta Profile"
def __init__(
self,
access_token: str,
client_id: str,
client_secret: str,
openid_config_url: str,
okta_api_token: str,
request_timeout_sec: int = 15,
) -> None:
self.access_token = access_token
self.client_id = client_id
self.client_secret = client_secret
self.openid_config_url = openid_config_url
self.request_timeout_sec = request_timeout_sec
# Extract Okta org URL from OpenID config URL using URL parsing
# OpenID config URL format: https://{org}.okta.com/.well-known/openid_configuration
parsed_url = urlparse(self.openid_config_url)
self.okta_org_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
self.okta_api_token = okta_api_token
self._oidc_config: OIDCConfig | None = None
@property
def name(self) -> str:
return self._NAME
@property
def description(self) -> str:
return self._DESCRIPTION
@property
def display_name(self) -> str:
return self._DISPLAY_NAME
def tool_definition(self) -> dict:
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {"type": "object", "properties": {}, "required": []},
},
}
def _load_oidc_config(self) -> OIDCConfig:
if self._oidc_config is not None:
return self._oidc_config
resp = requests.get(self.openid_config_url, timeout=self.request_timeout_sec)
resp.raise_for_status()
data = resp.json()
self._oidc_config = OIDCConfig(**data)
logger.debug(f"Loaded OIDC config from {self.openid_config_url}")
return self._oidc_config
def _call_userinfo(self, access_token: str) -> dict[str, Any] | None:
try:
cfg = self._load_oidc_config()
if not cfg.userinfo_endpoint:
logger.info("OIDC config missing userinfo_endpoint")
return None
headers = {"Authorization": f"Bearer {access_token}"}
r = requests.get(
cfg.userinfo_endpoint, headers=headers, timeout=self.request_timeout_sec
)
if r.status_code == 200:
return r.json()
logger.info(
f"userinfo call returned status {r.status_code}: {r.text[:200]}"
)
return None
except requests.RequestException as e:
logger.debug(f"userinfo request failed: {e}")
return None
def _call_introspection(self, access_token: str) -> dict[str, Any] | None:
try:
cfg = self._load_oidc_config()
if not cfg.introspection_endpoint:
logger.info("OIDC config missing introspection_endpoint")
return None
data = {
"token": access_token,
"token_type_hint": "access_token",
}
auth: tuple[str, str] | None = (self.client_id, self.client_secret)
r = requests.post(
cfg.introspection_endpoint,
data=data,
auth=auth,
headers={"Accept": "application/json"},
timeout=self.request_timeout_sec,
)
if r.status_code == 200:
return r.json()
logger.info(
f"introspection call returned status {r.status_code}: {r.text[:200]}"
)
return None
except requests.RequestException as e:
logger.debug(f"introspection request failed: {e}")
return None
def _call_users_api(self, uid: str) -> dict[str, Any]:
"""Call Okta Users API to fetch full user profile.
Requires okta_org_url and okta_api_token to be set. Raises exception on any error.
"""
if not self.okta_org_url or not self.okta_api_token:
raise ValueError(
"Okta org URL and API token are required for user profile lookup"
)
try:
url = f"{self.okta_org_url.rstrip('/')}/api/v1/users/{uid}"
headers = {"Authorization": f"SSWS {self.okta_api_token}"}
r = requests.get(url, headers=headers, timeout=self.request_timeout_sec)
if r.status_code == 200:
return r.json()
raise ValueError(
f"Okta Users API call failed with status {r.status_code}: {r.text[:200]}"
)
except requests.RequestException as e:
raise ValueError(f"Okta Users API request failed: {e}") from e
def build_tool_message_content(
self, *args: ToolResponse
) -> str | list[str | dict[str, Any]]:
# The tool emits a single aggregated packet; pass it through as compact JSON
profile = args[-1].response if args else {}
return json.dumps(profile)
def get_args_for_non_tool_calling_llm(
self,
query: str,
history: list[PreviousMessage],
llm: LLM,
force_run: bool = False,
) -> dict[str, Any] | None:
if force_run:
return {}
# Use LLM to determine if this tool should be called based on the query
prompt = f"""
You are helping to determine if an Okta profile lookup tool should be called based on a user's query.
{OKTA_TOOL_DESCRIPTION}
Query: {query}
Conversation history:
{GENERAL_SEP_PAT}
{history}
{GENERAL_SEP_PAT}
Should the Okta profile tool be called for this query? Respond with only "YES" or "NO".
""".strip()
response = llm.invoke(prompt)
if response and "YES" in message_to_string(response).upper():
return {}
return None
def run(
self, override_kwargs: None = None, **llm_kwargs: Any
) -> Generator[ToolResponse, None, None]:
# Try to get UID from userinfo first, then fallback to introspection
uid_candidate = None
# Try userinfo endpoint first
userinfo_data = self._call_userinfo(self.access_token)
if userinfo_data and isinstance(userinfo_data, dict):
uid_candidate = userinfo_data.get("uid")
# Only try introspection if userinfo didn't provide a UID
if not uid_candidate:
introspection_data = self._call_introspection(self.access_token)
if introspection_data and isinstance(introspection_data, dict):
uid_candidate = introspection_data.get("uid")
if not uid_candidate:
raise ValueError(
"Unable to fetch user profile from Okta. This likely means your Okta "
"token has expired. Please logout, log back in, and try again."
)
# Call Users API to get full profile - this is now required
users_api_data = self._call_users_api(uid_candidate)
yield ToolResponse(
id=OKTA_PROFILE_RESPONSE_ID, response=users_api_data["profile"]
)
def final_result(self, *args: ToolResponse) -> JSON_ro:
# Return the single aggregated profile packet
if not args:
return {}
return args[-1].response

View File

@@ -13,6 +13,7 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SLACK_CHANNEL_ID
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR
@@ -34,30 +35,6 @@ class LoggerContextVars:
doc_permission_sync_ctx.set(dict())
class TaskAttemptSingleton:
"""Used to tell if this process is an indexing job, and if so what is the
unique identifier for this indexing attempt. For things like the API server,
main background job (scheduler), etc. this will not be used."""
_INDEX_ATTEMPT_ID: None | int = None
_CONNECTOR_CREDENTIAL_PAIR_ID: None | int = None
@classmethod
def get_index_attempt_id(cls) -> None | int:
return cls._INDEX_ATTEMPT_ID
@classmethod
def get_connector_credential_pair_id(cls) -> None | int:
return cls._CONNECTOR_CREDENTIAL_PAIR_ID
@classmethod
def set_cc_and_index_id(
cls, index_attempt_id: int, connector_credential_pair_id: int
) -> None:
cls._INDEX_ATTEMPT_ID = index_attempt_id
cls._CONNECTOR_CREDENTIAL_PAIR_ID = connector_credential_pair_id
def get_log_level_from_str(log_level_str: str = LOG_LEVEL) -> int:
log_level_dict = {
"CRITICAL": logging.CRITICAL,
@@ -102,14 +79,12 @@ class OnyxLoggingAdapter(logging.LoggerAdapter):
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
break
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
if index_attempt_id is not None:
msg = f"[Index Attempt: {index_attempt_id}] {msg}"
if cc_pair_id is not None:
msg = f"[CC Pair: {cc_pair_id}] {msg}"
index_attempt_info = INDEX_ATTEMPT_INFO_CONTEXTVAR.get()
if index_attempt_info:
cc_pair_id, index_attempt_id = index_attempt_info
msg = (
f"[Index Attempt: {index_attempt_id}] [CC Pair: {cc_pair_id}] {msg}"
)
break
@@ -230,7 +205,7 @@ def setup_logger(
log_levels = ["debug", "info", "notice"]
for level in log_levels:
file_name = (
f"/var/log/{LOG_FILE_NAME}_{level}.log"
f"/var/log/onyx/{LOG_FILE_NAME}_{level}.log"
if is_containerized
else f"./log/{LOG_FILE_NAME}_{level}.log"
)

View File

@@ -0,0 +1,26 @@
def pass_aws_key(api_key: str) -> tuple[str, str, str]:
"""Parse AWS API key string into components.
Args:
api_key: String in format 'aws_ACCESSKEY_SECRETKEY_REGION'
Returns:
Tuple of (access_key, secret_key, region)
Raises:
ValueError: If key format is invalid
"""
if not api_key.startswith("aws"):
raise ValueError("API key must start with 'aws' prefix")
parts = api_key.split("_")
if len(parts) != 4:
raise ValueError(
f"API key must be in format 'aws_ACCESSKEY_SECRETKEY_REGION', got {len(parts) - 1} parts. "
"This is an onyx specific format for formatting the aws secrets for bedrock"
)
try:
_, aws_access_key_id, aws_secret_access_key, aws_region = parts
return aws_access_key_id, aws_secret_access_key, aws_region
except Exception as e:
raise ValueError(f"Failed to parse AWS key components: {str(e)}")

View File

@@ -44,12 +44,12 @@ litellm==1.72.2
lxml==5.3.0
lxml_html_clean==0.2.2
Mako==1.2.4
markitdown[pdf, docx, pptx, xlsx, xls]==0.1.2
msal==1.28.0
nltk==3.9.1
Office365-REST-Python-Client==2.5.9
oauthlib==3.2.2
openai==1.75.0
openpyxl==3.0.10
passlib==1.7.4
playwright==1.41.2
psutil==5.9.5
@@ -66,7 +66,7 @@ pypdf==5.4.0
pytest-mock==3.12.0
pytest-playwright==0.7.0
python-docx==1.1.2
python-dotenv==1.0.0
python-dotenv==1.1.1
python-multipart==0.0.20
pywikibot==9.0.0
redis==5.0.8
@@ -101,3 +101,5 @@ prometheus_client==0.21.0
fastapi-limiter==0.1.6
prometheus_fastapi_instrumentator==7.1.0
sendgrid==6.11.0
voyageai==0.2.3
cohere==5.6.1

View File

@@ -22,7 +22,6 @@ from onyx.configs.app_configs import REDIS_SSL
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_pool import RedisPool
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
@@ -130,9 +129,6 @@ def onyx_redis(
logger.info(f"Purging locks associated with deleting cc_pair={cc_pair_id}.")
redis_connector = RedisConnector(tenant_id, cc_pair_id)
match_pattern = f"{tenant_id}:{RedisConnectorIndex.FENCE_PREFIX}_{cc_pair_id}/*"
purge_by_match_and_type(match_pattern, "string", batch, dry_run, r)
redis_delete_if_exists_helper(
f"{tenant_id}:{redis_connector.prune.fence_key}", dry_run, r
)

View File

@@ -21,6 +21,11 @@ ONYX_REQUEST_ID_CONTEXTVAR: contextvars.ContextVar[str | None] = contextvars.Con
"onyx_request_id", default=None
)
# Used to store cc pair id and index attempt id in multithreaded environments
INDEX_ATTEMPT_INFO_CONTEXTVAR: contextvars.ContextVar[tuple[int, int] | None] = (
contextvars.ContextVar("index_attempt_info", default=None)
)
"""Utils related to contextvars"""

View File

@@ -11,6 +11,7 @@ import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.salesforce.connector import SalesforceConnector
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
def extract_key_value_pairs_to_set(
@@ -35,7 +36,7 @@ def _load_reference_data(
@pytest.fixture
def salesforce_connector() -> SalesforceConnector:
connector = SalesforceConnector(
requested_objects=["Account", "Contact", "Opportunity"],
requested_objects=[ACCOUNT_OBJECT_TYPE, "Contact", "Opportunity"],
)
username = os.environ["SF_USERNAME"]

View File

@@ -1,14 +1,20 @@
import os
import time
from dataclasses import dataclass
from datetime import datetime
from datetime import timezone
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.sharepoint.connector import SharepointConnector
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
# NOTE: Sharepoint site for tests is "sharepoint-tests"
@dataclass
@@ -43,6 +49,24 @@ EXPECTED_DOCUMENTS = [
),
]
EXPECTED_PAGES = [
ExpectedDocument(
semantic_identifier="CollabHome",
content=(
"# Home\n\nDisplay recent news.\n\n## News\n\nShow recent activities from your site\n\n"
"## Site activity\n\n## Quick links\n\nLearn about a team site\n\nLearn how to add a page\n\n"
"Add links to important documents and pages.\n\n## Quick links\n\nDocuments\n\n"
"Add a document library\n\n## Document library"
),
folder_path=None,
),
ExpectedDocument(
semantic_identifier="Home",
content="# Home",
folder_path=None,
),
]
def verify_document_metadata(doc: Document) -> None:
"""Verify common metadata that should be present on all documents."""
@@ -61,7 +85,7 @@ def verify_document_content(doc: Document, expected: ExpectedDocument) -> None:
assert doc.semantic_identifier == expected.semantic_identifier
assert len(doc.sections) == 1
assert doc.sections[0].text is not None
assert expected.content in doc.sections[0].text
assert expected.content == doc.sections[0].text
verify_document_metadata(doc)
@@ -76,6 +100,17 @@ def find_document(documents: list[Document], semantic_identifier: str) -> Docume
return matching_docs[0]
@pytest.fixture
def mock_store_image() -> MagicMock:
"""Mock store_image_and_create_section to return a predefined ImageSection."""
mock = MagicMock()
mock.return_value = (
ImageSection(image_file_id="mocked-file-id", link="https://example.com/image"),
"mocked-file-id",
)
return mock
@pytest.fixture
def sharepoint_credentials() -> dict[str, str]:
return {
@@ -87,199 +122,247 @@ def sharepoint_credentials() -> dict[str, str]:
def test_sharepoint_connector_all_sites__docs_only(
mock_get_unstructured_api_key: MagicMock,
mock_store_image: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with no sites
connector = SharepointConnector(include_site_pages=False)
with patch(
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
mock_store_image,
):
# Initialize connector with no sites
connector = SharepointConnector(
include_site_pages=False, include_site_documents=True
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Not asserting expected sites because that can change in test tenant at any time
# Finding any docs is good enough to verify that the connector is working
document_batches = list(connector.load_from_state())
assert document_batches, "Should find documents from all sites"
# Not asserting expected sites because that can change in test tenant at any time
# Finding any docs is good enough to verify that the connector is working
document_batches = load_all_docs_from_checkpoint_connector(
connector=connector,
start=0,
end=time.time(),
)
assert document_batches, "Should find documents from all sites"
def test_sharepoint_connector_all_sites__pages_only(
mock_get_unstructured_api_key: MagicMock,
mock_store_image: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
with patch(
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
mock_store_image,
):
# Initialize connector with no docs
connector = SharepointConnector(
include_site_pages=True, include_site_documents=False
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Not asserting expected sites because that can change in test tenant at any time
# Finding any docs is good enough to verify that the connector is working
document_batches = load_all_docs_from_checkpoint_connector(
connector=connector,
start=0,
end=time.time(),
)
assert document_batches, "Should find site pages from all sites"
def test_sharepoint_connector_specific_folder(
mock_get_unstructured_api_key: MagicMock,
mock_store_image: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the test site URL and specific folder
connector = SharepointConnector(
sites=[os.environ["SHAREPOINT_SITE"] + "/Shared Documents/test"]
)
with patch(
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
mock_store_image,
):
# Initialize connector with the test site URL and specific folder
connector = SharepointConnector(
sites=[os.environ["SHAREPOINT_SITE"] + "/Shared Documents/test"],
include_site_pages=False,
include_site_documents=True,
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Get all documents
document_batches = list(connector.load_from_state())
found_documents: list[Document] = [
doc for batch in document_batches for doc in batch
]
# Get all documents
found_documents: list[Document] = load_all_docs_from_checkpoint_connector(
connector=connector,
start=0,
end=time.time(),
)
# Should only find documents in the test folder
test_folder_docs = [
doc
for doc in EXPECTED_DOCUMENTS
if doc.folder_path and doc.folder_path.startswith("test")
]
assert len(found_documents) == len(
test_folder_docs
), "Should only find documents in test folder"
# Should only find documents in the test folder
test_folder_docs = [
doc
for doc in EXPECTED_DOCUMENTS
if doc.folder_path and doc.folder_path.startswith("test")
]
assert len(found_documents) == len(
test_folder_docs
), "Should only find documents in test folder"
# Verify each expected document
for expected in test_folder_docs:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)
# Verify each expected document
for expected in test_folder_docs:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)
def test_sharepoint_connector_root_folder__docs_only(
mock_get_unstructured_api_key: MagicMock,
mock_store_image: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the base site URL
connector = SharepointConnector(
sites=[os.environ["SHAREPOINT_SITE"]], include_site_pages=False
)
with patch(
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
mock_store_image,
):
# Initialize connector with the base site URL
connector = SharepointConnector(
sites=[os.environ["SHAREPOINT_SITE"]],
include_site_pages=False,
include_site_documents=True,
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Get all documents
document_batches = list(connector.load_from_state())
found_documents: list[Document] = [
doc for batch in document_batches for doc in batch
]
# Get all documents
found_documents: list[Document] = load_all_docs_from_checkpoint_connector(
connector=connector,
start=0,
end=time.time(),
)
assert len(found_documents) == len(
EXPECTED_DOCUMENTS
), "Should find all documents in main library"
assert len(found_documents) == len(
EXPECTED_DOCUMENTS
), "Should find all documents in main library"
# Verify each expected document
for expected in EXPECTED_DOCUMENTS:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)
# Verify each expected document
for expected in EXPECTED_DOCUMENTS:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)
def test_sharepoint_connector_other_library(
mock_get_unstructured_api_key: MagicMock,
mock_store_image: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the other library
connector = SharepointConnector(
sites=[
os.environ["SHAREPOINT_SITE"] + "/Other Library",
with patch(
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
mock_store_image,
):
# Initialize connector with the other library
connector = SharepointConnector(
sites=[
os.environ["SHAREPOINT_SITE"] + "/Other Library",
],
include_site_pages=False,
include_site_documents=True,
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Get all documents
found_documents: list[Document] = load_all_docs_from_checkpoint_connector(
connector=connector,
start=0,
end=time.time(),
)
expected_documents: list[ExpectedDocument] = [
doc for doc in EXPECTED_DOCUMENTS if doc.library == "Other Library"
]
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Should find all documents in `Other Library`
assert len(found_documents) == len(
expected_documents
), "Should find all documents in `Other Library`"
# Get all documents
document_batches = list(connector.load_from_state())
found_documents: list[Document] = [
doc for batch in document_batches for doc in batch
]
expected_documents: list[ExpectedDocument] = [
doc for doc in EXPECTED_DOCUMENTS if doc.library == "Other Library"
]
# Should find all documents in `Other Library`
assert len(found_documents) == len(
expected_documents
), "Should find all documents in `Other Library`"
# Verify each expected document
for expected in expected_documents:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)
# Verify each expected document
for expected in expected_documents:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)
def test_sharepoint_connector_poll(
mock_get_unstructured_api_key: MagicMock,
mock_store_image: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the base site URL
connector = SharepointConnector(
sites=["https://danswerai.sharepoint.com/sites/sharepoint-tests"]
)
with patch(
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
mock_store_image,
):
# Initialize connector with the base site URL
connector = SharepointConnector(sites=[os.environ["SHAREPOINT_SITE"]])
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Set time window to only capture test1.docx (modified at 2025-01-28 20:51:42+00:00)
start = datetime(2025, 1, 28, 20, 51, 30, tzinfo=timezone.utc) # 12 seconds before
end = datetime(2025, 1, 28, 20, 51, 50, tzinfo=timezone.utc) # 8 seconds after
# Set time window to only capture test1.docx (modified at 2025-01-28 20:51:42+00:00)
start = datetime(
2025, 1, 28, 20, 51, 30, tzinfo=timezone.utc
) # 12 seconds before
end = datetime(2025, 1, 28, 20, 51, 50, tzinfo=timezone.utc) # 8 seconds after
# Get documents within the time window
document_batches = list(connector._fetch_from_sharepoint(start=start, end=end))
found_documents: list[Document] = [
doc for batch in document_batches for doc in batch
]
# Get documents within the time window
found_documents: list[Document] = load_all_docs_from_checkpoint_connector(
connector=connector,
start=start.timestamp(),
end=end.timestamp(),
)
# Should only find test1.docx
assert len(found_documents) == 1, "Should only find one document in the time window"
doc = found_documents[0]
assert doc.semantic_identifier == "test1.docx"
verify_document_metadata(doc)
verify_document_content(
doc, [d for d in EXPECTED_DOCUMENTS if d.semantic_identifier == "test1.docx"][0]
)
# Should only find test1.docx
assert (
len(found_documents) == 1
), "Should only find one document in the time window"
doc = found_documents[0]
assert doc.semantic_identifier == "test1.docx"
verify_document_content(
doc,
next(
d for d in EXPECTED_DOCUMENTS if d.semantic_identifier == "test1.docx"
),
)
def test_sharepoint_connector_pages(
mock_get_unstructured_api_key: MagicMock,
mock_store_image: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the base site URL
connector = SharepointConnector(
sites=["https://danswerai.sharepoint.com/sites/sharepoint-tests-pages"]
)
with patch(
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
mock_store_image,
):
connector = SharepointConnector(
sites=[os.environ["SHAREPOINT_SITE"]],
include_site_pages=True,
include_site_documents=False,
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
connector.load_credentials(sharepoint_credentials)
# Get documents within the time window
document_batches = list(connector.load_from_state())
found_documents: list[Document] = [
doc for batch in document_batches for doc in batch
]
found_documents = load_all_docs_from_checkpoint_connector(
connector=connector,
start=0,
end=time.time(),
)
# Should only find CollabHome
assert len(found_documents) == 1, "Should only find one page"
doc = found_documents[0]
assert doc.semantic_identifier == "CollabHome"
verify_document_metadata(doc)
assert len(doc.sections) == 1
assert (
doc.sections[0].text
== """
# Home
assert len(found_documents) == len(
EXPECTED_PAGES
), "Should find all pages in test site"
Display recent news.
## News
Show recent activities from your site
## Site activity
## Quick links
Learn about a team site
Learn how to add a page
Add links to important documents and pages.
## Quick links
Documents
Add a document library
## Document library
""".strip()
)
for expected in EXPECTED_PAGES:
doc = find_document(found_documents, expected.semantic_identifier)
verify_document_content(doc, expected)

View File

@@ -0,0 +1,113 @@
import os
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import InputType
from onyx.connectors.sharepoint.connector import SharepointAuthMethod
from onyx.db.enums import AccessType
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.connector import ConnectorManager
from tests.integration.common_utils.managers.credential import CredentialManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestConnector
from tests.integration.common_utils.test_models import DATestCredential
from tests.integration.common_utils.test_models import DATestUser
SharepointTestEnvSetupTuple = tuple[
DATestUser, # admin_user
DATestUser, # regular_user_1
DATestUser, # regular_user_2
DATestCredential,
DATestConnector,
DATestCCPair,
]
@pytest.fixture(scope="module")
def sharepoint_test_env_setup() -> Generator[SharepointTestEnvSetupTuple]:
# Reset all data before running the test
reset_all()
# Required environment variables for SharePoint certificate authentication
sp_client_id = os.environ.get("PERM_SYNC_SHAREPOINT_CLIENT_ID")
sp_private_key = os.environ.get("PERM_SYNC_SHAREPOINT_PRIVATE_KEY")
sp_certificate_password = os.environ.get(
"PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD"
)
sp_directory_id = os.environ.get("PERM_SYNC_SHAREPOINT_DIRECTORY_ID")
sharepoint_sites = "https://danswerai.sharepoint.com/sites/Permisisonsync"
admin_email = "admin@onyx.app"
user1_email = "subash@onyx.app"
user2_email = "raunak@onyx.app"
if not sp_private_key or not sp_certificate_password or not sp_directory_id:
pytest.skip("Skipping test because required environment variables are not set")
# Certificate-based credentials
credentials = {
"authentication_method": SharepointAuthMethod.CERTIFICATE.value,
"sp_client_id": sp_client_id,
"sp_private_key": sp_private_key,
"sp_certificate_password": sp_certificate_password,
"sp_directory_id": sp_directory_id,
}
# Create users
admin_user: DATestUser = UserManager.create(email=admin_email)
regular_user_1: DATestUser = UserManager.create(email=user1_email)
regular_user_2: DATestUser = UserManager.create(email=user2_email)
# Create LLM provider for search functionality
LLMProviderManager.create(user_performing_action=admin_user)
# Create credential
credential: DATestCredential = CredentialManager.create(
source=DocumentSource.SHAREPOINT,
credential_json=credentials,
user_performing_action=admin_user,
)
# Create connector with SharePoint-specific configuration
connector: DATestConnector = ConnectorManager.create(
name="SharePoint Test",
input_type=InputType.POLL,
source=DocumentSource.SHAREPOINT,
connector_specific_config={
"sites": sharepoint_sites.split(","),
},
access_type=AccessType.SYNC, # Enable permission sync
user_performing_action=admin_user,
)
# Create CC pair with permission sync enabled
cc_pair: DATestCCPair = CCPairManager.create(
credential_id=credential.id,
connector_id=connector.id,
access_type=AccessType.SYNC, # Enable permission sync
user_performing_action=admin_user,
)
# Wait for both indexing and permission sync to complete
before = datetime.now(tz=timezone.utc)
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
timeout=float("inf"),
)
# Wait for permission sync completion specifically
CCPairManager.wait_for_sync(
cc_pair=cc_pair,
after=before,
user_performing_action=admin_user,
timeout=float("inf"),
)
yield admin_user, regular_user_1, regular_user_2, credential, connector, cc_pair

View File

@@ -0,0 +1,214 @@
import os
from typing import List
from uuid import UUID
import pytest
from sqlalchemy.orm import Session
from ee.onyx.access.access import _get_access_for_documents
from ee.onyx.db.external_perm import fetch_external_groups_for_user
from onyx.access.utils import prefix_external_group
from onyx.access.utils import prefix_user_email
from onyx.configs.constants import PUBLIC_DOC_PAT
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.models import User
from onyx.db.users import fetch_user_by_id
from onyx.utils.logger import setup_logger
from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.connector_job_tests.sharepoint.conftest import (
SharepointTestEnvSetupTuple,
)
logger = setup_logger()
def get_user_acl(user: User, db_session: Session) -> set[str]:
db_external_groups = (
fetch_external_groups_for_user(db_session, user.id) if user else []
)
prefixed_external_groups = [
prefix_external_group(db_external_group.external_user_group_id)
for db_external_group in db_external_groups
]
user_acl = set(prefixed_external_groups)
user_acl.update({prefix_user_email(user.email), PUBLIC_DOC_PAT})
return user_acl
def get_user_document_access_via_acl(
test_user: DATestUser, document_ids: List[str], db_session: Session
) -> List[str]:
# Get the actual User object from the database
user = fetch_user_by_id(db_session, UUID(test_user.id))
if not user:
logger.error(f"Could not find user with ID {test_user.id}")
return []
user_acl = get_user_acl(user, db_session)
logger.info(f"User {user.email} ACL entries: {user_acl}")
# Get document access information
doc_access_map = _get_access_for_documents(document_ids, db_session)
logger.info(f"Found access info for {len(doc_access_map)} documents")
accessible_docs = []
for doc_id, doc_access in doc_access_map.items():
doc_acl = doc_access.to_acl()
logger.info(f"Document {doc_id} ACL: {doc_acl}")
# Check if user has any matching ACL entry
if user_acl.intersection(doc_acl):
accessible_docs.append(doc_id)
logger.info(f"User {user.email} has access to document {doc_id}")
else:
logger.info(f"User {user.email} does NOT have access to document {doc_id}")
return accessible_docs
def get_all_connector_documents(
cc_pair: DATestCCPair, db_session: Session
) -> List[str]:
from onyx.db.models import DocumentByConnectorCredentialPair
from sqlalchemy import select
stmt = select(DocumentByConnectorCredentialPair.id).where(
DocumentByConnectorCredentialPair.connector_id == cc_pair.connector_id,
DocumentByConnectorCredentialPair.credential_id == cc_pair.credential_id,
)
result = db_session.execute(stmt)
document_ids = [row[0] for row in result.fetchall()]
logger.info(
f"Found {len(document_ids)} documents for connector {cc_pair.connector_id}"
)
return document_ids
def get_documents_by_permission_type(
document_ids: List[str], db_session: Session
) -> List[str]:
"""
Categorize documents by their permission types
Returns a dictionary with lists of document IDs for each permission type
"""
doc_access_map = _get_access_for_documents(document_ids, db_session)
public_docs = []
for doc_id, doc_access in doc_access_map.items():
if doc_access.is_public:
public_docs.append(doc_id)
return public_docs
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission tests are enterprise only",
)
def test_public_documents_accessible_by_all_users(
sharepoint_test_env_setup: SharepointTestEnvSetupTuple,
) -> None:
"""Test that public documents are accessible by both test users using ACL verification"""
(
admin_user,
regular_user_1,
regular_user_2,
credential,
connector,
cc_pair,
) = sharepoint_test_env_setup
with get_session_with_current_tenant() as db_session:
# Get all documents for this connector
all_document_ids = get_all_connector_documents(cc_pair, db_session)
# Test that regular_user_1 can access documents
accessible_docs_user1 = get_user_document_access_via_acl(
test_user=regular_user_1,
document_ids=all_document_ids,
db_session=db_session,
)
# Test that regular_user_2 can access documents
accessible_docs_user2 = get_user_document_access_via_acl(
test_user=regular_user_2,
document_ids=all_document_ids,
db_session=db_session,
)
logger.info(f"User 1 has access to {len(accessible_docs_user1)} documents")
logger.info(f"User 2 has access to {len(accessible_docs_user2)} documents")
# For public documents, both users should have access to at least some docs
assert len(accessible_docs_user1) == 8, (
f"User 1 should have access to documents. Found "
f"{len(accessible_docs_user1)} accessible docs out of "
f"{len(all_document_ids)} total"
)
assert len(accessible_docs_user2) == 1, (
f"User 2 should have access to documents. Found "
f"{len(accessible_docs_user2)} accessible docs out of "
f"{len(all_document_ids)} total"
)
logger.info(
"Successfully verified public documents are accessible by users via ACL"
)
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission tests are enterprise only",
)
def test_group_based_permissions(
sharepoint_test_env_setup: SharepointTestEnvSetupTuple,
) -> None:
"""Test that documents with group permissions are accessible only by users in that group using ACL verification"""
(
admin_user,
regular_user_1,
regular_user_2,
credential,
connector,
cc_pair,
) = sharepoint_test_env_setup
with get_session_with_current_tenant() as db_session:
# Get all documents for this connector
all_document_ids = get_all_connector_documents(cc_pair, db_session)
if not all_document_ids:
pytest.skip("No documents found for connector - skipping test")
# Test access for both users
accessible_docs_user1 = get_user_document_access_via_acl(
test_user=regular_user_1,
document_ids=all_document_ids,
db_session=db_session,
)
accessible_docs_user2 = get_user_document_access_via_acl(
test_user=regular_user_2,
document_ids=all_document_ids,
db_session=db_session,
)
logger.info(f"User 1 has access to {len(accessible_docs_user1)} documents")
logger.info(f"User 2 has access to {len(accessible_docs_user2)} documents")
public_docs = get_documents_by_permission_type(all_document_ids, db_session)
# Check if user 2 has access to any non-public documents
non_public_access_user2 = [
doc for doc in accessible_docs_user2 if doc not in public_docs
]
assert (
len(non_public_access_user2) == 0
), f"User 2 should only have access to public documents. Found access to non-public docs: {non_public_access_user2}"

Some files were not shown because too many files have changed in this diff Show More