mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-02 14:15:44 +00:00
Compare commits
58 Commits
refactor-m
...
v1.4.0-clo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5602ff8666 | ||
|
|
2fc70781b4 | ||
|
|
f76b4dec4c | ||
|
|
a5a516fa8a | ||
|
|
811a198134 | ||
|
|
5867ab1d7d | ||
|
|
dd6653eb1f | ||
|
|
db457ef432 | ||
|
|
de7fe939b2 | ||
|
|
38114d9542 | ||
|
|
32f20f2e2e | ||
|
|
3dd27099f7 | ||
|
|
91c4d43a80 | ||
|
|
a63ba1bb03 | ||
|
|
7b6189e74c | ||
|
|
ba423e5773 | ||
|
|
fe029eccae | ||
|
|
ea72af7698 | ||
|
|
17abf85533 | ||
|
|
3bd162acb9 | ||
|
|
664ce441eb | ||
|
|
6863fbee54 | ||
|
|
bb98088b80 | ||
|
|
ce8cb1112a | ||
|
|
a605bd4ca4 | ||
|
|
0e8b5af619 | ||
|
|
46f3af4f68 | ||
|
|
2af64ebf4c | ||
|
|
0eb1824158 | ||
|
|
e0a9a6fb66 | ||
|
|
fe194076c2 | ||
|
|
55dc24fd27 | ||
|
|
da02962a67 | ||
|
|
9bc62cc803 | ||
|
|
bf6705a9a5 | ||
|
|
df2fef3383 | ||
|
|
8cec3448d7 | ||
|
|
b81687995e | ||
|
|
87c2253451 | ||
|
|
297c2957b4 | ||
|
|
bacee0d09d | ||
|
|
297720c132 | ||
|
|
bd4bd00cef | ||
|
|
07c482f727 | ||
|
|
cf193dee29 | ||
|
|
1b47fa2700 | ||
|
|
e1a305d18a | ||
|
|
e2233d22c9 | ||
|
|
20d1175312 | ||
|
|
7117774287 | ||
|
|
77f2660bb2 | ||
|
|
1b2f4f3b87 | ||
|
|
d85b55a9d2 | ||
|
|
e2bae5a2d9 | ||
|
|
cc9c76c4fb | ||
|
|
258e08abcd | ||
|
|
67047e42a7 | ||
|
|
146628e734 |
31
.github/workflows/helm-chart-releases.yml
vendored
31
.github/workflows/helm-chart-releases.yml
vendored
@@ -18,23 +18,32 @@ jobs:
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Configure Git
|
||||
run: |
|
||||
git config user.name "$GITHUB_ACTOR"
|
||||
git config user.email "$GITHUB_ACTOR@users.noreply.github.com"
|
||||
|
||||
- name: Install Helm
|
||||
- name: Install Helm CLI
|
||||
uses: azure/setup-helm@v4
|
||||
with:
|
||||
version: v3.12.1
|
||||
|
||||
- name: Add Required Helm Repositories
|
||||
- name: Add required Helm repositories
|
||||
run: |
|
||||
helm repo add bitnami https://charts.bitnami.com/bitnami
|
||||
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo update
|
||||
|
||||
- name: Run chart-releaser
|
||||
uses: helm/chart-releaser-action@v1.7.0
|
||||
env:
|
||||
CR_TOKEN: "${{ secrets.GITHUB_TOKEN }}"
|
||||
- name: Build chart dependencies
|
||||
run: |
|
||||
set -euo pipefail
|
||||
for chart_dir in deployment/helm/charts/*; do
|
||||
if [ -f "$chart_dir/Chart.yaml" ]; then
|
||||
echo "Building dependencies for $chart_dir"
|
||||
helm dependency build "$chart_dir"
|
||||
fi
|
||||
done
|
||||
|
||||
- name: Publish Helm charts to gh-pages
|
||||
uses: stefanprodan/helm-gh-pages@v1.7.0
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
charts_dir: deployment/helm/charts
|
||||
branch: gh-pages
|
||||
commit_username: ${{ github.actor }}
|
||||
commit_email: ${{ github.actor }}@users.noreply.github.com
|
||||
20
.github/workflows/pr-helm-chart-testing.yml
vendored
20
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -55,7 +55,25 @@ jobs:
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
run: ct install --all --helm-extra-set-args="--set=nginx.enabled=false" --debug --config ct.yaml
|
||||
run: ct install --all \
|
||||
--helm-extra-set-args="\
|
||||
--set=nginx.enabled=false \
|
||||
--set=postgresql.enabled=false \
|
||||
--set=redis.enabled=false \
|
||||
--set=minio.enabled=false \
|
||||
--set=vespa.enabled=false \
|
||||
--set=slackbot.enabled=false \
|
||||
--set=api.replicaCount=0 \
|
||||
--set=inferenceCapability.replicaCount=0 \
|
||||
--set=indexCapability.replicaCount=0 \
|
||||
--set=celery_beat.replicaCount=0 \
|
||||
--set=celery_worker_heavy.replicaCount=0 \
|
||||
--set=celery_worker_docprocessing.replicaCount=0 \
|
||||
--set=celery_worker_light.replicaCount=0 \
|
||||
--set=celery_worker_monitoring.replicaCount=0 \
|
||||
--set=celery_worker_primary.replicaCount=0 \
|
||||
--set=celery_worker_user_files_indexing.replicaCount=0" \
|
||||
--debug --config ct.yaml
|
||||
# the following would install only changed charts, but we only have one chart so
|
||||
# don't worry about that for now
|
||||
# run: ct install --target-branch ${{ github.event.repository.default_branch }}
|
||||
|
||||
8
.github/workflows/pr-integration-tests.yml
vendored
8
.github/workflows/pr-integration-tests.yml
vendored
@@ -19,6 +19,10 @@ env:
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
PLATFORM_PAIR: linux-amd64
|
||||
|
||||
jobs:
|
||||
@@ -272,6 +276,10 @@ jobs:
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
|
||||
@@ -19,6 +19,10 @@ env:
|
||||
JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }}
|
||||
JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }}
|
||||
JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }}
|
||||
PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }}
|
||||
PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }}
|
||||
PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }}
|
||||
PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }}
|
||||
PLATFORM_PAIR: linux-amd64
|
||||
jobs:
|
||||
integration-tests-mit:
|
||||
@@ -207,6 +211,10 @@ jobs:
|
||||
-e JIRA_BASE_URL=${JIRA_BASE_URL} \
|
||||
-e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \
|
||||
-e JIRA_API_TOKEN=${JIRA_API_TOKEN} \
|
||||
-e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \
|
||||
-e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \
|
||||
-e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \
|
||||
-e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
|
||||
11
.gitignore
vendored
11
.gitignore
vendored
@@ -21,8 +21,19 @@ backend/tests/regression/search_quality/*.json
|
||||
# secret files
|
||||
.env
|
||||
jira_test_env
|
||||
settings.json
|
||||
|
||||
# others
|
||||
/deployment/data/nginx/app.conf
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
|
||||
# Local .terraform directories
|
||||
**/.terraform/*
|
||||
|
||||
# Local .tfstate files
|
||||
*.tfstate
|
||||
*.tfstate.*
|
||||
|
||||
# Local .terraform.lock.hcl file
|
||||
.terraform.lock.hcl
|
||||
|
||||
3
.vscode/env_template.txt
vendored
3
.vscode/env_template.txt
vendored
@@ -23,6 +23,9 @@ DISABLE_LLM_DOC_RELEVANCE=False
|
||||
# Useful if you want to toggle auth on/off (google_oauth/OIDC specifically)
|
||||
OAUTH_CLIENT_ID=<REPLACE THIS>
|
||||
OAUTH_CLIENT_SECRET=<REPLACE THIS>
|
||||
OPENID_CONFIG_URL=<REPLACE THIS>
|
||||
SAML_CONF_DIR=/<ABSOLUTE PATH TO ONYX>/onyx/backend/ee/onyx/configs/saml_config
|
||||
|
||||
# Generally not useful for dev, we don't generally want to set up an SMTP server for dev
|
||||
REQUIRE_EMAIL_VERIFICATION=False
|
||||
|
||||
|
||||
9
.vscode/launch.template.jsonc
vendored
9
.vscode/launch.template.jsonc
vendored
@@ -31,14 +31,16 @@
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
"stopAll": true
|
||||
},
|
||||
{
|
||||
"name": "Web / Model / API",
|
||||
"configurations": ["Web Server", "Model Server", "API Server"],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
"stopAll": true
|
||||
},
|
||||
{
|
||||
"name": "Celery (all)",
|
||||
@@ -53,7 +55,8 @@
|
||||
],
|
||||
"presentation": {
|
||||
"group": "1"
|
||||
}
|
||||
},
|
||||
"stopAll": true
|
||||
}
|
||||
],
|
||||
"configurations": [
|
||||
|
||||
@@ -103,10 +103,10 @@ If using PowerShell, the command slightly differs:
|
||||
Install the required python dependencies:
|
||||
|
||||
```bash
|
||||
pip install -r onyx/backend/requirements/default.txt
|
||||
pip install -r onyx/backend/requirements/dev.txt
|
||||
pip install -r onyx/backend/requirements/ee.txt
|
||||
pip install -r onyx/backend/requirements/model_server.txt
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install -r backend/requirements/ee.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
```
|
||||
|
||||
Install Playwright for Python (headless browser required by the Web Connector)
|
||||
|
||||
@@ -116,6 +116,14 @@ COPY ./assets /app/assets
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
# Create non-root user for security best practices
|
||||
RUN groupadd -g 1001 onyx && \
|
||||
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
|
||||
chown -R onyx:onyx /app && \
|
||||
mkdir -p /var/log/onyx && \
|
||||
chmod 755 /var/log/onyx && \
|
||||
chown onyx:onyx /var/log/onyx
|
||||
|
||||
# Default command which does nothing
|
||||
# This container is used by api server and background which specify their own CMD
|
||||
CMD ["tail", "-f", "/dev/null"]
|
||||
|
||||
@@ -9,11 +9,20 @@ visit https://github.com/onyx-dot-app/onyx."
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.0.0-dev
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
HF_HOME=/app/.cache/huggingface
|
||||
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
|
||||
# Create non-root user for security best practices
|
||||
RUN mkdir -p /app && \
|
||||
groupadd -g 1001 onyx && \
|
||||
useradd -u 1001 -g onyx -m -s /bin/bash onyx && \
|
||||
chown -R onyx:onyx /app && \
|
||||
mkdir -p /var/log/onyx && \
|
||||
chmod 755 /var/log/onyx && \
|
||||
chown onyx:onyx /var/log/onyx
|
||||
|
||||
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
--retries 5 \
|
||||
@@ -38,9 +47,11 @@ snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from sentence_transformers import SentenceTransformer; \
|
||||
SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);"
|
||||
|
||||
# In case the user has volumes mounted to /root/.cache/huggingface that they've downloaded while
|
||||
# running Onyx, don't overwrite it with the built in cache folder
|
||||
RUN mv /root/.cache/huggingface /root/.cache/temp_huggingface
|
||||
# In case the user has volumes mounted to /app/.cache/huggingface that they've downloaded while
|
||||
# running Onyx, move the current contents of the cache folder to a temporary location to ensure
|
||||
# it's preserved in order to combine with the user's cache contents
|
||||
RUN mv /app/.cache/huggingface /app/.cache/temp_huggingface && \
|
||||
chown -R onyx:onyx /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
341
backend/alembic/versions/90e3b9af7da4_tag_fix.py
Normal file
341
backend/alembic/versions/90e3b9af7da4_tag_fix.py
Normal file
@@ -0,0 +1,341 @@
|
||||
"""tag-fix
|
||||
|
||||
Revision ID: 90e3b9af7da4
|
||||
Revises: 62c3a055a141
|
||||
Create Date: 2025-08-01 20:58:14.607624
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from typing import cast
|
||||
from typing import Generator
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.db.search_settings import SearchSettings
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "90e3b9af7da4"
|
||||
down_revision = "62c3a055a141"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
SKIP_TAG_FIX = os.environ.get("SKIP_TAG_FIX", "true").lower() == "true"
|
||||
|
||||
# override for cloud
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
SKIP_TAG_FIX = True
|
||||
|
||||
|
||||
def set_is_list_for_known_tags() -> None:
|
||||
"""
|
||||
Sets is_list to true for all tags that are known to be lists.
|
||||
"""
|
||||
LIST_METADATA: list[tuple[str, str]] = [
|
||||
("CLICKUP", "tags"),
|
||||
("CONFLUENCE", "labels"),
|
||||
("DISCOURSE", "tags"),
|
||||
("FRESHDESK", "emails"),
|
||||
("GITHUB", "assignees"),
|
||||
("GITHUB", "labels"),
|
||||
("GURU", "tags"),
|
||||
("GURU", "folders"),
|
||||
("HUBSPOT", "associated_contact_ids"),
|
||||
("HUBSPOT", "associated_company_ids"),
|
||||
("HUBSPOT", "associated_deal_ids"),
|
||||
("HUBSPOT", "associated_ticket_ids"),
|
||||
("JIRA", "labels"),
|
||||
("MEDIAWIKI", "categories"),
|
||||
("ZENDESK", "labels"),
|
||||
("ZENDESK", "content_tags"),
|
||||
]
|
||||
|
||||
bind = op.get_bind()
|
||||
for source, key in LIST_METADATA:
|
||||
bind.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
UPDATE tag
|
||||
SET is_list = true
|
||||
WHERE tag_key = '{key}'
|
||||
AND source = '{source}'
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def set_is_list_for_list_tags() -> None:
|
||||
"""
|
||||
Sets is_list to true for all tags which have multiple values for a given
|
||||
document, key, and source triplet. This only works if we remove old tags
|
||||
from the database.
|
||||
"""
|
||||
bind = op.get_bind()
|
||||
bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE tag
|
||||
SET is_list = true
|
||||
FROM (
|
||||
SELECT DISTINCT tag.tag_key, tag.source
|
||||
FROM tag
|
||||
JOIN document__tag ON tag.id = document__tag.tag_id
|
||||
GROUP BY tag.tag_key, tag.source, document__tag.document_id
|
||||
HAVING count(*) > 1
|
||||
) AS list_tags
|
||||
WHERE tag.tag_key = list_tags.tag_key
|
||||
AND tag.source = list_tags.source
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def log_list_tags() -> None:
|
||||
bind = op.get_bind()
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT DISTINCT source, tag_key
|
||||
FROM tag
|
||||
WHERE is_list
|
||||
ORDER BY source, tag_key
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
logger.info(
|
||||
"List tags:\n" + "\n".join(f" {source}: {key}" for source, key in result)
|
||||
)
|
||||
|
||||
|
||||
def remove_old_tags() -> None:
|
||||
"""
|
||||
Removes old tags from the database.
|
||||
Previously, there was a bug where if a document got indexed with a tag and then
|
||||
the document got reindexed, the old tag would not be removed.
|
||||
This function removes those old tags by comparing it against the tags in vespa.
|
||||
"""
|
||||
current_search_settings, future_search_settings = active_search_settings()
|
||||
document_index = get_default_document_index(
|
||||
current_search_settings, future_search_settings
|
||||
)
|
||||
|
||||
# Get the index name
|
||||
if hasattr(document_index, "index_name"):
|
||||
index_name = document_index.index_name
|
||||
else:
|
||||
# Default index name if we can't get it from the document_index
|
||||
index_name = "danswer_index"
|
||||
|
||||
for batch in _get_batch_documents_with_multiple_tags():
|
||||
n_deleted = 0
|
||||
|
||||
for document_id in batch:
|
||||
true_metadata = _get_vespa_metadata(document_id, index_name)
|
||||
tags = _get_document_tags(document_id)
|
||||
|
||||
# identify document__tags to delete
|
||||
to_delete: list[str] = []
|
||||
for tag_id, tag_key, tag_value in tags:
|
||||
true_val = true_metadata.get(tag_key, "")
|
||||
if (isinstance(true_val, list) and tag_value not in true_val) or (
|
||||
isinstance(true_val, str) and tag_value != true_val
|
||||
):
|
||||
to_delete.append(str(tag_id))
|
||||
|
||||
if not to_delete:
|
||||
continue
|
||||
|
||||
# delete old document__tags
|
||||
bind = op.get_bind()
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
DELETE FROM document__tag
|
||||
WHERE document_id = '{document_id}'
|
||||
AND tag_id IN ({','.join(to_delete)})
|
||||
"""
|
||||
)
|
||||
)
|
||||
n_deleted += result.rowcount
|
||||
logger.info(f"Processed {len(batch)} documents and deleted {n_deleted} tags")
|
||||
|
||||
|
||||
def active_search_settings() -> tuple[SearchSettings, SearchSettings | None]:
|
||||
result = op.get_bind().execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT * FROM search_settings WHERE status = 'PRESENT' ORDER BY id DESC LIMIT 1
|
||||
"""
|
||||
)
|
||||
)
|
||||
search_settings_fetch = result.fetchall()
|
||||
search_settings = (
|
||||
SearchSettings(**search_settings_fetch[0]._asdict())
|
||||
if search_settings_fetch
|
||||
else None
|
||||
)
|
||||
|
||||
result2 = op.get_bind().execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT * FROM search_settings WHERE status = 'FUTURE' ORDER BY id DESC LIMIT 1
|
||||
"""
|
||||
)
|
||||
)
|
||||
search_settings_future_fetch = result2.fetchall()
|
||||
search_settings_future = (
|
||||
SearchSettings(**search_settings_future_fetch[0]._asdict())
|
||||
if search_settings_future_fetch
|
||||
else None
|
||||
)
|
||||
|
||||
if not isinstance(search_settings, SearchSettings):
|
||||
raise RuntimeError(
|
||||
"current search settings is of type " + str(type(search_settings))
|
||||
)
|
||||
if (
|
||||
not isinstance(search_settings_future, SearchSettings)
|
||||
and search_settings_future is not None
|
||||
):
|
||||
raise RuntimeError(
|
||||
"future search settings is of type " + str(type(search_settings_future))
|
||||
)
|
||||
|
||||
return search_settings, search_settings_future
|
||||
|
||||
|
||||
def _get_batch_documents_with_multiple_tags(
|
||||
batch_size: int = 128,
|
||||
) -> Generator[list[str], None, None]:
|
||||
"""
|
||||
Returns a list of document ids which contain a one to many tag.
|
||||
The document may either contain a list metadata value, or may contain leftover
|
||||
old tags from reindexing.
|
||||
"""
|
||||
offset_clause = ""
|
||||
bind = op.get_bind()
|
||||
|
||||
while True:
|
||||
batch = bind.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
SELECT DISTINCT document__tag.document_id
|
||||
FROM tag
|
||||
JOIN document__tag ON tag.id = document__tag.tag_id
|
||||
GROUP BY tag.tag_key, tag.source, document__tag.document_id
|
||||
HAVING count(*) > 1 {offset_clause}
|
||||
ORDER BY document__tag.document_id
|
||||
LIMIT {batch_size}
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
if not batch:
|
||||
break
|
||||
doc_ids = [document_id for document_id, in batch]
|
||||
yield doc_ids
|
||||
offset_clause = f"AND document__tag.document_id > '{doc_ids[-1]}'"
|
||||
|
||||
|
||||
def _get_vespa_metadata(
|
||||
document_id: str, index_name: str
|
||||
) -> dict[str, str | list[str]]:
|
||||
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
|
||||
|
||||
# Document-Selector language
|
||||
selection = (
|
||||
f"{index_name}.document_id=='{document_id}' and {index_name}.chunk_id==0"
|
||||
)
|
||||
|
||||
params: dict[str, str | int] = {
|
||||
"selection": selection,
|
||||
"wantedDocumentCount": 1,
|
||||
"fieldSet": f"{index_name}:metadata",
|
||||
}
|
||||
|
||||
with get_vespa_http_client() as client:
|
||||
resp = client.get(url, params=params)
|
||||
resp.raise_for_status()
|
||||
|
||||
docs = resp.json().get("documents", [])
|
||||
if not docs:
|
||||
raise RuntimeError(f"No chunk-0 found for document {document_id}")
|
||||
|
||||
# for some reason, metadata is a string
|
||||
metadata = docs[0]["fields"]["metadata"]
|
||||
return json.loads(metadata)
|
||||
|
||||
|
||||
def _get_document_tags(document_id: str) -> list[tuple[int, str, str]]:
|
||||
bind = op.get_bind()
|
||||
result = bind.execute(
|
||||
sa.text(
|
||||
f"""
|
||||
SELECT tag.id, tag.tag_key, tag.tag_value
|
||||
FROM tag
|
||||
JOIN document__tag ON tag.id = document__tag.tag_id
|
||||
WHERE document__tag.document_id = '{document_id}'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
return cast(list[tuple[int, str, str]], result)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"tag",
|
||||
sa.Column("is_list", sa.Boolean(), nullable=False, server_default="false"),
|
||||
)
|
||||
op.drop_constraint(
|
||||
constraint_name="_tag_key_value_source_uc",
|
||||
table_name="tag",
|
||||
type_="unique",
|
||||
)
|
||||
op.create_unique_constraint(
|
||||
constraint_name="_tag_key_value_source_list_uc",
|
||||
table_name="tag",
|
||||
columns=["tag_key", "tag_value", "source", "is_list"],
|
||||
)
|
||||
set_is_list_for_known_tags()
|
||||
|
||||
if SKIP_TAG_FIX:
|
||||
logger.warning(
|
||||
"Skipping removal of old tags. "
|
||||
"This can cause issues when using the knowledge graph, or "
|
||||
"when filtering for documents by tags."
|
||||
)
|
||||
log_list_tags()
|
||||
return
|
||||
|
||||
remove_old_tags()
|
||||
set_is_list_for_list_tags()
|
||||
|
||||
# debug
|
||||
log_list_tags()
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# the migration adds and populates the is_list column, and removes old bugged tags
|
||||
# there isn't a point in adding back the bugged tags, so we just drop the column
|
||||
op.drop_constraint(
|
||||
constraint_name="_tag_key_value_source_list_uc",
|
||||
table_name="tag",
|
||||
type_="unique",
|
||||
)
|
||||
op.create_unique_constraint(
|
||||
constraint_name="_tag_key_value_source_uc",
|
||||
table_name="tag",
|
||||
columns=["tag_key", "tag_value", "source"],
|
||||
)
|
||||
op.drop_column("tag", "is_list")
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Pause finished user file connectors
|
||||
|
||||
Revision ID: b558f51620b4
|
||||
Revises: 90e3b9af7da4
|
||||
Create Date: 2025-08-15 17:17:02.456704
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b558f51620b4"
|
||||
down_revision = "90e3b9af7da4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Set all user file connector credential pairs with ACTIVE status to PAUSED
|
||||
# This ensures user files don't continue to run indexing tasks after processing
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE connector_credential_pair
|
||||
SET status = 'PAUSED'
|
||||
WHERE is_user_file = true
|
||||
AND status = 'ACTIVE'
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -102,6 +102,19 @@ TEAMS_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("TEAMS_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
#####
|
||||
# SharePoint
|
||||
#####
|
||||
# In seconds, default is 30 minutes
|
||||
SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY") or 30 * 60
|
||||
)
|
||||
|
||||
# In seconds, default is 5 minutes
|
||||
SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
|
||||
####
|
||||
# Celery Job Frequency
|
||||
|
||||
@@ -2,18 +2,14 @@ from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import SortOrder
|
||||
|
||||
# Avoid circular imports
|
||||
if TYPE_CHECKING:
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup # noqa
|
||||
from onyx.access.models import DocExternalAccess # noqa
|
||||
from onyx.db.models import ConnectorCredentialPair # noqa
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface # noqa
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
|
||||
|
||||
class FetchAllDocumentsFunction(Protocol):
|
||||
@@ -52,20 +48,20 @@ class FetchAllDocumentsIdsFunction(Protocol):
|
||||
# Defining the input/output types for the sync functions
|
||||
DocSyncFuncType = Callable[
|
||||
[
|
||||
"ConnectorCredentialPair",
|
||||
ConnectorCredentialPair,
|
||||
FetchAllDocumentsFunction,
|
||||
FetchAllDocumentsIdsFunction,
|
||||
Optional["IndexingHeartbeatInterface"],
|
||||
Optional[IndexingHeartbeatInterface],
|
||||
],
|
||||
Generator["DocExternalAccess", None, None],
|
||||
Generator[DocExternalAccess, None, None],
|
||||
]
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
[
|
||||
str, # tenant_id
|
||||
"ConnectorCredentialPair", # cc_pair
|
||||
ConnectorCredentialPair, # cc_pair
|
||||
],
|
||||
Generator["ExternalUserGroup", None, None],
|
||||
Generator[ExternalUserGroup, None, None],
|
||||
]
|
||||
|
||||
# list of chunks to be censored and the user email. returns censored chunks
|
||||
|
||||
36
backend/ee/onyx/external_permissions/sharepoint/doc_sync.py
Normal file
36
backend/ee/onyx/external_permissions/sharepoint/doc_sync.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from ee.onyx.external_permissions.utils import generic_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
SHAREPOINT_DOC_SYNC_TAG = "sharepoint_doc_sync"
|
||||
|
||||
|
||||
def sharepoint_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
sharepoint_connector = SharepointConnector(
|
||||
**cc_pair.connector.connector_specific_config,
|
||||
)
|
||||
sharepoint_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
|
||||
callback=callback,
|
||||
doc_source=DocumentSource.SHAREPOINT,
|
||||
slim_connector=sharepoint_connector,
|
||||
label=SHAREPOINT_DOC_SYNC_TAG,
|
||||
)
|
||||
@@ -0,0 +1,63 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.sharepoint.permission_utils import (
|
||||
get_sharepoint_external_groups,
|
||||
)
|
||||
from onyx.connectors.sharepoint.connector import acquire_token_for_rest
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def sharepoint_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
"""Sync SharePoint groups and their members"""
|
||||
|
||||
# Get site URLs from connector config
|
||||
connector_config = cc_pair.connector.connector_specific_config
|
||||
|
||||
# Create SharePoint connector instance and load credentials
|
||||
connector = SharepointConnector(**connector_config)
|
||||
connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
if not connector.msal_app:
|
||||
raise RuntimeError("MSAL app not initialized in connector")
|
||||
|
||||
if not connector.sp_tenant_domain:
|
||||
raise RuntimeError("Tenant domain not initialized in connector")
|
||||
|
||||
# Get site descriptors from connector (either configured sites or all sites)
|
||||
site_descriptors = connector.site_descriptors or connector.fetch_sites()
|
||||
|
||||
if not site_descriptors:
|
||||
raise RuntimeError("No SharePoint sites found for group sync")
|
||||
|
||||
logger.info(f"Processing {len(site_descriptors)} sites for group sync")
|
||||
|
||||
msal_app = connector.msal_app
|
||||
sp_tenant_domain = connector.sp_tenant_domain
|
||||
# Process each site
|
||||
for site_descriptor in site_descriptors:
|
||||
logger.debug(f"Processing site: {site_descriptor.url}")
|
||||
|
||||
# Create client context for the site using connector's MSAL app
|
||||
ctx = ClientContext(site_descriptor.url).with_access_token(
|
||||
lambda: acquire_token_for_rest(msal_app, sp_tenant_domain)
|
||||
)
|
||||
|
||||
# Get external groups for this site
|
||||
external_groups = get_sharepoint_external_groups(ctx, connector.graph_client)
|
||||
|
||||
# Yield each group
|
||||
for group in external_groups:
|
||||
logger.debug(
|
||||
f"Found group: {group.id} with {len(group.user_emails)} members"
|
||||
)
|
||||
yield group
|
||||
@@ -0,0 +1,684 @@
|
||||
import re
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
from urllib.parse import unquote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from office365.graph_client import GraphClient # type: ignore[import-untyped]
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
|
||||
from office365.runtime.client_request import ClientRequestException # type: ignore
|
||||
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
|
||||
from office365.sharepoint.permissions.securable_object import RoleAssignmentCollection # type: ignore[import-untyped]
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.sharepoint.connector import sleep_and_retry
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# These values represent different types of SharePoint principals used in permission assignments
|
||||
USER_PRINCIPAL_TYPE = 1 # Individual user accounts
|
||||
ANONYMOUS_USER_PRINCIPAL_TYPE = 3 # Anonymous/unauthenticated users (public access)
|
||||
AZURE_AD_GROUP_PRINCIPAL_TYPE = 4 # Azure Active Directory security groups
|
||||
SHAREPOINT_GROUP_PRINCIPAL_TYPE = 8 # SharePoint site groups (local to the site)
|
||||
MICROSOFT_DOMAIN = ".onmicrosoft"
|
||||
# Limited Access role type, limited access is a travel through permission not a actual permission
|
||||
LIMITED_ACCESS_ROLE_TYPES = [1, 9]
|
||||
LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"]
|
||||
|
||||
|
||||
class SharepointGroup(BaseModel):
|
||||
model_config = {"frozen": True}
|
||||
|
||||
name: str
|
||||
login_name: str
|
||||
principal_type: int
|
||||
|
||||
|
||||
class GroupsResult(BaseModel):
|
||||
groups_to_emails: dict[str, set[str]]
|
||||
found_public_group: bool
|
||||
|
||||
|
||||
def _get_azuread_group_guid_by_name(
|
||||
graph_client: GraphClient, group_name: str
|
||||
) -> str | None:
|
||||
try:
|
||||
# Search for groups by display name
|
||||
groups = sleep_and_retry(
|
||||
graph_client.groups.filter(f"displayName eq '{group_name}'").get(),
|
||||
"get_azuread_group_guid_by_name",
|
||||
)
|
||||
|
||||
if groups and len(groups) > 0:
|
||||
return groups[0].id
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get Azure AD group GUID for name {group_name}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _extract_guid_from_claims_token(claims_token: str) -> str | None:
|
||||
|
||||
try:
|
||||
# Pattern to match GUID in claims token
|
||||
# Claims tokens often have format: c:0o.c|provider|GUID_suffix
|
||||
guid_pattern = r"([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})"
|
||||
|
||||
match = re.search(guid_pattern, claims_token, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to extract GUID from claims token {claims_token}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _get_group_guid_from_identifier(
|
||||
graph_client: GraphClient, identifier: str
|
||||
) -> str | None:
|
||||
try:
|
||||
# Check if it's already a GUID
|
||||
guid_pattern = r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$"
|
||||
if re.match(guid_pattern, identifier, re.IGNORECASE):
|
||||
return identifier
|
||||
|
||||
# Check if it's a SharePoint claims token
|
||||
if identifier.startswith("c:0") and "|" in identifier:
|
||||
guid = _extract_guid_from_claims_token(identifier)
|
||||
if guid:
|
||||
logger.info(f"Extracted GUID {guid} from claims token {identifier}")
|
||||
return guid
|
||||
|
||||
# Try to search by display name as fallback
|
||||
return _get_azuread_group_guid_by_name(graph_client, identifier)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get group GUID from identifier {identifier}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _get_security_group_owners(graph_client: GraphClient, group_id: str) -> list[str]:
|
||||
try:
|
||||
# Get group owners using Graph API
|
||||
group = graph_client.groups[group_id]
|
||||
owners = sleep_and_retry(
|
||||
group.owners.get_all(page_loaded=lambda _: None),
|
||||
"get_security_group_owners",
|
||||
)
|
||||
|
||||
owner_emails: list[str] = []
|
||||
logger.info(f"Owners: {owners}")
|
||||
|
||||
for owner in owners:
|
||||
owner_data = owner.to_json()
|
||||
|
||||
# Extract email from the JSON data
|
||||
mail: str | None = owner_data.get("mail")
|
||||
user_principal_name: str | None = owner_data.get("userPrincipalName")
|
||||
|
||||
# Check if owner is a user and has an email
|
||||
if mail:
|
||||
if MICROSOFT_DOMAIN in mail:
|
||||
mail = mail.replace(MICROSOFT_DOMAIN, "")
|
||||
owner_emails.append(mail)
|
||||
elif user_principal_name:
|
||||
if MICROSOFT_DOMAIN in user_principal_name:
|
||||
user_principal_name = user_principal_name.replace(
|
||||
MICROSOFT_DOMAIN, ""
|
||||
)
|
||||
owner_emails.append(user_principal_name)
|
||||
|
||||
logger.info(
|
||||
f"Retrieved {len(owner_emails)} owners from security group {group_id}"
|
||||
)
|
||||
return owner_emails
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get security group owners for group {group_id}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _get_sharepoint_list_item_id(drive_item: DriveItem) -> str | None:
|
||||
|
||||
try:
|
||||
# First try to get the list item directly from the drive item
|
||||
if hasattr(drive_item, "listItem"):
|
||||
list_item = drive_item.listItem
|
||||
if list_item:
|
||||
# Load the list item properties to get the ID
|
||||
sleep_and_retry(list_item.get(), "get_sharepoint_list_item_id")
|
||||
if hasattr(list_item, "id") and list_item.id:
|
||||
return str(list_item.id)
|
||||
|
||||
# The SharePoint list item ID is typically available in the sharepointIds property
|
||||
sharepoint_ids = getattr(drive_item, "sharepoint_ids", None)
|
||||
if sharepoint_ids and hasattr(sharepoint_ids, "listItemId"):
|
||||
return sharepoint_ids.listItemId
|
||||
|
||||
# Alternative: try to get it from the properties
|
||||
properties = getattr(drive_item, "properties", None)
|
||||
if properties:
|
||||
# Sometimes the SharePoint list item ID is in the properties
|
||||
for prop_name, prop_value in properties.items():
|
||||
if "listitemid" in prop_name.lower():
|
||||
return str(prop_value)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error getting SharePoint list item ID for item {drive_item.id}: {e}"
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
def _is_public_item(drive_item: DriveItem) -> bool:
|
||||
is_public = False
|
||||
try:
|
||||
permissions = sleep_and_retry(
|
||||
drive_item.permissions.get_all(page_loaded=lambda _: None), "is_public_item"
|
||||
)
|
||||
for permission in permissions:
|
||||
if permission.link and (
|
||||
permission.link.scope == "anonymous"
|
||||
or permission.link.scope == "organization"
|
||||
):
|
||||
is_public = True
|
||||
break
|
||||
return is_public
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check if item {drive_item.id} is public: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def _is_public_login_name(login_name: str) -> bool:
|
||||
# Patterns that indicate public access
|
||||
# This list is derived from the below link
|
||||
# https://learn.microsoft.com/en-us/answers/questions/2085339/guid-in-the-loginname-of-site-user-everyone-except
|
||||
public_login_patterns: list[str] = [
|
||||
"c:0-.f|rolemanager|spo-grid-all-users/",
|
||||
"c:0(.s|true",
|
||||
]
|
||||
for pattern in public_login_patterns:
|
||||
if pattern in login_name:
|
||||
logger.info(f"Login name {login_name} is public")
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
|
||||
def _get_group_name_with_suffix(
|
||||
login_name: str, group_name: str, graph_client: GraphClient
|
||||
) -> str:
|
||||
ad_group_suffix = _get_group_guid_from_identifier(graph_client, login_name)
|
||||
return f"{group_name}_{ad_group_suffix}"
|
||||
|
||||
|
||||
def _get_sharepoint_groups(
|
||||
client_context: ClientContext, group_name: str, graph_client: GraphClient
|
||||
) -> tuple[set[SharepointGroup], set[str]]:
|
||||
|
||||
groups: set[SharepointGroup] = set()
|
||||
user_emails: set[str] = set()
|
||||
|
||||
def process_users(users: list[Any]) -> None:
|
||||
nonlocal groups, user_emails
|
||||
|
||||
for user in users:
|
||||
logger.debug(f"User: {user.to_json()}")
|
||||
if user.principal_type == USER_PRINCIPAL_TYPE and hasattr(
|
||||
user, "user_principal_name"
|
||||
):
|
||||
if user.user_principal_name:
|
||||
email = user.user_principal_name
|
||||
if MICROSOFT_DOMAIN in email:
|
||||
email = email.replace(MICROSOFT_DOMAIN, "")
|
||||
user_emails.add(email)
|
||||
else:
|
||||
logger.warning(
|
||||
f"User don't have a user principal name: {user.login_name}"
|
||||
)
|
||||
elif user.principal_type in [
|
||||
AZURE_AD_GROUP_PRINCIPAL_TYPE,
|
||||
SHAREPOINT_GROUP_PRINCIPAL_TYPE,
|
||||
]:
|
||||
name = user.title
|
||||
if user.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
|
||||
name = _get_group_name_with_suffix(
|
||||
user.login_name, name, graph_client
|
||||
)
|
||||
groups.add(
|
||||
SharepointGroup(
|
||||
login_name=user.login_name,
|
||||
principal_type=user.principal_type,
|
||||
name=name,
|
||||
)
|
||||
)
|
||||
|
||||
group = client_context.web.site_groups.get_by_name(group_name)
|
||||
sleep_and_retry(
|
||||
group.users.get_all(page_loaded=process_users), "get_sharepoint_groups"
|
||||
)
|
||||
|
||||
return groups, user_emails
|
||||
|
||||
|
||||
def _get_azuread_groups(
|
||||
graph_client: GraphClient, group_name: str
|
||||
) -> tuple[set[SharepointGroup], set[str]]:
|
||||
|
||||
group_id = _get_group_guid_from_identifier(graph_client, group_name)
|
||||
if not group_id:
|
||||
logger.error(f"Failed to get Azure AD group GUID for name {group_name}")
|
||||
return set(), set()
|
||||
group = graph_client.groups[group_id]
|
||||
groups: set[SharepointGroup] = set()
|
||||
user_emails: set[str] = set()
|
||||
|
||||
def process_members(members: list[Any]) -> None:
|
||||
nonlocal groups, user_emails
|
||||
|
||||
for member in members:
|
||||
member_data = member.to_json()
|
||||
logger.debug(f"Member: {member_data}")
|
||||
# Check for user-specific attributes
|
||||
user_principal_name = member_data.get("userPrincipalName")
|
||||
mail = member_data.get("mail")
|
||||
display_name = member_data.get("displayName") or member_data.get(
|
||||
"display_name"
|
||||
)
|
||||
|
||||
# Check object attributes directly (if available)
|
||||
is_user = False
|
||||
is_group = False
|
||||
|
||||
# Users typically have userPrincipalName or mail
|
||||
if user_principal_name or (mail and "@" in str(mail)):
|
||||
is_user = True
|
||||
# Groups typically have displayName but no userPrincipalName
|
||||
elif display_name and not user_principal_name:
|
||||
# Additional check: try to access group-specific properties
|
||||
if (
|
||||
hasattr(member, "groupTypes")
|
||||
or member_data.get("groupTypes") is not None
|
||||
):
|
||||
is_group = True
|
||||
# Or check if it has an 'id' field typical for groups
|
||||
elif member_data.get("id") and not user_principal_name:
|
||||
is_group = True
|
||||
|
||||
# Check the object type name (fallback)
|
||||
if not is_user and not is_group:
|
||||
obj_type = type(member).__name__.lower()
|
||||
if "user" in obj_type:
|
||||
is_user = True
|
||||
elif "group" in obj_type:
|
||||
is_group = True
|
||||
|
||||
# Process based on identification
|
||||
if is_user:
|
||||
if user_principal_name:
|
||||
email = user_principal_name
|
||||
if MICROSOFT_DOMAIN in email:
|
||||
email = email.replace(MICROSOFT_DOMAIN, "")
|
||||
user_emails.add(email)
|
||||
elif mail:
|
||||
email = mail
|
||||
if MICROSOFT_DOMAIN in email:
|
||||
email = email.replace(MICROSOFT_DOMAIN, "")
|
||||
user_emails.add(email)
|
||||
logger.info(f"Added user: {user_principal_name or mail}")
|
||||
elif is_group:
|
||||
if not display_name:
|
||||
logger.error(f"No display name for group: {member_data.get('id')}")
|
||||
continue
|
||||
name = _get_group_name_with_suffix(
|
||||
member_data.get("id", ""), display_name, graph_client
|
||||
)
|
||||
groups.add(
|
||||
SharepointGroup(
|
||||
login_name=member_data.get("id", ""), # Use ID for groups
|
||||
principal_type=AZURE_AD_GROUP_PRINCIPAL_TYPE,
|
||||
name=name,
|
||||
)
|
||||
)
|
||||
logger.info(f"Added group: {name}")
|
||||
else:
|
||||
# Log unidentified members for debugging
|
||||
logger.warning(f"Could not identify member type for: {member_data}")
|
||||
|
||||
sleep_and_retry(
|
||||
group.members.get_all(page_loaded=process_members), "get_azuread_groups"
|
||||
)
|
||||
|
||||
owner_emails = _get_security_group_owners(graph_client, group_id)
|
||||
user_emails.update(owner_emails)
|
||||
|
||||
return groups, user_emails
|
||||
|
||||
|
||||
def _get_groups_and_members_recursively(
|
||||
client_context: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
groups: set[SharepointGroup],
|
||||
is_group_sync: bool = False,
|
||||
) -> GroupsResult:
|
||||
"""
|
||||
Get all groups and their members recursively.
|
||||
"""
|
||||
group_queue: deque[SharepointGroup] = deque(groups)
|
||||
visited_groups: set[str] = set()
|
||||
visited_group_name_to_emails: dict[str, set[str]] = {}
|
||||
found_public_group = False
|
||||
while group_queue:
|
||||
group = group_queue.popleft()
|
||||
if group.login_name in visited_groups:
|
||||
continue
|
||||
visited_groups.add(group.login_name)
|
||||
visited_group_name_to_emails[group.name] = set()
|
||||
logger.info(
|
||||
f"Processing group: {group.name} principal type: {group.principal_type}"
|
||||
)
|
||||
if group.principal_type == SHAREPOINT_GROUP_PRINCIPAL_TYPE:
|
||||
group_info, user_emails = _get_sharepoint_groups(
|
||||
client_context, group.login_name, graph_client
|
||||
)
|
||||
visited_group_name_to_emails[group.name].update(user_emails)
|
||||
if group_info:
|
||||
group_queue.extend(group_info)
|
||||
if group.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
|
||||
try:
|
||||
# if the site is public, we have default groups assigned to it, so we return early
|
||||
if _is_public_login_name(group.login_name):
|
||||
found_public_group = True
|
||||
if not is_group_sync:
|
||||
return GroupsResult(
|
||||
groups_to_emails={}, found_public_group=True
|
||||
)
|
||||
else:
|
||||
# we don't want to sync public groups, so we skip them
|
||||
continue
|
||||
group_info, user_emails = _get_azuread_groups(
|
||||
graph_client, group.login_name
|
||||
)
|
||||
visited_group_name_to_emails[group.name].update(user_emails)
|
||||
if group_info:
|
||||
group_queue.extend(group_info)
|
||||
except ClientRequestException as e:
|
||||
# If the group is not found, we skip it. There is a chance that group is still referenced
|
||||
# in sharepoint but it is removed from Azure AD. There is no actual documentation on this, but based on
|
||||
# our testing we have seen this happen.
|
||||
if e.response is not None and e.response.status_code == 404:
|
||||
logger.warning(f"Group {group.login_name} not found")
|
||||
continue
|
||||
raise e
|
||||
|
||||
return GroupsResult(
|
||||
groups_to_emails=visited_group_name_to_emails,
|
||||
found_public_group=found_public_group,
|
||||
)
|
||||
|
||||
|
||||
def get_external_access_from_sharepoint(
|
||||
client_context: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
drive_name: str | None,
|
||||
drive_item: DriveItem | None,
|
||||
site_page: dict[str, Any] | None,
|
||||
add_prefix: bool = False,
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
Get external access information from SharePoint.
|
||||
"""
|
||||
groups: set[SharepointGroup] = set()
|
||||
user_emails: set[str] = set()
|
||||
group_ids: set[str] = set()
|
||||
|
||||
# Add all members to a processing set first
|
||||
def add_user_and_group_to_sets(
|
||||
role_assignments: RoleAssignmentCollection,
|
||||
) -> None:
|
||||
nonlocal user_emails, groups
|
||||
for assignment in role_assignments:
|
||||
logger.debug(f"Assignment: {assignment.to_json()}")
|
||||
if assignment.role_definition_bindings:
|
||||
is_limited_access = True
|
||||
for role_definition_binding in assignment.role_definition_bindings:
|
||||
if (
|
||||
role_definition_binding.role_type_kind
|
||||
not in LIMITED_ACCESS_ROLE_TYPES
|
||||
or role_definition_binding.name not in LIMITED_ACCESS_ROLE_NAMES
|
||||
):
|
||||
is_limited_access = False
|
||||
break
|
||||
|
||||
# Skip if the role is only Limited Access, because this is not a actual permission its a travel through permission
|
||||
if is_limited_access:
|
||||
logger.info(
|
||||
"Skipping assignment because it has only Limited Access role"
|
||||
)
|
||||
continue
|
||||
if assignment.member:
|
||||
member = assignment.member
|
||||
if member.principal_type == USER_PRINCIPAL_TYPE and hasattr(
|
||||
member, "user_principal_name"
|
||||
):
|
||||
email = member.user_principal_name
|
||||
if MICROSOFT_DOMAIN in email:
|
||||
email = email.replace(MICROSOFT_DOMAIN, "")
|
||||
user_emails.add(email)
|
||||
elif member.principal_type in [
|
||||
AZURE_AD_GROUP_PRINCIPAL_TYPE,
|
||||
SHAREPOINT_GROUP_PRINCIPAL_TYPE,
|
||||
]:
|
||||
name = member.title
|
||||
if member.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
|
||||
name = _get_group_name_with_suffix(
|
||||
member.login_name, name, graph_client
|
||||
)
|
||||
groups.add(
|
||||
SharepointGroup(
|
||||
login_name=member.login_name,
|
||||
principal_type=member.principal_type,
|
||||
name=name,
|
||||
)
|
||||
)
|
||||
|
||||
if drive_item and drive_name:
|
||||
# Here we check if the item have have any public links, if so we return early
|
||||
is_public = _is_public_item(drive_item)
|
||||
if is_public:
|
||||
logger.info(f"Item {drive_item.id} is public")
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
item_id = _get_sharepoint_list_item_id(drive_item)
|
||||
|
||||
if not item_id:
|
||||
raise RuntimeError(
|
||||
f"Failed to get SharePoint list item ID for item {drive_item.id}"
|
||||
)
|
||||
|
||||
if drive_name == "Shared Documents":
|
||||
drive_name = "Documents"
|
||||
|
||||
item = client_context.web.lists.get_by_title(drive_name).items.get_by_id(
|
||||
item_id
|
||||
)
|
||||
|
||||
sleep_and_retry(
|
||||
item.role_assignments.expand(["Member", "RoleDefinitionBindings"]).get_all(
|
||||
page_loaded=add_user_and_group_to_sets,
|
||||
),
|
||||
"get_external_access_from_sharepoint",
|
||||
)
|
||||
elif site_page:
|
||||
site_url = site_page.get("webUrl")
|
||||
# Prefer server-relative URL to avoid OData filters that break on apostrophes
|
||||
server_relative_url = unquote(urlparse(site_url).path)
|
||||
file_obj = client_context.web.get_file_by_server_relative_url(
|
||||
server_relative_url
|
||||
)
|
||||
item = file_obj.listItemAllFields
|
||||
|
||||
sleep_and_retry(
|
||||
item.role_assignments.expand(["Member", "RoleDefinitionBindings"]).get_all(
|
||||
page_loaded=add_user_and_group_to_sets,
|
||||
),
|
||||
"get_external_access_from_sharepoint",
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("No drive item or site page provided")
|
||||
|
||||
groups_and_members: GroupsResult = _get_groups_and_members_recursively(
|
||||
client_context, graph_client, groups
|
||||
)
|
||||
|
||||
# If the site is public, w have default groups assigned to it, so we return early
|
||||
if groups_and_members.found_public_group:
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
for group_name, _ in groups_and_members.groups_to_emails.items():
|
||||
if add_prefix:
|
||||
group_name = build_ext_group_name_for_onyx(
|
||||
group_name, DocumentSource.SHAREPOINT
|
||||
)
|
||||
group_ids.add(group_name.lower())
|
||||
|
||||
logger.info(f"User emails: {len(user_emails)}")
|
||||
logger.info(f"Group IDs: {len(group_ids)}")
|
||||
|
||||
return ExternalAccess(
|
||||
external_user_emails=user_emails,
|
||||
external_user_group_ids=group_ids,
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
def get_sharepoint_external_groups(
|
||||
client_context: ClientContext, graph_client: GraphClient
|
||||
) -> list[ExternalUserGroup]:
|
||||
|
||||
groups: set[SharepointGroup] = set()
|
||||
|
||||
def add_group_to_sets(role_assignments: RoleAssignmentCollection) -> None:
|
||||
nonlocal groups
|
||||
for assignment in role_assignments:
|
||||
if assignment.role_definition_bindings:
|
||||
is_limited_access = True
|
||||
for role_definition_binding in assignment.role_definition_bindings:
|
||||
if (
|
||||
role_definition_binding.role_type_kind
|
||||
not in LIMITED_ACCESS_ROLE_TYPES
|
||||
or role_definition_binding.name not in LIMITED_ACCESS_ROLE_NAMES
|
||||
):
|
||||
is_limited_access = False
|
||||
break
|
||||
|
||||
# Skip if the role assignment is only Limited Access, because this is not a actual permission its
|
||||
# a travel through permission
|
||||
if is_limited_access:
|
||||
logger.info(
|
||||
"Skipping assignment because it has only Limited Access role"
|
||||
)
|
||||
continue
|
||||
if assignment.member:
|
||||
member = assignment.member
|
||||
if member.principal_type in [
|
||||
AZURE_AD_GROUP_PRINCIPAL_TYPE,
|
||||
SHAREPOINT_GROUP_PRINCIPAL_TYPE,
|
||||
]:
|
||||
name = member.title
|
||||
if member.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE:
|
||||
name = _get_group_name_with_suffix(
|
||||
member.login_name, name, graph_client
|
||||
)
|
||||
|
||||
groups.add(
|
||||
SharepointGroup(
|
||||
login_name=member.login_name,
|
||||
principal_type=member.principal_type,
|
||||
name=name,
|
||||
)
|
||||
)
|
||||
|
||||
sleep_and_retry(
|
||||
client_context.web.role_assignments.expand(
|
||||
["Member", "RoleDefinitionBindings"]
|
||||
).get_all(page_loaded=add_group_to_sets),
|
||||
"get_sharepoint_external_groups",
|
||||
)
|
||||
groups_and_members: GroupsResult = _get_groups_and_members_recursively(
|
||||
client_context, graph_client, groups, is_group_sync=True
|
||||
)
|
||||
|
||||
# get all Azure AD groups because if any group is assigned to the drive item, we don't want to miss them
|
||||
# We can't assign sharepoint groups to drive items or drives, so we don't need to get all sharepoint groups
|
||||
azure_ad_groups = sleep_and_retry(
|
||||
graph_client.groups.get_all(page_loaded=lambda _: None),
|
||||
"get_sharepoint_external_groups:get_azure_ad_groups",
|
||||
)
|
||||
logger.info(f"Azure AD Groups: {len(azure_ad_groups)}")
|
||||
identified_groups: set[str] = set(groups_and_members.groups_to_emails.keys())
|
||||
ad_groups_to_emails: dict[str, set[str]] = {}
|
||||
for group in azure_ad_groups:
|
||||
# If the group is already identified, we don't need to get the members
|
||||
if group.display_name in identified_groups:
|
||||
continue
|
||||
# AD groups allows same display name for multiple groups, so we need to add the GUID to the name
|
||||
name = group.display_name
|
||||
name = _get_group_name_with_suffix(group.id, name, graph_client)
|
||||
|
||||
members = sleep_and_retry(
|
||||
group.members.get_all(page_loaded=lambda _: None),
|
||||
"get_sharepoint_external_groups:get_azure_ad_groups:get_members",
|
||||
)
|
||||
for member in members:
|
||||
member_data = member.to_json()
|
||||
user_principal_name = member_data.get("userPrincipalName")
|
||||
mail = member_data.get("mail")
|
||||
if not ad_groups_to_emails.get(name):
|
||||
ad_groups_to_emails[name] = set()
|
||||
if user_principal_name:
|
||||
if MICROSOFT_DOMAIN in user_principal_name:
|
||||
user_principal_name = user_principal_name.replace(
|
||||
MICROSOFT_DOMAIN, ""
|
||||
)
|
||||
ad_groups_to_emails[name].add(user_principal_name)
|
||||
elif mail:
|
||||
if MICROSOFT_DOMAIN in mail:
|
||||
mail = mail.replace(MICROSOFT_DOMAIN, "")
|
||||
ad_groups_to_emails[name].add(mail)
|
||||
|
||||
external_user_groups: list[ExternalUserGroup] = []
|
||||
for group_name, emails in groups_and_members.groups_to_emails.items():
|
||||
external_user_group = ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(emails),
|
||||
)
|
||||
external_user_groups.append(external_user_group)
|
||||
|
||||
for group_name, emails in ad_groups_to_emails.items():
|
||||
external_user_group = ExternalUserGroup(
|
||||
id=group_name,
|
||||
user_emails=list(emails),
|
||||
)
|
||||
external_user_groups.append(external_user_group)
|
||||
|
||||
return external_user_groups
|
||||
@@ -11,6 +11,8 @@ from ee.onyx.configs.app_configs import GITHUB_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import JIRA_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import TEAMS_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
|
||||
@@ -29,6 +31,8 @@ from ee.onyx.external_permissions.perm_sync_types import GroupSyncFuncType
|
||||
from ee.onyx.external_permissions.salesforce.postprocessing import (
|
||||
censor_salesforce_chunks,
|
||||
)
|
||||
from ee.onyx.external_permissions.sharepoint.doc_sync import sharepoint_doc_sync
|
||||
from ee.onyx.external_permissions.sharepoint.group_sync import sharepoint_group_sync
|
||||
from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
from ee.onyx.external_permissions.teams.doc_sync import teams_doc_sync
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -156,6 +160,18 @@ _SOURCE_TO_SYNC_CONFIG: dict[DocumentSource, SyncConfig] = {
|
||||
initial_index_should_sync=True,
|
||||
),
|
||||
),
|
||||
DocumentSource.SHAREPOINT: SyncConfig(
|
||||
doc_sync_config=DocSyncConfig(
|
||||
doc_sync_frequency=SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
doc_sync_func=sharepoint_doc_sync,
|
||||
initial_index_should_sync=True,
|
||||
),
|
||||
group_sync_config=GroupSyncConfig(
|
||||
group_sync_frequency=SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
group_sync_func=sharepoint_group_sync,
|
||||
group_sync_is_cc_pair_agnostic=False,
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -206,7 +206,7 @@ def _handle_standard_answers(
|
||||
|
||||
restate_question_blocks = get_restate_blocks(
|
||||
msg=query_msg.message,
|
||||
is_bot_msg=message_info.is_bot_msg,
|
||||
is_slash_command=message_info.is_slash_command,
|
||||
)
|
||||
|
||||
answer_blocks = build_standard_answer_blocks(
|
||||
|
||||
@@ -67,7 +67,7 @@ def generate_chat_messages_report(
|
||||
file_id = file_store.save_file(
|
||||
content=temp_file,
|
||||
display_name=file_name,
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_origin=FileOrigin.GENERATED_REPORT,
|
||||
file_type="text/csv",
|
||||
)
|
||||
|
||||
@@ -99,7 +99,7 @@ def generate_user_report(
|
||||
file_id = file_store.save_file(
|
||||
content=temp_file,
|
||||
display_name=file_name,
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_origin=FileOrigin.GENERATED_REPORT,
|
||||
file_type="text/csv",
|
||||
)
|
||||
|
||||
|
||||
@@ -1,34 +1,5 @@
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
MODEL_WARM_UP_STRING = "hi " * 512
|
||||
INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi " * 16
|
||||
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
|
||||
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
|
||||
DEFAULT_VERTEX_MODEL = "text-embedding-005"
|
||||
|
||||
|
||||
class EmbeddingModelTextType:
|
||||
PROVIDER_TEXT_TYPE_MAP = {
|
||||
EmbeddingProvider.COHERE: {
|
||||
EmbedTextType.QUERY: "search_query",
|
||||
EmbedTextType.PASSAGE: "search_document",
|
||||
},
|
||||
EmbeddingProvider.VOYAGE: {
|
||||
EmbedTextType.QUERY: "query",
|
||||
EmbedTextType.PASSAGE: "document",
|
||||
},
|
||||
EmbeddingProvider.GOOGLE: {
|
||||
EmbedTextType.QUERY: "RETRIEVAL_QUERY",
|
||||
EmbedTextType.PASSAGE: "RETRIEVAL_DOCUMENT",
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
|
||||
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
|
||||
|
||||
|
||||
class GPUStatus:
|
||||
|
||||
@@ -1,55 +1,30 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from types import TracebackType
|
||||
from typing import cast
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
import aioboto3 # type: ignore
|
||||
import httpx
|
||||
import openai
|
||||
import vertexai # type: ignore
|
||||
import voyageai # type: ignore
|
||||
from cohere import AsyncClient as CohereAsyncClient
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from litellm import aembedding
|
||||
from litellm.exceptions import RateLimitError
|
||||
from retry import retry
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from vertexai.language_models import TextEmbeddingInput # type: ignore
|
||||
from vertexai.language_models import TextEmbeddingModel # type: ignore
|
||||
|
||||
from model_server.constants import DEFAULT_COHERE_MODEL
|
||||
from model_server.constants import DEFAULT_OPENAI_MODEL
|
||||
from model_server.constants import DEFAULT_VERTEX_MODEL
|
||||
from model_server.constants import DEFAULT_VOYAGE_MODEL
|
||||
from model_server.constants import EmbeddingModelTextType
|
||||
from model_server.constants import EmbeddingProvider
|
||||
from model_server.utils import pass_aws_key
|
||||
from model_server.utils import simple_log_function_time
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
|
||||
from shared_configs.configs import VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from shared_configs.model_server_models import Embedding
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
from shared_configs.utils import batch_list
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/encoder")
|
||||
|
||||
|
||||
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
||||
_RERANK_MODEL: Optional["CrossEncoder"] = None
|
||||
|
||||
@@ -57,315 +32,6 @@ _RERANK_MODEL: Optional["CrossEncoder"] = None
|
||||
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
|
||||
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
|
||||
|
||||
# OpenAI only allows 2048 embeddings to be computed at once
|
||||
_OPENAI_MAX_INPUT_LEN = 2048
|
||||
# Cohere allows up to 96 embeddings in a single embedding calling
|
||||
_COHERE_MAX_INPUT_LEN = 96
|
||||
|
||||
# Authentication error string constants
|
||||
_AUTH_ERROR_401 = "401"
|
||||
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
|
||||
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
|
||||
_AUTH_ERROR_PERMISSION = "permission"
|
||||
|
||||
|
||||
def is_authentication_error(error: Exception) -> bool:
|
||||
"""Check if an exception is related to authentication issues.
|
||||
|
||||
Args:
|
||||
error: The exception to check
|
||||
|
||||
Returns:
|
||||
bool: True if the error appears to be authentication-related
|
||||
"""
|
||||
error_str = str(error).lower()
|
||||
return (
|
||||
_AUTH_ERROR_401 in error_str
|
||||
or _AUTH_ERROR_UNAUTHORIZED in error_str
|
||||
or _AUTH_ERROR_INVALID_API_KEY in error_str
|
||||
or _AUTH_ERROR_PERMISSION in error_str
|
||||
)
|
||||
|
||||
|
||||
def format_embedding_error(
|
||||
error: Exception,
|
||||
service_name: str,
|
||||
model: str | None,
|
||||
provider: EmbeddingProvider,
|
||||
sanitized_api_key: str | None = None,
|
||||
status_code: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format a standardized error string for embedding errors.
|
||||
"""
|
||||
detail = f"Status {status_code}" if status_code else f"{type(error)}"
|
||||
|
||||
return (
|
||||
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
|
||||
f"Model: {model} "
|
||||
f"Provider: {provider} "
|
||||
f"API Key: {sanitized_api_key} "
|
||||
f"Exception: {error}"
|
||||
)
|
||||
|
||||
|
||||
# Custom exception for authentication errors
|
||||
class AuthenticationError(Exception):
|
||||
"""Raised when authentication fails with a provider."""
|
||||
|
||||
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
|
||||
self.provider = provider
|
||||
self.message = message
|
||||
super().__init__(f"{provider} authentication failed: {message}")
|
||||
|
||||
|
||||
class CloudEmbedding:
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
provider: EmbeddingProvider,
|
||||
api_url: str | None = None,
|
||||
api_version: str | None = None,
|
||||
timeout: int = API_BASED_EMBEDDING_TIMEOUT,
|
||||
) -> None:
|
||||
self.provider = provider
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url
|
||||
self.api_version = api_version
|
||||
self.timeout = timeout
|
||||
self.http_client = httpx.AsyncClient(timeout=timeout)
|
||||
self._closed = False
|
||||
self.sanitized_api_key = api_key[:4] + "********" + api_key[-4:]
|
||||
|
||||
async def _embed_openai(
|
||||
self, texts: list[str], model: str | None, reduced_dimension: int | None
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_OPENAI_MODEL
|
||||
|
||||
# Use the OpenAI specific timeout for this one
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT
|
||||
)
|
||||
|
||||
final_embeddings: list[Embedding] = []
|
||||
|
||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||
response = await client.embeddings.create(
|
||||
input=text_batch,
|
||||
model=model,
|
||||
dimensions=reduced_dimension or openai.NOT_GIVEN,
|
||||
)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
)
|
||||
return final_embeddings
|
||||
|
||||
async def _embed_cohere(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_COHERE_MODEL
|
||||
|
||||
client = CohereAsyncClient(api_key=self.api_key)
|
||||
|
||||
final_embeddings: list[Embedding] = []
|
||||
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
|
||||
# Does not use the same tokenizer as the Onyx API server but it's approximately the same
|
||||
# empirically it's only off by a very few tokens so it's not a big deal
|
||||
response = await client.embed(
|
||||
texts=text_batch,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncate="END",
|
||||
)
|
||||
final_embeddings.extend(cast(list[Embedding], response.embeddings))
|
||||
return final_embeddings
|
||||
|
||||
async def _embed_voyage(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_VOYAGE_MODEL
|
||||
|
||||
client = voyageai.AsyncClient(
|
||||
api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT
|
||||
)
|
||||
|
||||
response = await client.embed(
|
||||
texts=texts,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncation=True,
|
||||
)
|
||||
return response.embeddings
|
||||
|
||||
async def _embed_azure(
|
||||
self, texts: list[str], model: str | None
|
||||
) -> list[Embedding]:
|
||||
response = await aembedding(
|
||||
model=model,
|
||||
input=texts,
|
||||
timeout=API_BASED_EMBEDDING_TIMEOUT,
|
||||
api_key=self.api_key,
|
||||
api_base=self.api_url,
|
||||
api_version=self.api_version,
|
||||
)
|
||||
embeddings = [embedding["embedding"] for embedding in response.data]
|
||||
return embeddings
|
||||
|
||||
async def _embed_vertex(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_VERTEX_MODEL
|
||||
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
json.loads(self.api_key)
|
||||
)
|
||||
project_id = json.loads(self.api_key)["project_id"]
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
client = TextEmbeddingModel.from_pretrained(model)
|
||||
|
||||
inputs = [TextEmbeddingInput(text, embedding_type) for text in texts]
|
||||
|
||||
# Split into batches of 25 texts
|
||||
max_texts_per_batch = VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
|
||||
batches = [
|
||||
inputs[i : i + max_texts_per_batch]
|
||||
for i in range(0, len(inputs), max_texts_per_batch)
|
||||
]
|
||||
|
||||
# Dispatch all embedding calls asynchronously at once
|
||||
tasks = [
|
||||
client.get_embeddings_async(batch, auto_truncate=True) for batch in batches
|
||||
]
|
||||
|
||||
# Wait for all tasks to complete in parallel
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
return [embedding.values for batch in results for embedding in batch]
|
||||
|
||||
async def _embed_litellm_proxy(
|
||||
self, texts: list[str], model_name: str | None
|
||||
) -> list[Embedding]:
|
||||
if not model_name:
|
||||
raise ValueError("Model name is required for LiteLLM proxy embedding.")
|
||||
|
||||
if not self.api_url:
|
||||
raise ValueError("API URL is required for LiteLLM proxy embedding.")
|
||||
|
||||
headers = (
|
||||
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
|
||||
)
|
||||
|
||||
response = await self.http_client.post(
|
||||
self.api_url,
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": texts,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return [embedding["embedding"] for embedding in result["data"]]
|
||||
|
||||
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
|
||||
async def embed(
|
||||
self,
|
||||
*,
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
reduced_dimension: int | None = None,
|
||||
) -> list[Embedding]:
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return await self._embed_openai(texts, model_name, reduced_dimension)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return await self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return await self._embed_litellm_proxy(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return await self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return await self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return await self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
except openai.AuthenticationError:
|
||||
raise AuthenticationError(provider="OpenAI")
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise AuthenticationError(provider=str(self.provider))
|
||||
|
||||
error_string = format_embedding_error(
|
||||
e,
|
||||
str(self.provider),
|
||||
model_name or deployment_name,
|
||||
self.provider,
|
||||
sanitized_api_key=self.sanitized_api_key,
|
||||
status_code=e.response.status_code,
|
||||
)
|
||||
logger.error(error_string)
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
except Exception as e:
|
||||
if is_authentication_error(e):
|
||||
raise AuthenticationError(provider=str(self.provider))
|
||||
|
||||
error_string = format_embedding_error(
|
||||
e,
|
||||
str(self.provider),
|
||||
model_name or deployment_name,
|
||||
self.provider,
|
||||
sanitized_api_key=self.sanitized_api_key,
|
||||
)
|
||||
logger.error(error_string)
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
api_key: str,
|
||||
provider: EmbeddingProvider,
|
||||
api_url: str | None = None,
|
||||
api_version: str | None = None,
|
||||
) -> "CloudEmbedding":
|
||||
logger.debug(f"Creating Embedding instance for provider: {provider}")
|
||||
return CloudEmbedding(api_key, provider, api_url, api_version)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Explicitly close the client."""
|
||||
if not self._closed:
|
||||
await self.http_client.aclose()
|
||||
self._closed = True
|
||||
|
||||
async def __aenter__(self) -> "CloudEmbedding":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Finalizer to warn about unclosed clients."""
|
||||
if not self._closed:
|
||||
logger.warning(
|
||||
"CloudEmbedding was not properly closed. Use 'async with' or call aclose()"
|
||||
)
|
||||
|
||||
|
||||
def get_embedding_model(
|
||||
model_name: str,
|
||||
@@ -404,20 +70,34 @@ def get_local_reranking_model(
|
||||
return _RERANK_MODEL
|
||||
|
||||
|
||||
ENCODING_RETRIES = 3
|
||||
ENCODING_RETRY_DELAY = 0.1
|
||||
|
||||
|
||||
def _concurrent_embedding(
|
||||
texts: list[str], model: "SentenceTransformer", normalize_embeddings: bool
|
||||
) -> Any:
|
||||
"""Synchronous wrapper for concurrent_embedding to use with run_in_executor."""
|
||||
for _ in range(ENCODING_RETRIES):
|
||||
try:
|
||||
return model.encode(texts, normalize_embeddings=normalize_embeddings)
|
||||
except RuntimeError as e:
|
||||
# There is a concurrency bug in the SentenceTransformer library that causes
|
||||
# the model to fail to encode texts. It's pretty rare and we want to allow
|
||||
# concurrent embedding, hence we retry (the specific error is
|
||||
# "RuntimeError: Already borrowed" and occurs in the transformers library)
|
||||
logger.error(f"Error encoding texts, retrying: {e}")
|
||||
time.sleep(ENCODING_RETRY_DELAY)
|
||||
return model.encode(texts, normalize_embeddings=normalize_embeddings)
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
async def embed_text(
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None,
|
||||
deployment_name: str | None,
|
||||
max_context_length: int,
|
||||
normalize_embeddings: bool,
|
||||
api_key: str | None,
|
||||
provider_type: EmbeddingProvider | None,
|
||||
prefix: str | None,
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
reduced_dimension: int | None,
|
||||
gpu_type: str = "UNKNOWN",
|
||||
) -> list[Embedding]:
|
||||
if not all(texts):
|
||||
@@ -434,52 +114,10 @@ async def embed_text(
|
||||
for text in texts:
|
||||
total_chars += len(text)
|
||||
|
||||
if provider_type is not None:
|
||||
logger.info(
|
||||
f"Embedding {len(texts)} texts with {total_chars} total characters with provider: {provider_type}"
|
||||
)
|
||||
# Only local models should call this function now
|
||||
# API providers should go directly to API server
|
||||
|
||||
if api_key is None:
|
||||
logger.error("API key not provided for cloud model")
|
||||
raise RuntimeError("API key not provided for cloud model")
|
||||
|
||||
if prefix:
|
||||
logger.warning("Prefix provided for cloud model, which is not supported")
|
||||
raise ValueError(
|
||||
"Prefix string is not valid for cloud models. "
|
||||
"Cloud models take an explicit text type instead."
|
||||
)
|
||||
|
||||
async with CloudEmbedding(
|
||||
api_key=api_key,
|
||||
provider=provider_type,
|
||||
api_url=api_url,
|
||||
api_version=api_version,
|
||||
) as cloud_model:
|
||||
embeddings = await cloud_model.embed(
|
||||
texts=texts,
|
||||
model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
text_type=text_type,
|
||||
reduced_dimension=reduced_dimension,
|
||||
)
|
||||
|
||||
if any(embedding is None for embedding in embeddings):
|
||||
error_message = "Embeddings contain None values\n"
|
||||
error_message += "Corresponding texts:\n"
|
||||
error_message += "\n".join(texts)
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info(
|
||||
f"event=embedding_provider "
|
||||
f"texts={len(texts)} "
|
||||
f"chars={total_chars} "
|
||||
f"provider={provider_type} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
elif model_name is not None:
|
||||
if model_name is not None:
|
||||
logger.info(
|
||||
f"Embedding {len(texts)} texts with {total_chars} total characters with local model: {model_name}"
|
||||
)
|
||||
@@ -492,8 +130,8 @@ async def embed_text(
|
||||
# Run CPU-bound embedding in a thread pool
|
||||
embeddings_vectors = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: local_model.encode(
|
||||
prefixed_texts, normalize_embeddings=normalize_embeddings
|
||||
lambda: _concurrent_embedding(
|
||||
prefixed_texts, local_model, normalize_embeddings
|
||||
),
|
||||
)
|
||||
embeddings = [
|
||||
@@ -515,10 +153,8 @@ async def embed_text(
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
else:
|
||||
logger.error("Neither model name nor provider specified for embedding")
|
||||
raise ValueError(
|
||||
"Either model name or provider must be provided to run embeddings."
|
||||
)
|
||||
logger.error("Model name not specified for embedding")
|
||||
raise ValueError("Model name must be provided to run embeddings.")
|
||||
|
||||
return embeddings
|
||||
|
||||
@@ -533,77 +169,6 @@ async def local_rerank(query: str, docs: list[str], model_name: str) -> list[flo
|
||||
)
|
||||
|
||||
|
||||
async def cohere_rerank_api(
|
||||
query: str, docs: list[str], model_name: str, api_key: str
|
||||
) -> list[float]:
|
||||
cohere_client = CohereAsyncClient(api_key=api_key)
|
||||
response = await cohere_client.rerank(query=query, documents=docs, model=model_name)
|
||||
results = response.results
|
||||
sorted_results = sorted(results, key=lambda item: item.index)
|
||||
return [result.relevance_score for result in sorted_results]
|
||||
|
||||
|
||||
async def cohere_rerank_aws(
|
||||
query: str,
|
||||
docs: list[str],
|
||||
model_name: str,
|
||||
region_name: str,
|
||||
aws_access_key_id: str,
|
||||
aws_secret_access_key: str,
|
||||
) -> list[float]:
|
||||
session = aioboto3.Session(
|
||||
aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key
|
||||
)
|
||||
async with session.client(
|
||||
"bedrock-runtime", region_name=region_name
|
||||
) as bedrock_client:
|
||||
body = json.dumps(
|
||||
{
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
"api_version": 2,
|
||||
}
|
||||
)
|
||||
# Invoke the Bedrock model asynchronously
|
||||
response = await bedrock_client.invoke_model(
|
||||
modelId=model_name,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
body=body,
|
||||
)
|
||||
|
||||
# Read the response asynchronously
|
||||
response_body = json.loads(await response["body"].read())
|
||||
|
||||
# Extract and sort the results
|
||||
results = response_body.get("results", [])
|
||||
sorted_results = sorted(results, key=lambda item: item["index"])
|
||||
|
||||
return [result["relevance_score"] for result in sorted_results]
|
||||
|
||||
|
||||
async def litellm_rerank(
|
||||
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
|
||||
) -> list[float]:
|
||||
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
api_url,
|
||||
json={
|
||||
"model": model_name,
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return [
|
||||
item["relevance_score"]
|
||||
for item in sorted(result["results"], key=lambda x: x["index"])
|
||||
]
|
||||
|
||||
|
||||
@router.post("/bi-encoder-embed")
|
||||
async def route_bi_encoder_embed(
|
||||
request: Request,
|
||||
@@ -615,6 +180,13 @@ async def route_bi_encoder_embed(
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
|
||||
) -> EmbedResponse:
|
||||
# Only local models should use this endpoint - API providers should make direct API calls
|
||||
if embed_request.provider_type is not None:
|
||||
raise ValueError(
|
||||
f"Model server embedding endpoint should only be used for local models. "
|
||||
f"API provider '{embed_request.provider_type}' should make direct API calls instead."
|
||||
)
|
||||
|
||||
if not embed_request.texts:
|
||||
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
||||
|
||||
@@ -632,26 +204,12 @@ async def process_embed_request(
|
||||
embeddings = await embed_text(
|
||||
texts=embed_request.texts,
|
||||
model_name=embed_request.model_name,
|
||||
deployment_name=embed_request.deployment_name,
|
||||
max_context_length=embed_request.max_context_length,
|
||||
normalize_embeddings=embed_request.normalize_embeddings,
|
||||
api_key=embed_request.api_key,
|
||||
provider_type=embed_request.provider_type,
|
||||
text_type=embed_request.text_type,
|
||||
api_url=embed_request.api_url,
|
||||
api_version=embed_request.api_version,
|
||||
reduced_dimension=embed_request.reduced_dimension,
|
||||
prefix=prefix,
|
||||
gpu_type=gpu_type,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except AuthenticationError as e:
|
||||
# Handle authentication errors consistently
|
||||
logger.error(f"Authentication error: {e.provider}")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Authentication failed: {e.message}",
|
||||
)
|
||||
except RateLimitError as e:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
@@ -669,6 +227,13 @@ async def process_embed_request(
|
||||
@router.post("/cross-encoder-scores")
|
||||
async def process_rerank_request(rerank_request: RerankRequest) -> RerankResponse:
|
||||
"""Cross encoders can be purely black box from the app perspective"""
|
||||
# Only local models should use this endpoint - API providers should make direct API calls
|
||||
if rerank_request.provider_type is not None:
|
||||
raise ValueError(
|
||||
f"Model server reranking endpoint should only be used for local models. "
|
||||
f"API provider '{rerank_request.provider_type}' should make direct API calls instead."
|
||||
)
|
||||
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||
|
||||
@@ -680,55 +245,13 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
|
||||
raise ValueError("Empty documents cannot be reranked.")
|
||||
|
||||
try:
|
||||
if rerank_request.provider_type is None:
|
||||
sim_scores = await local_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
elif rerank_request.provider_type == RerankerProvider.LITELLM:
|
||||
if rerank_request.api_url is None:
|
||||
raise ValueError("API URL is required for LiteLLM reranking.")
|
||||
|
||||
sim_scores = await litellm_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
api_url=rerank_request.api_url,
|
||||
model_name=rerank_request.model_name,
|
||||
api_key=rerank_request.api_key,
|
||||
)
|
||||
|
||||
return RerankResponse(scores=sim_scores)
|
||||
|
||||
elif rerank_request.provider_type == RerankerProvider.COHERE:
|
||||
if rerank_request.api_key is None:
|
||||
raise RuntimeError("Cohere Rerank Requires an API Key")
|
||||
sim_scores = await cohere_rerank_api(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
api_key=rerank_request.api_key,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
|
||||
elif rerank_request.provider_type == RerankerProvider.BEDROCK:
|
||||
if rerank_request.api_key is None:
|
||||
raise RuntimeError("Bedrock Rerank Requires an API Key")
|
||||
aws_access_key_id, aws_secret_access_key, aws_region = pass_aws_key(
|
||||
rerank_request.api_key
|
||||
)
|
||||
sim_scores = await cohere_rerank_aws(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
region_name=aws_region,
|
||||
aws_access_key_id=aws_access_key_id,
|
||||
aws_secret_access_key=aws_secret_access_key,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {rerank_request.provider_type}")
|
||||
# At this point, provider_type is None, so handle local reranking
|
||||
sim_scores = await local_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error during reranking process:\n{str(e)}")
|
||||
|
||||
@@ -34,8 +34,8 @@ from shared_configs.configs import SENTRY_DSN
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
|
||||
HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/huggingface"
|
||||
TEMP_HF_CACHE_PATH = Path(os.path.expanduser("~")) / ".cache/temp_huggingface"
|
||||
HF_CACHE_PATH = Path(".cache/huggingface")
|
||||
TEMP_HF_CACHE_PATH = Path(".cache/temp_huggingface")
|
||||
|
||||
transformer_logging.set_verbosity_error()
|
||||
|
||||
|
||||
@@ -70,32 +70,3 @@ def get_gpu_type() -> str:
|
||||
return GPUStatus.MAC_MPS
|
||||
|
||||
return GPUStatus.NONE
|
||||
|
||||
|
||||
def pass_aws_key(api_key: str) -> tuple[str, str, str]:
|
||||
"""Parse AWS API key string into components.
|
||||
|
||||
Args:
|
||||
api_key: String in format 'aws_ACCESSKEY_SECRETKEY_REGION'
|
||||
|
||||
Returns:
|
||||
Tuple of (access_key, secret_key, region)
|
||||
|
||||
Raises:
|
||||
ValueError: If key format is invalid
|
||||
"""
|
||||
if not api_key.startswith("aws"):
|
||||
raise ValueError("API key must start with 'aws' prefix")
|
||||
|
||||
parts = api_key.split("_")
|
||||
if len(parts) != 4:
|
||||
raise ValueError(
|
||||
f"API key must be in format 'aws_ACCESSKEY_SECRETKEY_REGION', got {len(parts) - 1} parts"
|
||||
"this is an onyx specific format for formatting the aws secrets for bedrock"
|
||||
)
|
||||
|
||||
try:
|
||||
_, aws_access_key_id, aws_secret_access_key, aws_region = parts
|
||||
return aws_access_key_id, aws_secret_access_key, aws_region
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse AWS key components: {str(e)}")
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
from datetime import datetime
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
AnswerQuestionOutput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_individual_sub_answer.states import (
|
||||
SubQuestionAnsweringInput,
|
||||
)
|
||||
from onyx.agents.agent_search.deep_search.initial.generate_initial_answer.states import (
|
||||
SubQuestionRetrievalState,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
|
||||
|
||||
def parallelize_initial_sub_question_answering(
|
||||
state: SubQuestionRetrievalState,
|
||||
) -> list[Send | Hashable]:
|
||||
"""
|
||||
LangGraph edge to parallelize the initial sub-question answering. If there are no sub-questions,
|
||||
we send empty answers to the initial answer generation, and that answer would be generated
|
||||
solely based on the documents retrieved for the original question.
|
||||
"""
|
||||
edge_start_time = datetime.now()
|
||||
if len(state.initial_sub_questions) > 0:
|
||||
return [
|
||||
Send(
|
||||
"answer_query_subgraph",
|
||||
SubQuestionAnsweringInput(
|
||||
question=question,
|
||||
question_id=make_question_id(0, question_num + 1),
|
||||
log_messages=[
|
||||
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_num, question in enumerate(state.initial_sub_questions)
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
@@ -40,7 +40,7 @@ def parallelize_initial_sub_question_answering(
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_answers",
|
||||
"format_initial_sub_question_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
|
||||
@@ -43,36 +43,6 @@ def route_initial_tool_choice(
|
||||
return "call_tool"
|
||||
|
||||
|
||||
def parallelize_initial_sub_question_answering(
|
||||
state: MainState,
|
||||
) -> list[Send | Hashable]:
|
||||
edge_start_time = datetime.now()
|
||||
if len(state.initial_sub_questions) > 0:
|
||||
return [
|
||||
Send(
|
||||
"answer_query_subgraph",
|
||||
SubQuestionAnsweringInput(
|
||||
question=question,
|
||||
question_id=make_question_id(0, question_num + 1),
|
||||
log_messages=[
|
||||
f"{edge_start_time} -- Main Edge - Parallelize Initial Sub-question Answering"
|
||||
],
|
||||
),
|
||||
)
|
||||
for question_num, question in enumerate(state.initial_sub_questions)
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_answers",
|
||||
AnswerQuestionOutput(
|
||||
answer_results=[],
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
# Define the function that determines whether to continue or not
|
||||
def continue_to_refined_answer_or_end(
|
||||
state: RequireRefinemenEvalUpdate,
|
||||
|
||||
@@ -7,6 +7,17 @@ from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
|
||||
def remove_user_from_invited_users(email: str) -> int:
|
||||
try:
|
||||
store = get_kv_store()
|
||||
user_emails = cast(list, store.load(KV_USER_STORE_KEY))
|
||||
remaining_users = [user for user in user_emails if user != email]
|
||||
store.store(KV_USER_STORE_KEY, cast(JSON_ro, remaining_users))
|
||||
return len(remaining_users)
|
||||
except KvKeyNotFoundError:
|
||||
return 0
|
||||
|
||||
|
||||
def get_invited_users() -> list[str]:
|
||||
try:
|
||||
store = get_kv_store()
|
||||
|
||||
@@ -60,6 +60,7 @@ from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.email_utils import send_forgot_password_email
|
||||
from onyx.auth.email_utils import send_user_verification_email
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
@@ -241,7 +242,7 @@ def verify_email_domain(email: str) -> None:
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Email is not valid",
|
||||
)
|
||||
domain = email.split("@")[-1]
|
||||
domain = email.split("@")[-1].lower()
|
||||
if domain not in VALID_EMAIL_DOMAINS:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
@@ -350,6 +351,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
remove_user_from_invited_users(user_create.email)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
return user
|
||||
@@ -527,7 +529,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
):
|
||||
await self.user_db.update(user, {"oidc_expiry": None})
|
||||
user.oidc_expiry = None # type: ignore
|
||||
|
||||
remove_user_from_invited_users(user.email)
|
||||
if token:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@@ -231,10 +231,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
True if equivalent, False if not."""
|
||||
current_tasks = set(name for name, _ in schedule1)
|
||||
new_tasks = set(schedule2.keys())
|
||||
if current_tasks != new_tasks:
|
||||
return False
|
||||
|
||||
return True
|
||||
return current_tasks == new_tasks
|
||||
|
||||
|
||||
@beat_init.connect
|
||||
|
||||
@@ -32,7 +32,6 @@ from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_stop import RedisConnectorStop
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
@@ -161,7 +160,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
RedisUserGroup.reset_all(r)
|
||||
RedisConnectorDelete.reset_all(r)
|
||||
RedisConnectorPrune.reset_all(r)
|
||||
RedisConnectorIndex.reset_all(r)
|
||||
RedisConnectorStop.reset_all(r)
|
||||
RedisConnectorPermissionSync.reset_all(r)
|
||||
RedisConnectorExternalGroupSync.reset_all(r)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from pathlib import Path
|
||||
@@ -8,10 +10,12 @@ import httpx
|
||||
|
||||
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
|
||||
from onyx.connectors.connector_runner import batched_doc_ids
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
@@ -22,12 +26,14 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
PRUNING_CHECKPOINTED_BATCH_SIZE = 32
|
||||
|
||||
|
||||
def document_batch_to_ids(
|
||||
doc_batch: list[Document],
|
||||
) -> set[str]:
|
||||
return {doc.id for doc in doc_batch}
|
||||
doc_batch: Iterator[list[Document]],
|
||||
) -> Generator[set[str], None, None]:
|
||||
for doc_list in doc_batch:
|
||||
yield {doc.id for doc in doc_list}
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
@@ -46,33 +52,50 @@ def extract_ids_from_runnable_connector(
|
||||
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
|
||||
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
|
||||
|
||||
doc_batch_generator = None
|
||||
doc_batch_id_generator = None
|
||||
|
||||
if isinstance(runnable_connector, LoadConnector):
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.load_from_state()
|
||||
)
|
||||
elif isinstance(runnable_connector, PollConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime.now(timezone.utc).timestamp()
|
||||
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.poll_source(start=start, end=end)
|
||||
)
|
||||
elif isinstance(runnable_connector, CheckpointedConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime.now(timezone.utc).timestamp()
|
||||
checkpoint = runnable_connector.build_dummy_checkpoint()
|
||||
checkpoint_generator = runnable_connector.load_from_checkpoint(
|
||||
start=start, end=end, checkpoint=checkpoint
|
||||
)
|
||||
doc_batch_id_generator = batched_doc_ids(
|
||||
checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||
|
||||
doc_batch_processing_func = document_batch_to_ids
|
||||
# this function is called per batch for rate limiting
|
||||
def doc_batch_processing_func(doc_batch_ids: set[str]) -> set[str]:
|
||||
return doc_batch_ids
|
||||
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
|
||||
doc_batch_processing_func = rate_limit_builder(
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(document_batch_to_ids)
|
||||
for doc_batch in doc_batch_generator:
|
||||
)(lambda x: x)
|
||||
for doc_batch_ids in doc_batch_id_generator:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch_ids))
|
||||
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))
|
||||
callback.progress("extract_ids_from_runnable_connector", len(doc_batch_ids))
|
||||
|
||||
return all_connector_doc_ids
|
||||
|
||||
|
||||
@@ -193,12 +193,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
|
||||
task_logger.info(
|
||||
"Timed out waiting for tasks blocking deletion. Resetting blocking fences."
|
||||
)
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings.id
|
||||
)
|
||||
redis_connector_index.reset()
|
||||
|
||||
redis_connector.prune.reset()
|
||||
redis_connector.permissions.reset()
|
||||
redis_connector.external_group_sync.reset()
|
||||
|
||||
@@ -2,7 +2,6 @@ import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
|
||||
import sentry_sdk
|
||||
@@ -22,7 +21,7 @@ from onyx.background.celery.tasks.models import SimpleJobResult
|
||||
from onyx.background.indexing.job_client import SimpleJob
|
||||
from onyx.background.indexing.job_client import SimpleJobClient
|
||||
from onyx.background.indexing.job_client import SimpleJobException
|
||||
from onyx.background.indexing.run_docfetching import run_indexing_entrypoint
|
||||
from onyx.background.indexing.run_docfetching import run_docfetching_entrypoint
|
||||
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
@@ -34,7 +33,6 @@ from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
@@ -156,7 +154,6 @@ def _docfetching_task(
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector.new_index(search_settings_id)
|
||||
|
||||
# TODO: remove all fences, cause all signals to be set in postgres
|
||||
if redis_connector.delete.fenced:
|
||||
@@ -214,7 +211,7 @@ def _docfetching_task(
|
||||
)
|
||||
|
||||
# This is where the heavy/real work happens
|
||||
run_indexing_entrypoint(
|
||||
run_docfetching_entrypoint(
|
||||
app,
|
||||
index_attempt_id,
|
||||
tenant_id,
|
||||
@@ -261,7 +258,7 @@ def _docfetching_task(
|
||||
def process_job_result(
|
||||
job: SimpleJob,
|
||||
connector_source: str | None,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
index_attempt_id: int,
|
||||
log_builder: ConnectorIndexingLogBuilder,
|
||||
) -> SimpleJobResult:
|
||||
result = SimpleJobResult()
|
||||
@@ -278,13 +275,11 @@ def process_job_result(
|
||||
|
||||
# In EKS, there is an edge case where successful tasks return exit
|
||||
# code 1 in the cloud due to the set_spawn_method not sticking.
|
||||
# We've since worked around this, but the following is a safe way to
|
||||
# work around this issue. Basically, we ignore the job error state
|
||||
# if the completion signal is OK.
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
# Workaround: check that the total number of batches is set, since this only
|
||||
# happens when docfetching completed successfully
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
index_attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if index_attempt and index_attempt.total_batches is not None:
|
||||
ignore_exitcode = True
|
||||
|
||||
if ignore_exitcode:
|
||||
@@ -300,7 +295,11 @@ def process_job_result(
|
||||
if result.exit_code is not None:
|
||||
result.status = IndexingWatchdogTerminalStatus.from_code(result.exit_code)
|
||||
|
||||
result.exception_str = job.exception()
|
||||
job_level_exception = job.exception()
|
||||
result.exception_str = (
|
||||
f"Docfetching returned exit code {result.exit_code} "
|
||||
f"with exception: {job_level_exception}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@@ -458,9 +457,6 @@ def docfetching_proxy_task(
|
||||
)
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# Track the last time memory info was emitted
|
||||
last_memory_emit_time = 0.0
|
||||
|
||||
@@ -487,7 +483,7 @@ def docfetching_proxy_task(
|
||||
if job.done():
|
||||
try:
|
||||
result = process_job_result(
|
||||
job, result.connector_source, redis_connector_index, log_builder
|
||||
job, result.connector_source, index_attempt_id, log_builder
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
|
||||
@@ -5,6 +5,9 @@ from sqlalchemy import update
|
||||
from onyx.configs.constants import INDEXING_WORKER_HEARTBEAT_INTERVAL
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def start_heartbeat(index_attempt_id: int) -> tuple[threading.Thread, threading.Event]:
|
||||
@@ -21,9 +24,15 @@ def start_heartbeat(index_attempt_id: int) -> tuple[threading.Thread, threading.
|
||||
.values(heartbeat_counter=IndexAttempt.heartbeat_counter + 1)
|
||||
)
|
||||
db_session.commit()
|
||||
logger.debug(
|
||||
"Updated heartbeat counter for index attempt %s",
|
||||
index_attempt_id,
|
||||
)
|
||||
except Exception:
|
||||
# Silently continue if heartbeat fails
|
||||
pass
|
||||
logger.exception(
|
||||
"Failed to update heartbeat counter for index attempt %s",
|
||||
index_attempt_id,
|
||||
)
|
||||
|
||||
thread = threading.Thread(target=heartbeat_loop, daemon=True)
|
||||
thread.start()
|
||||
|
||||
@@ -4,7 +4,6 @@ from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
@@ -16,6 +15,8 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
|
||||
@@ -66,6 +67,7 @@ from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from onyx.db.index_attempt import mark_attempt_succeeded
|
||||
from onyx.db.indexing_coordination import CoordinationStatus
|
||||
from onyx.db.indexing_coordination import INDEXING_PROGRESS_TIMEOUT_HOURS
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
@@ -90,7 +92,6 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_utils import is_fence
|
||||
from onyx.server.runtime.onyx_runtime import OnyxRuntime
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
@@ -98,10 +99,16 @@ from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
USER_FILE_INDEXING_LIMIT = 100
|
||||
DOCPROCESSING_STALL_TIMEOUT_MULTIPLIER = 4
|
||||
DOCPROCESSING_HEARTBEAT_TIMEOUT_MULTIPLIER = 24
|
||||
# Heartbeat timeout: if no heartbeat received for 30 minutes, consider it dead
|
||||
# This should be much longer than INDEXING_WORKER_HEARTBEAT_INTERVAL (30s)
|
||||
HEARTBEAT_TIMEOUT_SECONDS = 30 * 60 # 30 minutes
|
||||
|
||||
|
||||
def _get_fence_validation_block_expiration() -> int:
|
||||
@@ -133,14 +140,10 @@ def validate_active_indexing_attempts(
|
||||
every INDEXING_WORKER_HEARTBEAT_INTERVAL seconds.
|
||||
"""
|
||||
logger.info("Validating active indexing attempts")
|
||||
# Heartbeat timeout: if no heartbeat received for 5 minutes, consider it dead
|
||||
# This should be much longer than INDEXING_WORKER_HEARTBEAT_INTERVAL (30s)
|
||||
HEARTBEAT_TIMEOUT_SECONDS = 30 * 60 # 30 minutes
|
||||
|
||||
heartbeat_timeout_seconds = HEARTBEAT_TIMEOUT_SECONDS
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(
|
||||
seconds=HEARTBEAT_TIMEOUT_SECONDS
|
||||
)
|
||||
|
||||
# Find all active indexing attempts
|
||||
active_attempts = (
|
||||
@@ -199,6 +202,15 @@ def validate_active_indexing_attempts(
|
||||
)
|
||||
continue
|
||||
|
||||
if fresh_attempt.total_batches and fresh_attempt.completed_batches == 0:
|
||||
heartbeat_timeout_seconds = (
|
||||
HEARTBEAT_TIMEOUT_SECONDS
|
||||
* DOCPROCESSING_HEARTBEAT_TIMEOUT_MULTIPLIER
|
||||
)
|
||||
cutoff_time = datetime.now(timezone.utc) - timedelta(
|
||||
seconds=heartbeat_timeout_seconds
|
||||
)
|
||||
|
||||
# Heartbeat hasn't advanced - check if it's been too long
|
||||
if last_check_time >= cutoff_time:
|
||||
task_logger.debug(
|
||||
@@ -208,7 +220,7 @@ def validate_active_indexing_attempts(
|
||||
|
||||
# No heartbeat for too long - mark as failed
|
||||
failure_reason = (
|
||||
f"No heartbeat received for {HEARTBEAT_TIMEOUT_SECONDS} seconds"
|
||||
f"No heartbeat received for {heartbeat_timeout_seconds} seconds"
|
||||
)
|
||||
|
||||
task_logger.warning(
|
||||
@@ -257,7 +269,7 @@ class ConnectorIndexingLogBuilder:
|
||||
|
||||
|
||||
def monitor_indexing_attempt_progress(
|
||||
attempt: IndexAttempt, tenant_id: str, db_session: Session
|
||||
attempt: IndexAttempt, tenant_id: str, db_session: Session, task: Task
|
||||
) -> None:
|
||||
"""
|
||||
TODO: rewrite this docstring
|
||||
@@ -316,7 +328,9 @@ def monitor_indexing_attempt_progress(
|
||||
|
||||
# Check task completion using Celery
|
||||
try:
|
||||
check_indexing_completion(attempt.id, coordination_status, storage, tenant_id)
|
||||
check_indexing_completion(
|
||||
attempt.id, coordination_status, storage, tenant_id, task
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to monitor document processing completion: "
|
||||
@@ -350,6 +364,7 @@ def check_indexing_completion(
|
||||
coordination_status: CoordinationStatus,
|
||||
storage: DocumentBatchStorage,
|
||||
tenant_id: str,
|
||||
task: Task,
|
||||
) -> None:
|
||||
|
||||
logger.info(
|
||||
@@ -376,20 +391,78 @@ def check_indexing_completion(
|
||||
|
||||
# Update progress tracking and check for stalls
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Update progress tracking
|
||||
stalled_timeout_hours = INDEXING_PROGRESS_TIMEOUT_HOURS
|
||||
# Index attempts that are waiting between docfetching and
|
||||
# docprocessing get a generous stalling timeout
|
||||
if batches_total is not None and batches_processed == 0:
|
||||
stalled_timeout_hours = (
|
||||
stalled_timeout_hours * DOCPROCESSING_STALL_TIMEOUT_MULTIPLIER
|
||||
)
|
||||
|
||||
timed_out = not IndexingCoordination.update_progress_tracking(
|
||||
db_session, index_attempt_id, batches_processed
|
||||
db_session,
|
||||
index_attempt_id,
|
||||
batches_processed,
|
||||
timeout_hours=stalled_timeout_hours,
|
||||
)
|
||||
|
||||
# Check for stalls (3-6 hour timeout)
|
||||
if timed_out:
|
||||
logger.error(
|
||||
f"Indexing attempt {index_attempt_id} has been indexing for 3-6 hours without progress. "
|
||||
f"Marking it as failed."
|
||||
)
|
||||
mark_attempt_failed(
|
||||
index_attempt_id, db_session, failure_reason="Stalled indexing"
|
||||
)
|
||||
# Check for stalls (3-6 hour timeout). Only applies to in-progress attempts.
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if attempt and timed_out:
|
||||
if attempt.status == IndexingStatus.IN_PROGRESS:
|
||||
logger.error(
|
||||
f"Indexing attempt {index_attempt_id} has been indexing for "
|
||||
f"{stalled_timeout_hours//2}-{stalled_timeout_hours} hours without progress. "
|
||||
f"Marking it as failed."
|
||||
)
|
||||
mark_attempt_failed(
|
||||
index_attempt_id, db_session, failure_reason="Stalled indexing"
|
||||
)
|
||||
elif (
|
||||
attempt.status == IndexingStatus.NOT_STARTED and attempt.celery_task_id
|
||||
):
|
||||
# Check if the task exists in the celery queue
|
||||
# This handles the case where Redis dies after task creation but before task execution
|
||||
redis_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
task_exists = celery_find_task(
|
||||
attempt.celery_task_id,
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
redis_celery,
|
||||
)
|
||||
unacked_task_ids = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, redis_celery
|
||||
)
|
||||
|
||||
if not task_exists and attempt.celery_task_id not in unacked_task_ids:
|
||||
# there is a race condition where the docfetching task has been taken off
|
||||
# the queues (i.e. started) but the indexing attempt still has a status of
|
||||
# Not Started because the switch to in progress takes like 0.1 seconds.
|
||||
# sleep a bit and confirm that the attempt is still not in progress.
|
||||
time.sleep(1)
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if attempt and attempt.status == IndexingStatus.NOT_STARTED:
|
||||
logger.error(
|
||||
f"Task {attempt.celery_task_id} attached to indexing attempt "
|
||||
f"{index_attempt_id} does not exist in the queue. "
|
||||
f"Marking indexing attempt as failed."
|
||||
)
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
failure_reason="Task not in queue",
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Indexing attempt {index_attempt_id} is {attempt.status}. 3-6 hours without heartbeat "
|
||||
"but task is in the queue. Likely underprovisioned docfetching worker."
|
||||
)
|
||||
# Update last progress time so we won't time out again for another 3 hours
|
||||
IndexingCoordination.update_progress_tracking(
|
||||
db_session,
|
||||
index_attempt_id,
|
||||
batches_processed,
|
||||
force_update_progress=True,
|
||||
)
|
||||
|
||||
# check again on the next check_for_indexing task
|
||||
# TODO: on the cloud this is currently 25 minutes at most, which
|
||||
@@ -432,7 +505,14 @@ def check_indexing_completion(
|
||||
ConnectorCredentialPairStatus.SCHEDULED,
|
||||
ConnectorCredentialPairStatus.INITIAL_INDEXING,
|
||||
]:
|
||||
cc_pair.status = ConnectorCredentialPairStatus.ACTIVE
|
||||
# User file connectors must be paused on success
|
||||
# NOTE: _run_indexing doesn't update connectors if the index attempt is the future embedding model
|
||||
# TODO: figure out why this doesn't pause connectors during swap
|
||||
cc_pair.status = (
|
||||
ConnectorCredentialPairStatus.PAUSED
|
||||
if cc_pair.is_user_file
|
||||
else ConnectorCredentialPairStatus.ACTIVE
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
# Clear repeated error state on success
|
||||
@@ -449,15 +529,6 @@ def check_indexing_completion(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# TODO: make it so we don't need this (might already be true)
|
||||
redis_connector = RedisConnector(
|
||||
tenant_id, attempt.connector_credential_pair_id
|
||||
)
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
attempt.search_settings_id
|
||||
)
|
||||
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
|
||||
|
||||
# Clean up FileStore storage (still needed for document batches during transition)
|
||||
try:
|
||||
logger.info(f"Cleaning up storage after indexing completion: {storage}")
|
||||
@@ -811,7 +882,9 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
|
||||
for attempt in active_attempts:
|
||||
try:
|
||||
monitor_indexing_attempt_progress(attempt, tenant_id, db_session)
|
||||
monitor_indexing_attempt_progress(
|
||||
attempt, tenant_id, db_session, self
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Error monitoring attempt {attempt.id}")
|
||||
|
||||
@@ -1015,9 +1088,12 @@ def docprocessing_task(
|
||||
# Start heartbeat for this indexing attempt
|
||||
heartbeat_thread, stop_event = start_heartbeat(index_attempt_id)
|
||||
try:
|
||||
# Cannot use the TaskSingleton approach here because the worker is multithreaded
|
||||
token = INDEX_ATTEMPT_INFO_CONTEXTVAR.set((cc_pair_id, index_attempt_id))
|
||||
_docprocessing_task(index_attempt_id, cc_pair_id, tenant_id, batch_num)
|
||||
finally:
|
||||
stop_heartbeat(heartbeat_thread, stop_event) # Stop heartbeat before exiting
|
||||
INDEX_ATTEMPT_INFO_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def _docprocessing_task(
|
||||
@@ -1028,9 +1104,6 @@ def _docprocessing_task(
|
||||
) -> None:
|
||||
start_time = time.monotonic()
|
||||
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
TaskAttemptSingleton.set_cc_and_index_id(index_attempt_id, cc_pair_id)
|
||||
if tenant_id:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
@@ -1085,12 +1158,8 @@ def _docprocessing_task(
|
||||
f"Index attempt {index_attempt_id} is not running, status {index_attempt.status}"
|
||||
)
|
||||
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
index_attempt.search_settings.id
|
||||
)
|
||||
|
||||
cross_batch_db_lock: RedisLock = r.lock(
|
||||
redis_connector_index.db_lock_key,
|
||||
redis_connector.db_lock_key(index_attempt.search_settings.id),
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
@@ -1230,17 +1299,6 @@ def _docprocessing_task(
|
||||
f"attempt={index_attempt_id} "
|
||||
)
|
||||
|
||||
# on failure, signal completion with an error to unblock the watchdog
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
index_attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if index_attempt and index_attempt.search_settings:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
index_attempt.search_settings.id
|
||||
)
|
||||
redis_connector_index.set_generator_complete(
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value
|
||||
)
|
||||
|
||||
raise
|
||||
finally:
|
||||
if per_batch_lock and per_batch_lock.owned():
|
||||
|
||||
@@ -47,7 +47,6 @@ from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.db.tag import delete_orphan_tags__no_commit
|
||||
@@ -519,9 +518,6 @@ def connector_pruning_generator_task(
|
||||
cc_pair.credential,
|
||||
)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
redis_connector.new_index(search_settings.id)
|
||||
|
||||
callback = PruneCallback(
|
||||
0,
|
||||
redis_connector,
|
||||
|
||||
@@ -153,10 +153,9 @@ class SimpleJob:
|
||||
if self._exception is None and self.queue and not self.queue.empty():
|
||||
self._exception = self.queue.get() # Get exception from queue
|
||||
|
||||
if self._exception:
|
||||
return self._exception
|
||||
|
||||
return f"Job with ID '{self.id}' did not report an exception."
|
||||
return (
|
||||
self._exception or f"Job with ID '{self.id}' did not report an exception."
|
||||
)
|
||||
|
||||
|
||||
class SimpleJobClient:
|
||||
|
||||
@@ -71,13 +71,13 @@ from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.middleware import make_randomized_onyx_request_id
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
|
||||
|
||||
logger = setup_logger(propagate=False)
|
||||
|
||||
@@ -226,8 +226,12 @@ def _check_connector_and_attempt_status(
|
||||
raise ConnectorStopSignal(f"Index attempt {index_attempt_id} was canceled")
|
||||
|
||||
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
|
||||
error_str = ""
|
||||
if index_attempt_loop.error_msg:
|
||||
error_str = f" Original error: {index_attempt_loop.error_msg}"
|
||||
|
||||
raise RuntimeError(
|
||||
f"Index Attempt is not running, status is {index_attempt_loop.status}"
|
||||
f"Index Attempt is not running, status is {index_attempt_loop.status}.{error_str}"
|
||||
)
|
||||
|
||||
if index_attempt_loop.celery_task_id is None:
|
||||
@@ -267,7 +271,7 @@ def _check_failure_threshold(
|
||||
|
||||
# NOTE: this is the old run_indexing function that the new decoupled approach
|
||||
# is based on. Leaving this for comparison purposes, but if you see this comment
|
||||
# has been here for >1 month, please delete this function.
|
||||
# has been here for >2 month, please delete this function.
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
@@ -832,7 +836,7 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
|
||||
def run_indexing_entrypoint(
|
||||
def run_docfetching_entrypoint(
|
||||
app: Celery,
|
||||
index_attempt_id: int,
|
||||
tenant_id: str,
|
||||
@@ -847,8 +851,8 @@ def run_indexing_entrypoint(
|
||||
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
TaskAttemptSingleton.set_cc_and_index_id(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
token = INDEX_ATTEMPT_INFO_CONTEXTVAR.set(
|
||||
(connector_credential_pair_id, index_attempt_id)
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
@@ -886,6 +890,8 @@ def run_indexing_entrypoint(
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
|
||||
INDEX_ATTEMPT_INFO_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def connector_document_extraction(
|
||||
app: Celery,
|
||||
@@ -1350,6 +1356,9 @@ def reissue_old_batches(
|
||||
)
|
||||
path_info = batch_storage.extract_path_info(batch_id)
|
||||
if path_info is None:
|
||||
logger.warning(
|
||||
f"Could not extract path info from batch {batch_id}, skipping"
|
||||
)
|
||||
continue
|
||||
if path_info.cc_pair_id != cc_pair_id:
|
||||
raise RuntimeError(f"Batch {batch_id} is not for cc pair {cc_pair_id}")
|
||||
|
||||
@@ -108,7 +108,11 @@ _VALID_EMAIL_DOMAINS_STR = (
|
||||
os.environ.get("VALID_EMAIL_DOMAINS", "") or _VALID_EMAIL_DOMAIN
|
||||
)
|
||||
VALID_EMAIL_DOMAINS = (
|
||||
[domain.strip() for domain in _VALID_EMAIL_DOMAINS_STR.split(",")]
|
||||
[
|
||||
domain.strip().lower()
|
||||
for domain in _VALID_EMAIL_DOMAINS_STR.split(",")
|
||||
if domain.strip()
|
||||
]
|
||||
if _VALID_EMAIL_DOMAINS_STR
|
||||
else []
|
||||
)
|
||||
@@ -121,6 +125,8 @@ OAUTH_CLIENT_SECRET = (
|
||||
os.environ.get("OAUTH_CLIENT_SECRET", os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET"))
|
||||
or ""
|
||||
)
|
||||
# OpenID Connect configuration URL for Okta Profile Tool and other OIDC integrations
|
||||
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL") or ""
|
||||
|
||||
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
|
||||
|
||||
@@ -359,6 +365,12 @@ POLL_CONNECTOR_OFFSET = 30 # Minutes overlap between poll windows
|
||||
# only very select connectors are enabled and admins cannot add other connector types
|
||||
ENABLED_CONNECTOR_TYPES = os.environ.get("ENABLED_CONNECTOR_TYPES") or ""
|
||||
|
||||
# If set to true, curators can only access and edit assistants that they created
|
||||
CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS = (
|
||||
os.environ.get("CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS", "").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
# Some calls to get information on expert users are quite costly especially with rate limiting
|
||||
# Since experts are not used in the actual user experience, currently it is turned off
|
||||
# for some connectors
|
||||
@@ -611,6 +623,17 @@ AVERAGE_SUMMARY_EMBEDDINGS = (
|
||||
|
||||
MAX_TOKENS_FOR_FULL_INCLUSION = 4096
|
||||
|
||||
|
||||
#####
|
||||
# Tool Configs
|
||||
#####
|
||||
OKTA_PROFILE_TOOL_ENABLED = (
|
||||
os.environ.get("OKTA_PROFILE_TOOL_ENABLED", "").lower() == "true"
|
||||
)
|
||||
# API token for SSWS auth to Okta Admin API. If set, Users API will be used to enrich profile.
|
||||
OKTA_API_TOKEN = os.environ.get("OKTA_API_TOKEN") or ""
|
||||
|
||||
|
||||
#####
|
||||
# Miscellaneous
|
||||
#####
|
||||
|
||||
@@ -25,6 +25,28 @@ TimeRange = tuple[datetime, datetime]
|
||||
CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
|
||||
def batched_doc_ids(
|
||||
checkpoint_connector_generator: CheckpointOutput[CT],
|
||||
batch_size: int,
|
||||
) -> Generator[set[str], None, None]:
|
||||
batch: set[str] = set()
|
||||
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
|
||||
checkpoint_connector_generator
|
||||
):
|
||||
if document is not None:
|
||||
batch.add(document.id)
|
||||
elif (
|
||||
failure and failure.failed_document and failure.failed_document.document_id
|
||||
):
|
||||
batch.add(failure.failed_document.document_id)
|
||||
|
||||
if len(batch) >= batch_size:
|
||||
yield batch
|
||||
batch = set()
|
||||
if len(batch) > 0:
|
||||
yield batch
|
||||
|
||||
|
||||
class CheckpointOutputWrapper(Generic[CT]):
|
||||
"""
|
||||
Wraps a CheckpointOutput generator to give things back in a more digestible format,
|
||||
|
||||
@@ -24,6 +24,7 @@ from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -72,6 +73,7 @@ def _process_file(
|
||||
file: IO[Any],
|
||||
metadata: dict[str, Any] | None,
|
||||
pdf_pass: str | None,
|
||||
file_type: str | None,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Process a file and return a list of Documents.
|
||||
@@ -148,6 +150,7 @@ def _process_file(
|
||||
file=file,
|
||||
file_name=file_name,
|
||||
pdf_pass=pdf_pass,
|
||||
content_type=file_type,
|
||||
)
|
||||
|
||||
# Each file may have file-specific ONYX_METADATA https://docs.onyx.app/connectors/file
|
||||
@@ -229,21 +232,18 @@ class LocalFileConnector(LoadConnector):
|
||||
|
||||
# Note: file_names is a required parameter, but should not break backwards compatibility.
|
||||
# If add_file_names migration is not run, old file connector configs will not have file_names.
|
||||
# This is fine because the configs are not re-used to instantiate the connector.
|
||||
# file_names is only used for display purposes in the UI and file_locations is used as a fallback.
|
||||
def __init__(
|
||||
self,
|
||||
file_locations: list[Path | str],
|
||||
file_names: list[
|
||||
str
|
||||
], # Must accept this parameter as connector_specific_config is unpacked as args
|
||||
zip_metadata: dict[str, Any],
|
||||
file_names: list[str] | None = None,
|
||||
zip_metadata: dict[str, Any] | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.file_locations = [str(loc) for loc in file_locations]
|
||||
self.batch_size = batch_size
|
||||
self.pdf_pass: str | None = None
|
||||
self.zip_metadata = zip_metadata
|
||||
self.zip_metadata = zip_metadata or {}
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.pdf_pass = credentials.get("pdf_password")
|
||||
@@ -278,6 +278,7 @@ class LocalFileConnector(LoadConnector):
|
||||
file=file_io,
|
||||
metadata=metadata,
|
||||
pdf_pass=self.pdf_pass,
|
||||
file_type=file_record.file_type,
|
||||
)
|
||||
documents.extend(new_docs)
|
||||
|
||||
|
||||
@@ -119,7 +119,19 @@ class LoopioConnector(LoadConnector, PollConnector):
|
||||
part["name"] for part in entry["location"].values() if part
|
||||
)
|
||||
|
||||
answer = parse_html_page_basic(entry.get("answer", {}).get("text", ""))
|
||||
answer_text = entry.get("answer", {}).get("text", "")
|
||||
if not answer_text:
|
||||
logger.warning(
|
||||
f"The Library entry {entry['id']} has no answer text. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
answer = parse_html_page_basic(answer_text)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing HTML for entry {entry['id']}: {e}")
|
||||
continue
|
||||
|
||||
questions = [
|
||||
question.get("text").replace("\xa0", " ")
|
||||
for question in entry["questions"]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import csv
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
@@ -28,8 +29,12 @@ from onyx.connectors.salesforce.doc_conversion import ID_PREFIX
|
||||
from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
|
||||
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
|
||||
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||
from onyx.connectors.salesforce.utils import get_sqlite_db_path
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -38,27 +43,27 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
_DEFAULT_PARENT_OBJECT_TYPES = [ACCOUNT_OBJECT_TYPE]
|
||||
|
||||
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
|
||||
"Opportunity": {
|
||||
"Account": "account",
|
||||
ACCOUNT_OBJECT_TYPE: "account",
|
||||
"FiscalQuarter": "fiscal_quarter",
|
||||
"FiscalYear": "fiscal_year",
|
||||
"IsClosed": "is_closed",
|
||||
"Name": "name",
|
||||
NAME_FIELD: "name",
|
||||
"StageName": "stage_name",
|
||||
"Type": "type",
|
||||
"Amount": "amount",
|
||||
"CloseDate": "close_date",
|
||||
"Probability": "probability",
|
||||
"CreatedDate": "created_date",
|
||||
"LastModifiedDate": "last_modified_date",
|
||||
MODIFIED_FIELD: "last_modified_date",
|
||||
},
|
||||
"Contact": {
|
||||
"Account": "account",
|
||||
ACCOUNT_OBJECT_TYPE: "account",
|
||||
"CreatedDate": "created_date",
|
||||
"LastModifiedDate": "last_modified_date",
|
||||
MODIFIED_FIELD: "last_modified_date",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -74,19 +79,77 @@ class SalesforceConnectorContext:
|
||||
parent_to_child_types: dict[str, set[str]] = {} # map from parent to child types
|
||||
child_to_parent_types: dict[str, set[str]] = {} # map from child to parent types
|
||||
parent_reference_fields_by_type: dict[str, dict[str, list[str]]] = {}
|
||||
type_to_queryable_fields: dict[str, list[str]] = {}
|
||||
type_to_queryable_fields: dict[str, set[str]] = {}
|
||||
prefix_to_type: dict[str, str] = {} # infer the object type of an id immediately
|
||||
|
||||
parent_to_child_relationships: dict[str, set[str]] = (
|
||||
{}
|
||||
) # map from parent to child relationships
|
||||
parent_to_relationship_queryable_fields: dict[str, dict[str, list[str]]] = (
|
||||
parent_to_relationship_queryable_fields: dict[str, dict[str, set[str]]] = (
|
||||
{}
|
||||
) # map from relationship to queryable fields
|
||||
|
||||
parent_child_names_to_relationships: dict[str, str] = {}
|
||||
|
||||
|
||||
def _extract_fields_and_associations_from_config(
|
||||
config: dict[str, Any], object_type: str
|
||||
) -> tuple[list[str] | None, dict[str, list[str]]]:
|
||||
"""
|
||||
Extract fields and associations for a specific object type from custom config.
|
||||
|
||||
Returns:
|
||||
tuple of (fields_list, associations_dict)
|
||||
- fields_list: List of fields to query, or None if not specified (use all)
|
||||
- associations_dict: Dict mapping association names to their config
|
||||
"""
|
||||
if object_type not in config:
|
||||
return None, {}
|
||||
|
||||
obj_config = config[object_type]
|
||||
fields = obj_config.get("fields")
|
||||
associations = obj_config.get("associations", {})
|
||||
|
||||
return fields, associations
|
||||
|
||||
|
||||
def _validate_custom_query_config(config: dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the structure of the custom query configuration.
|
||||
"""
|
||||
|
||||
for object_type, obj_config in config.items():
|
||||
if not isinstance(obj_config, dict):
|
||||
raise ValueError(
|
||||
f"top level object {object_type} must be mapped to a dictionary"
|
||||
)
|
||||
|
||||
# Check if fields is a list when present
|
||||
if "fields" in obj_config:
|
||||
if not isinstance(obj_config["fields"], list):
|
||||
raise ValueError("if fields key exists, value must be a list")
|
||||
for v in obj_config["fields"]:
|
||||
if not isinstance(v, str):
|
||||
raise ValueError(f"if fields list value {v} is not a string")
|
||||
|
||||
# Check if associations is a dict when present
|
||||
if "associations" in obj_config:
|
||||
if not isinstance(obj_config["associations"], dict):
|
||||
raise ValueError(
|
||||
"if associations key exists, value must be a dictionary"
|
||||
)
|
||||
for assoc_name, assoc_fields in obj_config["associations"].items():
|
||||
if not isinstance(assoc_fields, list):
|
||||
raise ValueError(
|
||||
f"associations list value {assoc_fields} for key {assoc_name} is not a list"
|
||||
)
|
||||
for v in assoc_fields:
|
||||
if not isinstance(v, str):
|
||||
raise ValueError(
|
||||
f"if associations list value {v} is not a string"
|
||||
)
|
||||
|
||||
|
||||
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
"""Approach outline
|
||||
|
||||
@@ -134,14 +197,25 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
requested_objects: list[str] = [],
|
||||
custom_query_config: str | None = None,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self._sf_client: OnyxSalesforce | None = None
|
||||
self.parent_object_list = (
|
||||
[obj.capitalize() for obj in requested_objects]
|
||||
if requested_objects
|
||||
else _DEFAULT_PARENT_OBJECT_TYPES
|
||||
)
|
||||
|
||||
# Validate and store custom query config
|
||||
if custom_query_config:
|
||||
config_json = json.loads(custom_query_config)
|
||||
self.custom_query_config: dict[str, Any] | None = config_json
|
||||
# If custom query config is provided, use the object types from it
|
||||
self.parent_object_list = list(config_json.keys())
|
||||
else:
|
||||
self.custom_query_config = None
|
||||
# Use the traditional requested_objects approach
|
||||
self.parent_object_list = (
|
||||
[obj.strip().capitalize() for obj in requested_objects]
|
||||
if requested_objects
|
||||
else _DEFAULT_PARENT_OBJECT_TYPES
|
||||
)
|
||||
|
||||
def load_credentials(
|
||||
self,
|
||||
@@ -187,7 +261,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
@staticmethod
|
||||
def _download_object_csvs(
|
||||
all_types_to_filter: dict[str, bool],
|
||||
queryable_fields_by_type: dict[str, list[str]],
|
||||
queryable_fields_by_type: dict[str, set[str]],
|
||||
directory: str,
|
||||
sf_client: OnyxSalesforce,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
@@ -325,9 +399,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# all_types.update(child_types.keys())
|
||||
|
||||
# # Always want to make sure user is grabbed for permissioning purposes
|
||||
# all_types.add("User")
|
||||
# all_types.add(USER_OBJECT_TYPE)
|
||||
# # Always want to make sure account is grabbed for reference purposes
|
||||
# all_types.add("Account")
|
||||
# all_types.add(ACCOUNT_OBJECT_TYPE)
|
||||
|
||||
# logger.info(f"All object types: num={len(all_types)} list={all_types}")
|
||||
|
||||
@@ -351,7 +425,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# all_types.update(child_types)
|
||||
|
||||
# # Always want to make sure user is grabbed for permissioning purposes
|
||||
# all_types.add("User")
|
||||
# all_types.add(USER_OBJECT_TYPE)
|
||||
|
||||
# logger.info(f"All object types: num={len(all_types)} list={all_types}")
|
||||
|
||||
@@ -364,7 +438,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
) -> GenerateDocumentsOutput:
|
||||
type_to_processed: dict[str, int] = {}
|
||||
|
||||
logger.info("_fetch_from_salesforce starting.")
|
||||
logger.info("_fetch_from_salesforce starting (full sync).")
|
||||
if not self._sf_client:
|
||||
raise RuntimeError("self._sf_client is None!")
|
||||
|
||||
@@ -548,7 +622,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
) -> GenerateDocumentsOutput:
|
||||
type_to_processed: dict[str, int] = {}
|
||||
|
||||
logger.info("_fetch_from_salesforce starting.")
|
||||
logger.info("_fetch_from_salesforce starting (delta sync).")
|
||||
if not self._sf_client:
|
||||
raise RuntimeError("self._sf_client is None!")
|
||||
|
||||
@@ -677,7 +751,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
try:
|
||||
last_modified_by_id = record["LastModifiedById"]
|
||||
user_record = self.sf_client.query_object(
|
||||
"User", last_modified_by_id, ctx.type_to_queryable_fields
|
||||
USER_OBJECT_TYPE,
|
||||
last_modified_by_id,
|
||||
ctx.type_to_queryable_fields,
|
||||
)
|
||||
if user_record:
|
||||
primary_owner = BasicExpertInfo.from_dict(user_record)
|
||||
@@ -792,7 +868,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
parent_reference_fields_by_type: dict[str, dict[str, list[str]]] = (
|
||||
{}
|
||||
) # for a given object, the fields reference parent objects
|
||||
type_to_queryable_fields: dict[str, list[str]] = {}
|
||||
type_to_queryable_fields: dict[str, set[str]] = {}
|
||||
prefix_to_type: dict[str, str] = {}
|
||||
|
||||
parent_to_child_relationships: dict[str, set[str]] = (
|
||||
@@ -802,15 +878,13 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# relationship keys are formatted as "parent__relationship"
|
||||
# we have to do this because relationship names are not unique!
|
||||
# values are a dict of relationship names to a list of queryable fields
|
||||
parent_to_relationship_queryable_fields: dict[str, dict[str, list[str]]] = {}
|
||||
parent_to_relationship_queryable_fields: dict[str, dict[str, set[str]]] = {}
|
||||
|
||||
parent_child_names_to_relationships: dict[str, str] = {}
|
||||
|
||||
full_sync = False
|
||||
if start is None and end is None:
|
||||
full_sync = True
|
||||
full_sync = start is None and end is None
|
||||
|
||||
# Step 1 - make a list of all the types to download (parent + direct child + "User")
|
||||
# Step 1 - make a list of all the types to download (parent + direct child + USER_OBJECT_TYPE)
|
||||
# prefixes = {}
|
||||
|
||||
global_description = sf_client.describe()
|
||||
@@ -831,16 +905,63 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
logger.info(f"Parent object types: num={len(parent_types)} list={parent_types}")
|
||||
for parent_type in parent_types:
|
||||
# parent_onyx_sf_type = OnyxSalesforceType(parent_type, sf_client)
|
||||
type_to_queryable_fields[parent_type] = (
|
||||
sf_client.get_queryable_fields_by_type(parent_type)
|
||||
)
|
||||
|
||||
child_types_working = sf_client.get_children_of_sf_type(parent_type)
|
||||
logger.debug(
|
||||
f"Found {len(child_types_working)} child types for {parent_type}"
|
||||
)
|
||||
custom_fields: list[str] | None = []
|
||||
associations_config: dict[str, list[str]] | None = None
|
||||
|
||||
# parent_to_child_relationships[parent_type] = child_types_working
|
||||
# Set queryable fields for parent type
|
||||
if self.custom_query_config:
|
||||
custom_fields, associations_config = (
|
||||
_extract_fields_and_associations_from_config(
|
||||
self.custom_query_config, parent_type
|
||||
)
|
||||
)
|
||||
custom_fields = custom_fields or []
|
||||
|
||||
# Get custom fields for parent type
|
||||
field_set = set(custom_fields)
|
||||
# these are expected and used during doc conversion
|
||||
field_set.add(NAME_FIELD)
|
||||
field_set.add(MODIFIED_FIELD)
|
||||
|
||||
# Use only the specified fields
|
||||
type_to_queryable_fields[parent_type] = field_set
|
||||
logger.info(f"Using custom fields for {parent_type}: {field_set}")
|
||||
else:
|
||||
# Use all queryable fields
|
||||
type_to_queryable_fields[parent_type] = (
|
||||
sf_client.get_queryable_fields_by_type(parent_type)
|
||||
)
|
||||
logger.info(f"Using all fields for {parent_type}")
|
||||
|
||||
child_types_all = sf_client.get_children_of_sf_type(parent_type)
|
||||
logger.debug(f"Found {len(child_types_all)} child types for {parent_type}")
|
||||
logger.debug(f"child types: {child_types_all}")
|
||||
|
||||
child_types_working = child_types_all.copy()
|
||||
if associations_config is not None:
|
||||
child_types_working = {
|
||||
k: v for k, v in child_types_all.items() if k in associations_config
|
||||
}
|
||||
any_not_found = False
|
||||
for k in associations_config:
|
||||
if k not in child_types_working:
|
||||
any_not_found = True
|
||||
logger.warning(f"Association {k} not found in {parent_type}")
|
||||
if any_not_found:
|
||||
queryable_fields = sf_client.get_queryable_fields_by_type(
|
||||
parent_type
|
||||
)
|
||||
raise RuntimeError(
|
||||
f"Associations {associations_config} not found in {parent_type} "
|
||||
"make sure your parent-child associations are in the right order"
|
||||
# f"with child objects {child_types_all}"
|
||||
# f" and fields {queryable_fields}"
|
||||
)
|
||||
|
||||
parent_to_child_relationships[parent_type] = set()
|
||||
parent_to_child_types[parent_type] = set()
|
||||
parent_to_relationship_queryable_fields[parent_type] = {}
|
||||
|
||||
for child_type, child_relationship in child_types_working.items():
|
||||
child_type = cast(str, child_type)
|
||||
@@ -848,8 +969,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# onyx_sf_type = OnyxSalesforceType(child_type, sf_client)
|
||||
|
||||
# map parent name to child name
|
||||
if parent_type not in parent_to_child_types:
|
||||
parent_to_child_types[parent_type] = set()
|
||||
parent_to_child_types[parent_type].add(child_type)
|
||||
|
||||
# reverse map child name to parent name
|
||||
@@ -858,19 +977,25 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
child_to_parent_types[child_type].add(parent_type)
|
||||
|
||||
# map parent name to child relationship
|
||||
if parent_type not in parent_to_child_relationships:
|
||||
parent_to_child_relationships[parent_type] = set()
|
||||
parent_to_child_relationships[parent_type].add(child_relationship)
|
||||
|
||||
# map relationship to queryable fields of the target table
|
||||
queryable_fields = sf_client.get_queryable_fields_by_type(child_type)
|
||||
if config_fields := (
|
||||
associations_config and associations_config.get(child_type)
|
||||
):
|
||||
field_set = set(config_fields)
|
||||
# these are expected and used during doc conversion
|
||||
field_set.add(NAME_FIELD)
|
||||
field_set.add(MODIFIED_FIELD)
|
||||
queryable_fields = field_set
|
||||
else:
|
||||
queryable_fields = sf_client.get_queryable_fields_by_type(
|
||||
child_type
|
||||
)
|
||||
|
||||
if child_relationship in parent_to_relationship_queryable_fields:
|
||||
raise RuntimeError(f"{child_relationship=} already exists")
|
||||
|
||||
if parent_type not in parent_to_relationship_queryable_fields:
|
||||
parent_to_relationship_queryable_fields[parent_type] = {}
|
||||
|
||||
parent_to_relationship_queryable_fields[parent_type][
|
||||
child_relationship
|
||||
] = queryable_fields
|
||||
@@ -894,14 +1019,22 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
all_types.update(child_types)
|
||||
|
||||
# NOTE(rkuo): should this be an implicit parent type?
|
||||
all_types.add("User") # Always add User for permissioning purposes
|
||||
all_types.add("Account") # Always add Account for reference purposes
|
||||
all_types.add(USER_OBJECT_TYPE) # Always add User for permissioning purposes
|
||||
all_types.add(ACCOUNT_OBJECT_TYPE) # Always add Account for reference purposes
|
||||
|
||||
logger.info(f"All object types: num={len(all_types)} list={all_types}")
|
||||
|
||||
# Ensure User and Account have queryable fields if they weren't already processed
|
||||
essential_types = [USER_OBJECT_TYPE, ACCOUNT_OBJECT_TYPE]
|
||||
for essential_type in essential_types:
|
||||
if essential_type not in type_to_queryable_fields:
|
||||
type_to_queryable_fields[essential_type] = (
|
||||
sf_client.get_queryable_fields_by_type(essential_type)
|
||||
)
|
||||
|
||||
# 1.1 - Detect all fields in child types which reference a parent type.
|
||||
# build dicts to detect relationships between parent and child
|
||||
for child_type in child_types:
|
||||
for child_type in child_types.union(essential_types):
|
||||
# onyx_sf_type = OnyxSalesforceType(child_type, sf_client)
|
||||
parent_reference_fields = sf_client.get_parent_reference_fields(
|
||||
child_type, parent_types
|
||||
@@ -1003,6 +1136,32 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
yield doc_metadata_list
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""
|
||||
Validate that the Salesforce credentials and connector settings are correct.
|
||||
Specifically checks that we can make an authenticated request to Salesforce.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Attempt to fetch a small batch of objects (arbitrary endpoint) to verify credentials
|
||||
self.sf_client.describe()
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(
|
||||
"Failed to validate Salesforce credentials. Please check your"
|
||||
f"credentials and try again. Error: {e}"
|
||||
)
|
||||
|
||||
if self.custom_query_config:
|
||||
try:
|
||||
_validate_custom_query_config(self.custom_query_config)
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(
|
||||
"Failed to validate Salesforce custom query config. Please check your"
|
||||
f"config and try again. Error: {e}"
|
||||
)
|
||||
|
||||
logger.info("Salesforce credentials validated successfully.")
|
||||
|
||||
# @override
|
||||
# def load_from_checkpoint(
|
||||
# self,
|
||||
@@ -1032,7 +1191,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = SalesforceConnector(requested_objects=["Account"])
|
||||
connector = SalesforceConnector(requested_objects=[ACCOUNT_OBJECT_TYPE])
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
|
||||
@@ -10,6 +10,8 @@ from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
|
||||
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -140,7 +142,7 @@ def _extract_primary_owner(
|
||||
first_name=user_data.get("FirstName"),
|
||||
last_name=user_data.get("LastName"),
|
||||
email=user_data.get("Email"),
|
||||
display_name=user_data.get("Name"),
|
||||
display_name=user_data.get(NAME_FIELD),
|
||||
)
|
||||
|
||||
# Check if all fields are None
|
||||
@@ -166,8 +168,8 @@ def convert_sf_query_result_to_doc(
|
||||
"""Generates a yieldable Document from query results"""
|
||||
|
||||
base_url = f"https://{sf_client.sf_instance}"
|
||||
extracted_doc_updated_at = time_str_to_utc(record["LastModifiedDate"])
|
||||
extracted_semantic_identifier = record.get("Name", "Unknown Object")
|
||||
extracted_doc_updated_at = time_str_to_utc(record[MODIFIED_FIELD])
|
||||
extracted_semantic_identifier = record.get(NAME_FIELD, "Unknown Object")
|
||||
|
||||
sections = [_extract_section(record, f"{base_url}/{record_id}")]
|
||||
for child_record_key, child_record in child_records.items():
|
||||
@@ -205,8 +207,8 @@ def convert_sf_object_to_doc(
|
||||
salesforce_id = object_dict["Id"]
|
||||
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
|
||||
base_url = f"https://{sf_instance}"
|
||||
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
|
||||
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
|
||||
extracted_doc_updated_at = time_str_to_utc(object_dict[MODIFIED_FIELD])
|
||||
extracted_semantic_identifier = object_dict.get(NAME_FIELD, "Unknown Object")
|
||||
|
||||
sections = [_extract_section(sf_object.data, f"{base_url}/{sf_object.id}")]
|
||||
for id in sf_db.get_child_ids(sf_object.id):
|
||||
|
||||
@@ -60,7 +60,7 @@ class OnyxSalesforce(Salesforce):
|
||||
return True
|
||||
|
||||
for suffix in SALESFORCE_BLACKLISTED_SUFFIXES:
|
||||
if object_type_lower.endswith(prefix):
|
||||
if object_type_lower.endswith(suffix):
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -112,7 +112,7 @@ class OnyxSalesforce(Salesforce):
|
||||
object_id: str,
|
||||
sf_type: str,
|
||||
child_relationships: list[str],
|
||||
relationships_to_fields: dict[str, list[str]],
|
||||
relationships_to_fields: dict[str, set[str]],
|
||||
) -> str:
|
||||
"""Returns a SOQL query given the object id, type and child relationships.
|
||||
|
||||
@@ -148,7 +148,7 @@ class OnyxSalesforce(Salesforce):
|
||||
self,
|
||||
object_type: str,
|
||||
object_id: str,
|
||||
type_to_queryable_fields: dict[str, list[str]],
|
||||
type_to_queryable_fields: dict[str, set[str]],
|
||||
) -> dict[str, Any] | None:
|
||||
record: dict[str, Any] = {}
|
||||
|
||||
@@ -172,7 +172,7 @@ class OnyxSalesforce(Salesforce):
|
||||
object_id: str,
|
||||
sf_type: str,
|
||||
child_relationships: list[str],
|
||||
relationships_to_fields: dict[str, list[str]],
|
||||
relationships_to_fields: dict[str, set[str]],
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""There's a limit on the number of subqueries we can put in a single query."""
|
||||
child_records: dict[str, dict[str, Any]] = {}
|
||||
@@ -264,10 +264,10 @@ class OnyxSalesforce(Salesforce):
|
||||
time.sleep(3)
|
||||
raise
|
||||
|
||||
def get_queryable_fields_by_type(self, name: str) -> list[str]:
|
||||
def get_queryable_fields_by_type(self, name: str) -> set[str]:
|
||||
object_description = self.describe_type(name)
|
||||
if object_description is None:
|
||||
return []
|
||||
return set()
|
||||
|
||||
fields: list[dict[str, Any]] = object_description["fields"]
|
||||
valid_fields: set[str] = set()
|
||||
@@ -286,7 +286,7 @@ class OnyxSalesforce(Salesforce):
|
||||
if field_name:
|
||||
valid_fields.add(field_name)
|
||||
|
||||
return list(valid_fields - field_names_to_remove)
|
||||
return valid_fields - field_names_to_remove
|
||||
|
||||
def get_children_of_sf_type(self, sf_type: str) -> dict[str, str]:
|
||||
"""Returns a dict of child object names to relationship names.
|
||||
|
||||
@@ -14,6 +14,7 @@ from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
@@ -54,12 +55,12 @@ def _build_created_date_time_filter_for_salesforce(
|
||||
|
||||
|
||||
def _make_time_filter_for_sf_type(
|
||||
queryable_fields: list[str],
|
||||
queryable_fields: set[str],
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
) -> str | None:
|
||||
|
||||
if "LastModifiedDate" in queryable_fields:
|
||||
if MODIFIED_FIELD in queryable_fields:
|
||||
return _build_last_modified_time_filter_for_salesforce(start, end)
|
||||
|
||||
if "CreatedDate" in queryable_fields:
|
||||
@@ -69,14 +70,14 @@ def _make_time_filter_for_sf_type(
|
||||
|
||||
|
||||
def _make_time_filtered_query(
|
||||
queryable_fields: list[str], sf_type: str, time_filter: str
|
||||
queryable_fields: set[str], sf_type: str, time_filter: str
|
||||
) -> str:
|
||||
query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}"
|
||||
return query
|
||||
|
||||
|
||||
def get_object_by_id_query(
|
||||
object_id: str, sf_type: str, queryable_fields: list[str]
|
||||
object_id: str, sf_type: str, queryable_fields: set[str]
|
||||
) -> str:
|
||||
query = (
|
||||
f"SELECT {', '.join(queryable_fields)} FROM {sf_type} WHERE Id = '{object_id}'"
|
||||
@@ -193,7 +194,7 @@ def _bulk_retrieve_from_salesforce(
|
||||
def fetch_all_csvs_in_parallel(
|
||||
sf_client: Salesforce,
|
||||
all_types_to_filter: dict[str, bool],
|
||||
queryable_fields_by_type: dict[str, list[str]],
|
||||
queryable_fields_by_type: dict[str, set[str]],
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
target_dir: str,
|
||||
|
||||
@@ -8,11 +8,15 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import validate_salesforce_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.utils import batch_list
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -567,7 +571,7 @@ class OnyxSalesforceSQLite:
|
||||
uncommitted_rows = 0
|
||||
|
||||
# If we're updating User objects, update the email map
|
||||
if object_type == "User":
|
||||
if object_type == USER_OBJECT_TYPE:
|
||||
OnyxSalesforceSQLite._update_user_email_map(cursor)
|
||||
|
||||
return updated_ids
|
||||
@@ -619,7 +623,7 @@ class OnyxSalesforceSQLite:
|
||||
with self._conn:
|
||||
cursor = self._conn.cursor()
|
||||
# Get the object data and account data
|
||||
if object_type == "Account" or isChild:
|
||||
if object_type == ACCOUNT_OBJECT_TYPE or isChild:
|
||||
cursor.execute(
|
||||
"SELECT data FROM salesforce_objects WHERE id = ?", (object_id,)
|
||||
)
|
||||
@@ -638,7 +642,7 @@ class OnyxSalesforceSQLite:
|
||||
|
||||
data = json.loads(result[0][0])
|
||||
|
||||
if object_type != "Account":
|
||||
if object_type != ACCOUNT_OBJECT_TYPE:
|
||||
|
||||
# convert any account ids of the relationships back into data fields, with name
|
||||
for row in result:
|
||||
@@ -647,14 +651,14 @@ class OnyxSalesforceSQLite:
|
||||
if len(row) < 3:
|
||||
continue
|
||||
|
||||
if row[1] and row[2] and row[2] == "Account":
|
||||
if row[1] and row[2] and row[2] == ACCOUNT_OBJECT_TYPE:
|
||||
data["AccountId"] = row[1]
|
||||
cursor.execute(
|
||||
"SELECT data FROM salesforce_objects WHERE id = ?",
|
||||
(row[1],),
|
||||
)
|
||||
account_data = json.loads(cursor.fetchone()[0])
|
||||
data["Account"] = account_data.get("Name", "")
|
||||
data[ACCOUNT_OBJECT_TYPE] = account_data.get(NAME_FIELD, "")
|
||||
|
||||
return SalesforceObject(id=object_id, type=object_type, data=data)
|
||||
|
||||
|
||||
@@ -2,6 +2,11 @@ import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
NAME_FIELD = "Name"
|
||||
MODIFIED_FIELD = "LastModifiedDate"
|
||||
ACCOUNT_OBJECT_TYPE = "Account"
|
||||
USER_OBJECT_TYPE = "User"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SalesforceObject:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
38
backend/onyx/connectors/sharepoint/connector_utils.py
Normal file
38
backend/onyx/connectors/sharepoint/connector_utils.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Any
|
||||
|
||||
from office365.graph_client import GraphClient # type: ignore[import-untyped]
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped]
|
||||
from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped]
|
||||
|
||||
from onyx.connectors.models import ExternalAccess
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
|
||||
|
||||
def get_sharepoint_external_access(
|
||||
ctx: ClientContext,
|
||||
graph_client: GraphClient,
|
||||
drive_item: DriveItem | None = None,
|
||||
drive_name: str | None = None,
|
||||
site_page: dict[str, Any] | None = None,
|
||||
add_prefix: bool = False,
|
||||
) -> ExternalAccess:
|
||||
if drive_item and drive_item.id is None:
|
||||
raise ValueError("DriveItem ID is required")
|
||||
|
||||
# Get external access using the EE implementation
|
||||
def noop_fallback(*args: Any, **kwargs: Any) -> ExternalAccess:
|
||||
return ExternalAccess.empty()
|
||||
|
||||
get_external_access_func = fetch_versioned_implementation_with_fallback(
|
||||
"onyx.external_permissions.sharepoint.permission_utils",
|
||||
"get_external_access_from_sharepoint",
|
||||
fallback=noop_fallback,
|
||||
)
|
||||
|
||||
external_access = get_external_access_func(
|
||||
ctx, graph_client, drive_name, drive_item, site_page, add_prefix
|
||||
)
|
||||
|
||||
return external_access
|
||||
@@ -267,6 +267,7 @@ class IndexingCoordination:
|
||||
index_attempt_id: int,
|
||||
current_batches_completed: int,
|
||||
timeout_hours: int = INDEXING_PROGRESS_TIMEOUT_HOURS,
|
||||
force_update_progress: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Update progress tracking for stall detection.
|
||||
@@ -281,7 +282,8 @@ class IndexingCoordination:
|
||||
current_time = get_db_current_time(db_session)
|
||||
|
||||
# No progress - check if this is the first time tracking
|
||||
if attempt.last_progress_time is None:
|
||||
# or if the caller wants to simulate guaranteed progress
|
||||
if attempt.last_progress_time is None or force_update_progress:
|
||||
# First time tracking - initialize
|
||||
attempt.last_progress_time = current_time
|
||||
attempt.last_batches_completed_count = current_batches_completed
|
||||
|
||||
@@ -1293,6 +1293,7 @@ class Tag(Base):
|
||||
source: Mapped[DocumentSource] = mapped_column(
|
||||
Enum(DocumentSource, native_enum=False)
|
||||
)
|
||||
is_list: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
documents = relationship(
|
||||
"Document",
|
||||
@@ -1302,7 +1303,11 @@ class Tag(Base):
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"tag_key", "tag_value", "source", name="_tag_key_value_source_uc"
|
||||
"tag_key",
|
||||
"tag_value",
|
||||
"source",
|
||||
"is_list",
|
||||
name="_tag_key_value_source_list_uc",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1685,12 +1690,14 @@ class IndexAttempt(Base):
|
||||
# can be taken to the FileStore to grab the actual checkpoint value
|
||||
checkpoint_pointer: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# NEW: Database-based coordination fields (replacing Redis fencing)
|
||||
# Database-based coordination fields (replacing Redis fencing)
|
||||
celery_task_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
cancellation_requested: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# NEW: Batch coordination fields (replacing FileStore state)
|
||||
# Batch coordination fields
|
||||
# Once this is set, docfetching has completed
|
||||
total_batches: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
# batches that are fully indexed (i.e. have completed docfetching and docprocessing)
|
||||
completed_batches: Mapped[int] = mapped_column(Integer, default=0)
|
||||
# TODO: unused, remove this column
|
||||
total_failures_batch_level: Mapped[int] = mapped_column(Integer, default=0)
|
||||
@@ -1702,7 +1709,7 @@ class IndexAttempt(Base):
|
||||
)
|
||||
last_batches_completed_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
# NEW: Heartbeat tracking for worker liveness detection
|
||||
# Heartbeat tracking for worker liveness detection
|
||||
heartbeat_counter: Mapped[int] = mapped_column(Integer, default=0)
|
||||
last_heartbeat_value: Mapped[int] = mapped_column(Integer, default=0)
|
||||
last_heartbeat_time: Mapped[datetime.datetime | None] = mapped_column(
|
||||
|
||||
@@ -15,6 +15,7 @@ from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.chat_configs import BING_API_KEY
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
@@ -96,6 +97,14 @@ def _add_user_filters(
|
||||
where_clause = Persona.is_public == True # noqa: E712
|
||||
return stmt.where(where_clause)
|
||||
|
||||
# If curator ownership restriction is enabled, curators can only access their own assistants
|
||||
if CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS and user.role in [
|
||||
UserRole.CURATOR,
|
||||
UserRole.GLOBAL_CURATOR,
|
||||
]:
|
||||
where_clause = (Persona.user_id == user.id) | (Persona.user_id.is_(None))
|
||||
return stmt.where(where_clause)
|
||||
|
||||
where_clause = User__UserGroup.user_id == user.id
|
||||
if user.role == UserRole.CURATOR and get_editable:
|
||||
where_clause &= User__UserGroup.is_curator == True # noqa: E712
|
||||
|
||||
@@ -47,11 +47,12 @@ def create_or_add_document_tag(
|
||||
Tag.tag_key == tag_key,
|
||||
Tag.tag_value == tag_value,
|
||||
Tag.source == source,
|
||||
Tag.is_list.is_(False),
|
||||
)
|
||||
tag = db_session.execute(tag_stmt).scalar_one_or_none()
|
||||
|
||||
if not tag:
|
||||
tag = Tag(tag_key=tag_key, tag_value=tag_value, source=source)
|
||||
tag = Tag(tag_key=tag_key, tag_value=tag_value, source=source, is_list=False)
|
||||
db_session.add(tag)
|
||||
|
||||
if tag not in document.tags:
|
||||
@@ -82,6 +83,7 @@ def create_or_add_document_tag_list(
|
||||
Tag.tag_key == tag_key,
|
||||
Tag.tag_value.in_(valid_tag_values),
|
||||
Tag.source == source,
|
||||
Tag.is_list.is_(True),
|
||||
)
|
||||
existing_tags = list(db_session.execute(existing_tags_stmt).scalars().all())
|
||||
existing_tag_values = {tag.tag_value for tag in existing_tags}
|
||||
@@ -89,7 +91,9 @@ def create_or_add_document_tag_list(
|
||||
new_tags = []
|
||||
for tag_value in valid_tag_values:
|
||||
if tag_value not in existing_tag_values:
|
||||
new_tag = Tag(tag_key=tag_key, tag_value=tag_value, source=source)
|
||||
new_tag = Tag(
|
||||
tag_key=tag_key, tag_value=tag_value, source=source, is_list=True
|
||||
)
|
||||
db_session.add(new_tag)
|
||||
new_tags.append(new_tag)
|
||||
existing_tag_values.add(tag_value)
|
||||
@@ -109,6 +113,45 @@ def create_or_add_document_tag_list(
|
||||
return all_tags
|
||||
|
||||
|
||||
def upsert_document_tags(
|
||||
document_id: str,
|
||||
source: DocumentSource,
|
||||
metadata: dict[str, str | list[str]],
|
||||
db_session: Session,
|
||||
) -> list[Tag]:
|
||||
document = db_session.get(Document, document_id)
|
||||
if not document:
|
||||
raise ValueError("Invalid Document, cannot attach Tags")
|
||||
|
||||
old_tag_ids: set[int] = {tag.id for tag in document.tags}
|
||||
|
||||
new_tags: list[Tag] = []
|
||||
new_tag_ids: set[int] = set()
|
||||
for k, v in metadata.items():
|
||||
if isinstance(v, list):
|
||||
new_tags.extend(
|
||||
create_or_add_document_tag_list(k, v, source, document_id, db_session)
|
||||
)
|
||||
new_tag_ids.update({tag.id for tag in new_tags})
|
||||
continue
|
||||
|
||||
new_tag = create_or_add_document_tag(k, v, source, document_id, db_session)
|
||||
if new_tag:
|
||||
new_tag_ids.add(new_tag.id)
|
||||
new_tags.append(new_tag)
|
||||
|
||||
delete_tags = old_tag_ids - new_tag_ids
|
||||
if delete_tags:
|
||||
delete_stmt = delete(Document__Tag).where(
|
||||
Document__Tag.document_id == document_id,
|
||||
Document__Tag.tag_id.in_(delete_tags),
|
||||
)
|
||||
db_session.execute(delete_stmt)
|
||||
db_session.commit()
|
||||
|
||||
return new_tags
|
||||
|
||||
|
||||
def find_tags(
|
||||
tag_key_prefix: str | None,
|
||||
tag_value_prefix: str | None,
|
||||
@@ -147,24 +190,37 @@ def find_tags(
|
||||
def get_structured_tags_for_document(
|
||||
document_id: str, db_session: Session
|
||||
) -> dict[str, str | list[str]]:
|
||||
"""Essentially returns the document metadata from postgres."""
|
||||
document = db_session.get(Document, document_id)
|
||||
if not document:
|
||||
raise ValueError("Invalid Document, cannot find tags")
|
||||
|
||||
document_metadata: dict[str, Any] = {}
|
||||
for tag in document.tags:
|
||||
if tag.tag_key in document_metadata:
|
||||
# NOTE: we convert to list if there are multiple values for the same key
|
||||
# Thus, it won't know if a tag is a list if it only contains one value
|
||||
if isinstance(document_metadata[tag.tag_key], str):
|
||||
document_metadata[tag.tag_key] = [
|
||||
document_metadata[tag.tag_key],
|
||||
tag.tag_value,
|
||||
]
|
||||
else:
|
||||
document_metadata[tag.tag_key].append(tag.tag_value)
|
||||
else:
|
||||
document_metadata[tag.tag_key] = tag.tag_value
|
||||
if tag.is_list:
|
||||
document_metadata.setdefault(tag.tag_key, [])
|
||||
# should always be a list (if tag.is_list is always True for this key), but just in case
|
||||
if not isinstance(document_metadata[tag.tag_key], list):
|
||||
logger.warning(
|
||||
"Inconsistent is_list for document %s, tag_key %s",
|
||||
document_id,
|
||||
tag.tag_key,
|
||||
)
|
||||
document_metadata[tag.tag_key] = [document_metadata[tag.tag_key]]
|
||||
document_metadata[tag.tag_key].append(tag.tag_value)
|
||||
continue
|
||||
|
||||
# set value (ignore duplicate keys, though there should be none)
|
||||
document_metadata.setdefault(tag.tag_key, tag.tag_value)
|
||||
|
||||
# should always be a value, but just in case (treat it as a list in this case)
|
||||
if isinstance(document_metadata[tag.tag_key], list):
|
||||
logger.warning(
|
||||
"Inconsistent is_list for document %s, tag_key %s",
|
||||
document_id,
|
||||
tag.tag_key,
|
||||
)
|
||||
document_metadata[tag.tag_key] = [document_metadata[tag.tag_key]]
|
||||
return document_metadata
|
||||
|
||||
|
||||
|
||||
@@ -12,8 +12,7 @@ from sqlalchemy.sql import expression
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
from sqlalchemy.sql.elements import KeyedColumnElement
|
||||
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.db.models import DocumentSet__User
|
||||
@@ -342,10 +341,4 @@ def delete_user_from_db(
|
||||
|
||||
# NOTE: edge case may exist with race conditions
|
||||
# with this `invited user` scheme generally.
|
||||
user_emails = get_invited_users()
|
||||
remaining_users = [
|
||||
remaining_user_email
|
||||
for remaining_user_email in user_emails
|
||||
if remaining_user_email != user_to_delete.email
|
||||
]
|
||||
write_invited_users(remaining_users)
|
||||
remove_user_from_invited_users(user_to_delete.email)
|
||||
|
||||
@@ -17,11 +17,11 @@ from typing import NamedTuple
|
||||
from zipfile import BadZipFile
|
||||
|
||||
import chardet
|
||||
import docx # type: ignore
|
||||
import openpyxl # type: ignore
|
||||
import pptx # type: ignore
|
||||
from docx import Document as DocxDocument
|
||||
from fastapi import UploadFile
|
||||
from markitdown import FileConversionException
|
||||
from markitdown import MarkItDown
|
||||
from markitdown import UnsupportedFormatException
|
||||
from PIL import Image
|
||||
from pypdf import PdfReader
|
||||
from pypdf.errors import PdfStreamError
|
||||
@@ -29,6 +29,7 @@ from pypdf.errors import PdfStreamError
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import ONYX_METADATA_FILENAME
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.file_processing.file_validation import TEXT_MIME_TYPE
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.file_processing.unstructured import get_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import unstructured_to_text
|
||||
@@ -83,11 +84,6 @@ IMAGE_MEDIA_TYPES = [
|
||||
"image/webp",
|
||||
]
|
||||
|
||||
KNOWN_OPENPYXL_BUGS = [
|
||||
"Value must be either numerical or a string containing a wildcard",
|
||||
"File contains no valid workbook part",
|
||||
]
|
||||
|
||||
|
||||
class OnyxExtensionType(IntFlag):
|
||||
Plain = auto()
|
||||
@@ -149,6 +145,13 @@ def is_macos_resource_fork_file(file_name: str) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def to_bytesio(stream: IO[bytes]) -> BytesIO:
|
||||
if isinstance(stream, BytesIO):
|
||||
return stream
|
||||
data = stream.read() # consumes the stream!
|
||||
return BytesIO(data)
|
||||
|
||||
|
||||
def load_files_from_zip(
|
||||
zip_file_io: IO,
|
||||
ignore_macos_resource_fork_files: bool = True,
|
||||
@@ -305,19 +308,38 @@ def read_pdf_file(
|
||||
return "", metadata, []
|
||||
|
||||
|
||||
def extract_docx_images(docx_bytes: IO[Any]) -> list[tuple[bytes, str]]:
|
||||
"""
|
||||
Given the bytes of a docx file, extract all the images.
|
||||
Returns a list of tuples (image_bytes, image_name).
|
||||
"""
|
||||
out = []
|
||||
try:
|
||||
with zipfile.ZipFile(docx_bytes) as z:
|
||||
for name in z.namelist():
|
||||
if name.startswith("word/media/"):
|
||||
out.append((z.read(name), name.split("/")[-1]))
|
||||
except Exception:
|
||||
logger.exception("Failed to extract all docx images")
|
||||
return out
|
||||
|
||||
|
||||
def docx_to_text_and_images(
|
||||
file: IO[Any], file_name: str = ""
|
||||
) -> tuple[str, Sequence[tuple[bytes, str]]]:
|
||||
"""
|
||||
Extract text from a docx. If embed_images=True, also extract inline images.
|
||||
Extract text from a docx.
|
||||
Return (text_content, list_of_images).
|
||||
"""
|
||||
paragraphs = []
|
||||
embedded_images: list[tuple[bytes, str]] = []
|
||||
|
||||
md = MarkItDown(enable_plugins=False)
|
||||
try:
|
||||
doc = docx.Document(file)
|
||||
except (BadZipFile, ValueError) as e:
|
||||
doc = md.convert(to_bytesio(file))
|
||||
except (
|
||||
BadZipFile,
|
||||
ValueError,
|
||||
FileConversionException,
|
||||
UnsupportedFormatException,
|
||||
) as e:
|
||||
logger.warning(
|
||||
f"Failed to extract docx {file_name or 'docx file'}: {e}. Attempting to read as text file."
|
||||
)
|
||||
@@ -330,96 +352,44 @@ def docx_to_text_and_images(
|
||||
)
|
||||
return text_content_raw or "", []
|
||||
|
||||
# Grab text from paragraphs
|
||||
for paragraph in doc.paragraphs:
|
||||
paragraphs.append(paragraph.text)
|
||||
|
||||
# Reset position so we can re-load the doc (python-docx has read the stream)
|
||||
# Note: if python-docx has fully consumed the stream, you may need to open it again from memory.
|
||||
# For large docs, a more robust approach is needed.
|
||||
# This is a simplified example.
|
||||
|
||||
for rel_id, rel in doc.part.rels.items():
|
||||
if "image" in rel.reltype:
|
||||
# Skip images that are linked rather than embedded (TargetMode="External")
|
||||
if getattr(rel, "is_external", False):
|
||||
continue
|
||||
|
||||
try:
|
||||
# image is typically in rel.target_part.blob
|
||||
image_bytes = rel.target_part.blob
|
||||
except ValueError:
|
||||
# Safeguard against relationships that lack an internal target_part
|
||||
# (e.g., external relationships or other anomalies)
|
||||
continue
|
||||
|
||||
image_name = rel.target_part.partname
|
||||
# store
|
||||
embedded_images.append((image_bytes, os.path.basename(str(image_name))))
|
||||
|
||||
text_content = "\n".join(paragraphs)
|
||||
return text_content, embedded_images
|
||||
file.seek(0)
|
||||
return doc.markdown, extract_docx_images(to_bytesio(file))
|
||||
|
||||
|
||||
def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
md = MarkItDown(enable_plugins=False)
|
||||
try:
|
||||
presentation = pptx.Presentation(file)
|
||||
except BadZipFile as e:
|
||||
presentation = md.convert(to_bytesio(file))
|
||||
except (
|
||||
BadZipFile,
|
||||
ValueError,
|
||||
FileConversionException,
|
||||
UnsupportedFormatException,
|
||||
) as e:
|
||||
error_str = f"Failed to extract text from {file_name or 'pptx file'}: {e}"
|
||||
logger.warning(error_str)
|
||||
return ""
|
||||
text_content = []
|
||||
for slide_number, slide in enumerate(presentation.slides, start=1):
|
||||
slide_text = f"\nSlide {slide_number}:\n"
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, "text"):
|
||||
slide_text += shape.text + "\n"
|
||||
text_content.append(slide_text)
|
||||
return TEXT_SECTION_SEPARATOR.join(text_content)
|
||||
return presentation.markdown
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
md = MarkItDown(enable_plugins=False)
|
||||
try:
|
||||
workbook = openpyxl.load_workbook(file, read_only=True)
|
||||
except BadZipFile as e:
|
||||
workbook = md.convert(to_bytesio(file))
|
||||
except (
|
||||
BadZipFile,
|
||||
ValueError,
|
||||
FileConversionException,
|
||||
UnsupportedFormatException,
|
||||
) as e:
|
||||
error_str = f"Failed to extract text from {file_name or 'xlsx file'}: {e}"
|
||||
if file_name.startswith("~"):
|
||||
logger.debug(error_str + " (this is expected for files with ~)")
|
||||
else:
|
||||
logger.warning(error_str)
|
||||
return ""
|
||||
except Exception as e:
|
||||
if any(s in str(e) for s in KNOWN_OPENPYXL_BUGS):
|
||||
logger.error(
|
||||
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
|
||||
)
|
||||
return ""
|
||||
raise e
|
||||
|
||||
text_content = []
|
||||
for sheet in workbook.worksheets:
|
||||
rows = []
|
||||
num_empty_consecutive_rows = 0
|
||||
for row in sheet.iter_rows(min_row=1, values_only=True):
|
||||
row_str = ",".join(str(cell or "") for cell in row)
|
||||
|
||||
# Only add the row if there are any values in the cells
|
||||
if len(row_str) >= len(row):
|
||||
rows.append(row_str)
|
||||
num_empty_consecutive_rows = 0
|
||||
else:
|
||||
num_empty_consecutive_rows += 1
|
||||
|
||||
if num_empty_consecutive_rows > 100:
|
||||
# handle massive excel sheets with mostly empty cells
|
||||
logger.warning(
|
||||
f"Found {num_empty_consecutive_rows} empty rows in {file_name},"
|
||||
" skipping rest of file"
|
||||
)
|
||||
break
|
||||
sheet_str = "\n".join(rows)
|
||||
text_content.append(sheet_str)
|
||||
return TEXT_SECTION_SEPARATOR.join(text_content)
|
||||
return workbook.markdown
|
||||
|
||||
|
||||
def eml_to_text(file: IO[Any]) -> str:
|
||||
@@ -472,9 +442,9 @@ def extract_file_text(
|
||||
"""
|
||||
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
|
||||
".pdf": pdf_to_text,
|
||||
".docx": lambda f: docx_to_text_and_images(f)[0], # no images
|
||||
".pptx": pptx_to_text,
|
||||
".xlsx": xlsx_to_text,
|
||||
".docx": lambda f: docx_to_text_and_images(f, file_name)[0], # no images
|
||||
".pptx": lambda f: pptx_to_text(f, file_name),
|
||||
".xlsx": lambda f: xlsx_to_text(f, file_name),
|
||||
".eml": eml_to_text,
|
||||
".epub": epub_to_text,
|
||||
".html": parse_html_page_basic,
|
||||
@@ -523,10 +493,23 @@ class ExtractionResult(NamedTuple):
|
||||
metadata: dict[str, Any]
|
||||
|
||||
|
||||
def extract_result_from_text_file(file: IO[Any]) -> ExtractionResult:
|
||||
encoding = detect_encoding(file)
|
||||
text_content_raw, file_metadata = read_text_file(
|
||||
file, encoding=encoding, ignore_onyx_metadata=False
|
||||
)
|
||||
return ExtractionResult(
|
||||
text_content=text_content_raw,
|
||||
embedded_images=[],
|
||||
metadata=file_metadata,
|
||||
)
|
||||
|
||||
|
||||
def extract_text_and_images(
|
||||
file: IO[Any],
|
||||
file_name: str,
|
||||
pdf_pass: str | None = None,
|
||||
content_type: str | None = None,
|
||||
) -> ExtractionResult:
|
||||
"""
|
||||
Primary new function for the updated connector.
|
||||
@@ -547,13 +530,20 @@ def extract_text_and_images(
|
||||
)
|
||||
file.seek(0) # Reset file pointer just in case
|
||||
|
||||
# When we upload a document via a connector or MyDocuments, we extract and store the content of files
|
||||
# with content types in UploadMimeTypes.DOCUMENT_MIME_TYPES as plain text files.
|
||||
# As a result, the file name extension may differ from the original content type.
|
||||
# We process files with a plain text content type first to handle this scenario.
|
||||
if content_type == TEXT_MIME_TYPE:
|
||||
return extract_result_from_text_file(file)
|
||||
|
||||
# Default processing
|
||||
try:
|
||||
extension = get_file_ext(file_name)
|
||||
|
||||
# docx example for embedded images
|
||||
if extension == ".docx":
|
||||
text_content, images = docx_to_text_and_images(file)
|
||||
text_content, images = docx_to_text_and_images(file, file_name)
|
||||
return ExtractionResult(
|
||||
text_content=text_content, embedded_images=images, metadata={}
|
||||
)
|
||||
@@ -605,15 +595,7 @@ def extract_text_and_images(
|
||||
|
||||
# If we reach here and it's a recognized text extension
|
||||
if is_text_file_extension(file_name):
|
||||
encoding = detect_encoding(file)
|
||||
text_content_raw, file_metadata = read_text_file(
|
||||
file, encoding=encoding, ignore_onyx_metadata=False
|
||||
)
|
||||
return ExtractionResult(
|
||||
text_content=text_content_raw,
|
||||
embedded_images=[],
|
||||
metadata=file_metadata,
|
||||
)
|
||||
return extract_result_from_text_file(file)
|
||||
|
||||
# If it's an image file or something else, we do not parse embedded images from them
|
||||
# just return empty text
|
||||
|
||||
@@ -21,6 +21,9 @@ EXCLUDED_IMAGE_TYPES = [
|
||||
"image/avif",
|
||||
]
|
||||
|
||||
# Text MIME types
|
||||
TEXT_MIME_TYPE = "text/plain"
|
||||
|
||||
|
||||
def is_valid_image_type(mime_type: str) -> bool:
|
||||
"""
|
||||
@@ -32,9 +35,11 @@ def is_valid_image_type(mime_type: str) -> bool:
|
||||
Returns:
|
||||
True if the MIME type is a valid image type, False otherwise
|
||||
"""
|
||||
if not mime_type:
|
||||
return False
|
||||
return mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||
return (
|
||||
bool(mime_type)
|
||||
and mime_type.startswith("image/")
|
||||
and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||
)
|
||||
|
||||
|
||||
def is_supported_by_vision_llm(mime_type: str) -> bool:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import re
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import IO
|
||||
|
||||
import bs4
|
||||
@@ -161,7 +162,7 @@ def format_document_soup(
|
||||
return strip_excessive_newlines_and_spaces(text)
|
||||
|
||||
|
||||
def parse_html_page_basic(text: str | IO[bytes]) -> str:
|
||||
def parse_html_page_basic(text: str | BytesIO | IO[bytes]) -> str:
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
@@ -196,6 +196,9 @@ class FileStoreDocumentBatchStorage(DocumentBatchStorage):
|
||||
for batch_file_name in batch_names:
|
||||
path_info = self.extract_path_info(batch_file_name)
|
||||
if path_info is None:
|
||||
logger.warning(
|
||||
f"Could not extract path info from batch file: {batch_file_name}"
|
||||
)
|
||||
continue
|
||||
new_batch_file_name = self._get_batch_file_name(path_info.batch_num)
|
||||
self.file_store.change_file_id(batch_file_name, new_batch_file_name)
|
||||
|
||||
@@ -19,6 +19,14 @@ class ChatFileType(str, Enum):
|
||||
# "user knowledge" is not a file type, it's a source or intent
|
||||
USER_KNOWLEDGE = "user_knowledge"
|
||||
|
||||
def is_text_file(self) -> bool:
|
||||
return self in (
|
||||
ChatFileType.PLAIN_TEXT,
|
||||
ChatFileType.DOC,
|
||||
ChatFileType.CSV,
|
||||
ChatFileType.USER_KNOWLEDGE,
|
||||
)
|
||||
|
||||
|
||||
class FileDescriptor(TypedDict):
|
||||
"""NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column
|
||||
|
||||
@@ -49,11 +49,10 @@ def sanitize_s3_key_name(file_name: str) -> str:
|
||||
|
||||
# Characters to avoid completely (replace with underscore)
|
||||
# These are characters that AWS recommends avoiding
|
||||
avoid_chars = r'[\\{}^%`\[\]"<>#|~]'
|
||||
avoid_chars = r'[\\{}^%`\[\]"<>#|~/]'
|
||||
|
||||
# Replace avoided characters with underscore
|
||||
sanitized = re.sub(avoid_chars, "_", file_name)
|
||||
|
||||
# Characters that might require special handling but are allowed
|
||||
# We'll URL encode these to be safe
|
||||
special_chars = r"[&$@=;:+,?\s]"
|
||||
@@ -81,6 +80,9 @@ def sanitize_s3_key_name(file_name: str) -> str:
|
||||
# Remove any trailing periods to avoid download issues
|
||||
sanitized = sanitized.rstrip(".")
|
||||
|
||||
# Remove multiple separators
|
||||
sanitized = re.sub(r"[-_]{2,}", "-", sanitized)
|
||||
|
||||
# If sanitization resulted in empty string, use a default
|
||||
if not sanitized:
|
||||
sanitized = "sanitized_file"
|
||||
|
||||
@@ -22,6 +22,8 @@ from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
RECENT_FOLDER_ID = -1
|
||||
|
||||
|
||||
def user_file_id_to_plaintext_file_name(user_file_id: int) -> str:
|
||||
"""Generate a consistent file name for storing plaintext content of a user file."""
|
||||
@@ -46,7 +48,6 @@ def store_user_file_plaintext(user_file_id: int, plaintext_content: str) -> bool
|
||||
# Get plaintext file name
|
||||
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
|
||||
|
||||
# Use a separate session to avoid committing the caller's transaction
|
||||
try:
|
||||
file_store = get_default_file_store()
|
||||
file_content = BytesIO(plaintext_content.encode("utf-8"))
|
||||
@@ -245,14 +246,21 @@ def get_user_files_as_user(
|
||||
Fetches all UserFile database records for a given user.
|
||||
"""
|
||||
user_files = get_user_files(user_file_ids, user_folder_ids, db_session)
|
||||
current_user_files = []
|
||||
for user_file in user_files:
|
||||
# Note: if user_id is None, then all files should be None as well
|
||||
# (since auth must be disabled in this case)
|
||||
if user_file.user_id != user_id:
|
||||
raise ValueError(
|
||||
f"User {user_id} does not have access to file {user_file.id}"
|
||||
)
|
||||
return user_files
|
||||
if user_file.folder_id == RECENT_FOLDER_ID:
|
||||
if user_file.user_id == user_id:
|
||||
current_user_files.append(user_file)
|
||||
else:
|
||||
if user_file.user_id != user_id:
|
||||
raise ValueError(
|
||||
f"User {user_id} does not have access to file {user_file.id}"
|
||||
)
|
||||
current_user_files.append(user_file)
|
||||
|
||||
return current_user_files
|
||||
|
||||
|
||||
def save_file_from_url(url: str) -> str:
|
||||
|
||||
@@ -44,8 +44,7 @@ from onyx.db.document_set import fetch_document_sets_for_documents
|
||||
from onyx.db.models import Document as DBDocument
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.tag import create_or_add_document_tag
|
||||
from onyx.db.tag import create_or_add_document_tag_list
|
||||
from onyx.db.tag import upsert_document_tags
|
||||
from onyx.db.user_documents import fetch_user_files_for_documents
|
||||
from onyx.db.user_documents import fetch_user_folders_for_documents
|
||||
from onyx.db.user_documents import update_user_file_token_count__no_commit
|
||||
@@ -150,24 +149,12 @@ def _upsert_documents_in_db(
|
||||
|
||||
# Insert document content metadata
|
||||
for doc in documents:
|
||||
for k, v in doc.metadata.items():
|
||||
if isinstance(v, list):
|
||||
create_or_add_document_tag_list(
|
||||
tag_key=k,
|
||||
tag_values=v,
|
||||
source=doc.source,
|
||||
document_id=doc.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
continue
|
||||
|
||||
create_or_add_document_tag(
|
||||
tag_key=k,
|
||||
tag_value=v,
|
||||
source=doc.source,
|
||||
document_id=doc.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
upsert_document_tags(
|
||||
document_id=doc.id,
|
||||
source=doc.source,
|
||||
metadata=doc.metadata,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
def _get_aggregated_chunk_boost_factor(
|
||||
@@ -867,31 +854,27 @@ def index_doc_batch(
|
||||
user_file_id_to_raw_text: dict[int, str] = {}
|
||||
for document_id in updatable_ids:
|
||||
# Only calculate token counts for documents that have a user file ID
|
||||
if (
|
||||
document_id in doc_id_to_user_file_id
|
||||
and doc_id_to_user_file_id[document_id] is not None
|
||||
):
|
||||
user_file_id = doc_id_to_user_file_id[document_id]
|
||||
if not user_file_id:
|
||||
continue
|
||||
document_chunks = [
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == document_id
|
||||
]
|
||||
if document_chunks:
|
||||
combined_content = " ".join(
|
||||
[chunk.content for chunk in document_chunks]
|
||||
)
|
||||
token_count = (
|
||||
len(llm_tokenizer.encode(combined_content))
|
||||
if llm_tokenizer
|
||||
else 0
|
||||
)
|
||||
user_file_id_to_token_count[user_file_id] = token_count
|
||||
user_file_id_to_raw_text[user_file_id] = combined_content
|
||||
else:
|
||||
user_file_id_to_token_count[user_file_id] = None
|
||||
|
||||
user_file_id = doc_id_to_user_file_id.get(document_id)
|
||||
if user_file_id is None:
|
||||
continue
|
||||
|
||||
document_chunks = [
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == document_id
|
||||
]
|
||||
if document_chunks:
|
||||
combined_content = " ".join(
|
||||
[chunk.content for chunk in document_chunks]
|
||||
)
|
||||
token_count = (
|
||||
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
|
||||
)
|
||||
user_file_id_to_token_count[user_file_id] = token_count
|
||||
user_file_id_to_raw_text[user_file_id] = combined_content
|
||||
else:
|
||||
user_file_id_to_token_count[user_file_id] = None
|
||||
|
||||
# we're concerned about race conditions where multiple simultaneous indexings might result
|
||||
# in one set of metadata overwriting another one in vespa.
|
||||
|
||||
@@ -24,6 +24,7 @@ from langchain_core.messages import SystemMessageChunk
|
||||
from langchain_core.messages.tool import ToolCallChunk
|
||||
from langchain_core.messages.tool import ToolMessage
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
from litellm.utils import get_supported_openai_params
|
||||
|
||||
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
|
||||
@@ -52,6 +53,8 @@ litellm.telemetry = False
|
||||
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
|
||||
VERTEX_LOCATION_KWARG = "vertex_location"
|
||||
LEGACY_MAX_TOKENS_KWARG = "max_tokens"
|
||||
STANDARD_MAX_TOKENS_KWARG = "max_completion_tokens"
|
||||
|
||||
|
||||
class LLMTimeoutError(Exception):
|
||||
@@ -313,14 +316,22 @@ class DefaultMultiLLM(LLM):
|
||||
|
||||
self._model_kwargs = model_kwargs
|
||||
|
||||
def log_model_configs(self) -> None:
|
||||
logger.debug(f"Config: {self.config}")
|
||||
self._max_token_param = LEGACY_MAX_TOKENS_KWARG
|
||||
try:
|
||||
params = get_supported_openai_params(model_name, model_provider)
|
||||
if STANDARD_MAX_TOKENS_KWARG in (params or []):
|
||||
self._max_token_param = STANDARD_MAX_TOKENS_KWARG
|
||||
except Exception as e:
|
||||
logger.warning(f"Error getting supported openai params: {e}")
|
||||
|
||||
def _safe_model_config(self) -> dict:
|
||||
dump = self.config.model_dump()
|
||||
dump["api_key"] = mask_string(dump.get("api_key", ""))
|
||||
return dump
|
||||
|
||||
def log_model_configs(self) -> None:
|
||||
logger.debug(f"Config: {self._safe_model_config()}")
|
||||
|
||||
def _record_call(self, prompt: LanguageModelInput) -> None:
|
||||
if self._long_term_logger:
|
||||
self._long_term_logger.record(
|
||||
@@ -393,11 +404,14 @@ class DefaultMultiLLM(LLM):
|
||||
messages=processed_prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice if tools else None,
|
||||
max_tokens=max_tokens,
|
||||
# streaming choice
|
||||
stream=stream,
|
||||
# model params
|
||||
temperature=self._temperature,
|
||||
temperature=(
|
||||
1
|
||||
if self.config.model_name in ["gpt-5", "gpt-5-mini", "gpt-5-nano"]
|
||||
else self._temperature
|
||||
),
|
||||
timeout=timeout_override or self._timeout,
|
||||
# For now, we don't support parallel tool calls
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
@@ -422,6 +436,7 @@ class DefaultMultiLLM(LLM):
|
||||
if structured_response_format
|
||||
else {}
|
||||
),
|
||||
**({self._max_token_param: max_tokens} if max_tokens else {}),
|
||||
**self._model_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -47,6 +47,9 @@ class WellKnownLLMProviderDescriptor(BaseModel):
|
||||
|
||||
OPENAI_PROVIDER_NAME = "openai"
|
||||
OPEN_AI_MODEL_NAMES = [
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"gpt-5-nano",
|
||||
"o4-mini",
|
||||
"o3-mini",
|
||||
"o1-mini",
|
||||
@@ -73,7 +76,14 @@ OPEN_AI_MODEL_NAMES = [
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-0301",
|
||||
]
|
||||
OPEN_AI_VISIBLE_MODEL_NAMES = ["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"]
|
||||
OPEN_AI_VISIBLE_MODEL_NAMES = [
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"o1",
|
||||
"o3-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
]
|
||||
|
||||
BEDROCK_PROVIDER_NAME = "bedrock"
|
||||
# need to remove all the weird "bedrock/eu-central-1/anthropic.claude-v1" named
|
||||
|
||||
@@ -72,6 +72,7 @@ class PreviousMessage(BaseModel):
|
||||
message_type = MessageType.USER
|
||||
elif isinstance(msg, AIMessage):
|
||||
message_type = MessageType.ASSISTANT
|
||||
|
||||
message = message_to_string(msg)
|
||||
return cls(
|
||||
message=message,
|
||||
|
||||
@@ -136,16 +136,7 @@ def _build_content(
|
||||
if not files:
|
||||
return message
|
||||
|
||||
text_files = [
|
||||
file
|
||||
for file in files
|
||||
if file.file_type
|
||||
in (
|
||||
ChatFileType.PLAIN_TEXT,
|
||||
ChatFileType.CSV,
|
||||
ChatFileType.USER_KNOWLEDGE,
|
||||
)
|
||||
]
|
||||
text_files = [file for file in files if file.file_type.is_text_file()]
|
||||
|
||||
if not text_files:
|
||||
return message
|
||||
|
||||
40
backend/onyx/natural_language_processing/constants.py
Normal file
40
backend/onyx/natural_language_processing/constants.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""
|
||||
Constants for natural language processing, including embedding and reranking models.
|
||||
|
||||
This file contains constants moved from model_server to support the gradual migration
|
||||
of API-based calls to bypass the model server.
|
||||
"""
|
||||
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
# Default model names for different providers
|
||||
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
|
||||
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
|
||||
DEFAULT_VERTEX_MODEL = "text-embedding-005"
|
||||
|
||||
|
||||
class EmbeddingModelTextType:
|
||||
"""Mapping of Onyx text types to provider-specific text types."""
|
||||
|
||||
PROVIDER_TEXT_TYPE_MAP = {
|
||||
EmbeddingProvider.COHERE: {
|
||||
EmbedTextType.QUERY: "search_query",
|
||||
EmbedTextType.PASSAGE: "search_document",
|
||||
},
|
||||
EmbeddingProvider.VOYAGE: {
|
||||
EmbedTextType.QUERY: "query",
|
||||
EmbedTextType.PASSAGE: "document",
|
||||
},
|
||||
EmbeddingProvider.GOOGLE: {
|
||||
EmbedTextType.QUERY: "RETRIEVAL_QUERY",
|
||||
EmbedTextType.PASSAGE: "RETRIEVAL_DOCUMENT",
|
||||
},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
|
||||
"""Get provider-specific text type string."""
|
||||
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
|
||||
@@ -1,3 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
@@ -5,14 +7,26 @@ from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from functools import wraps
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import aioboto3 # type: ignore
|
||||
import httpx
|
||||
import openai
|
||||
import requests
|
||||
import vertexai # type: ignore
|
||||
import voyageai # type: ignore
|
||||
from cohere import AsyncClient as CohereAsyncClient
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from httpx import HTTPError
|
||||
from litellm import aembedding
|
||||
from requests import JSONDecodeError
|
||||
from requests import RequestException
|
||||
from requests import Response
|
||||
from retry import retry
|
||||
from vertexai.language_models import TextEmbeddingInput # type: ignore
|
||||
from vertexai.language_models import TextEmbeddingModel # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS
|
||||
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
|
||||
@@ -25,16 +39,26 @@ from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from onyx.connectors.models import ConnectorStopSignal
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.natural_language_processing.constants import DEFAULT_COHERE_MODEL
|
||||
from onyx.natural_language_processing.constants import DEFAULT_OPENAI_MODEL
|
||||
from onyx.natural_language_processing.constants import DEFAULT_VERTEX_MODEL
|
||||
from onyx.natural_language_processing.constants import DEFAULT_VOYAGE_MODEL
|
||||
from onyx.natural_language_processing.constants import EmbeddingModelTextType
|
||||
from onyx.natural_language_processing.exceptions import (
|
||||
ModelServerRateLimitError,
|
||||
)
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.natural_language_processing.utils import tokenizer_trim_content
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.search_nlp_models_utils import pass_aws_key
|
||||
from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
|
||||
from shared_configs.configs import VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
@@ -53,6 +77,21 @@ from shared_configs.utils import batch_list
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# If we are not only indexing, dont want retry very long
|
||||
_RETRY_DELAY = 10 if INDEXING_ONLY else 0.1
|
||||
_RETRY_TRIES = 10 if INDEXING_ONLY else 2
|
||||
|
||||
# OpenAI only allows 2048 embeddings to be computed at once
|
||||
_OPENAI_MAX_INPUT_LEN = 2048
|
||||
# Cohere allows up to 96 embeddings in a single embedding calling
|
||||
_COHERE_MAX_INPUT_LEN = 96
|
||||
|
||||
# Authentication error string constants
|
||||
_AUTH_ERROR_401 = "401"
|
||||
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
|
||||
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
|
||||
_AUTH_ERROR_PERMISSION = "permission"
|
||||
|
||||
|
||||
WARM_UP_STRINGS = [
|
||||
"Onyx is amazing!",
|
||||
@@ -79,6 +118,377 @@ def build_model_server_url(
|
||||
return f"http://{model_server_url}"
|
||||
|
||||
|
||||
def is_authentication_error(error: Exception) -> bool:
|
||||
"""Check if an exception is related to authentication issues.
|
||||
|
||||
Args:
|
||||
error: The exception to check
|
||||
|
||||
Returns:
|
||||
bool: True if the error appears to be authentication-related
|
||||
"""
|
||||
error_str = str(error).lower()
|
||||
return (
|
||||
_AUTH_ERROR_401 in error_str
|
||||
or _AUTH_ERROR_UNAUTHORIZED in error_str
|
||||
or _AUTH_ERROR_INVALID_API_KEY in error_str
|
||||
or _AUTH_ERROR_PERMISSION in error_str
|
||||
)
|
||||
|
||||
|
||||
def format_embedding_error(
|
||||
error: Exception,
|
||||
service_name: str,
|
||||
model: str | None,
|
||||
provider: EmbeddingProvider,
|
||||
sanitized_api_key: str | None = None,
|
||||
status_code: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format a standardized error string for embedding errors.
|
||||
"""
|
||||
detail = f"Status {status_code}" if status_code else f"{type(error)}"
|
||||
|
||||
return (
|
||||
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
|
||||
f"Model: {model} "
|
||||
f"Provider: {provider} "
|
||||
f"API Key: {sanitized_api_key} "
|
||||
f"Exception: {error}"
|
||||
)
|
||||
|
||||
|
||||
# Custom exception for authentication errors
|
||||
class AuthenticationError(Exception):
|
||||
"""Raised when authentication fails with a provider."""
|
||||
|
||||
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
|
||||
self.provider = provider
|
||||
self.message = message
|
||||
super().__init__(f"{provider} authentication failed: {message}")
|
||||
|
||||
|
||||
class CloudEmbedding:
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
provider: EmbeddingProvider,
|
||||
api_url: str | None = None,
|
||||
api_version: str | None = None,
|
||||
timeout: int = API_BASED_EMBEDDING_TIMEOUT,
|
||||
) -> None:
|
||||
self.provider = provider
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url
|
||||
self.api_version = api_version
|
||||
self.timeout = timeout
|
||||
self.http_client = httpx.AsyncClient(timeout=timeout)
|
||||
self._closed = False
|
||||
self.sanitized_api_key = api_key[:4] + "********" + api_key[-4:]
|
||||
|
||||
async def _embed_openai(
|
||||
self, texts: list[str], model: str | None, reduced_dimension: int | None
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_OPENAI_MODEL
|
||||
|
||||
# Use the OpenAI specific timeout for this one
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT
|
||||
)
|
||||
|
||||
final_embeddings: list[Embedding] = []
|
||||
|
||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||
response = await client.embeddings.create(
|
||||
input=text_batch,
|
||||
model=model,
|
||||
dimensions=reduced_dimension or openai.NOT_GIVEN,
|
||||
)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
)
|
||||
return final_embeddings
|
||||
|
||||
async def _embed_cohere(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_COHERE_MODEL
|
||||
|
||||
client = CohereAsyncClient(api_key=self.api_key)
|
||||
|
||||
final_embeddings: list[Embedding] = []
|
||||
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
|
||||
# Does not use the same tokenizer as the Onyx API server but it's approximately the same
|
||||
# empirically it's only off by a very few tokens so it's not a big deal
|
||||
response = await client.embed(
|
||||
texts=text_batch,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncate="END",
|
||||
)
|
||||
final_embeddings.extend(cast(list[Embedding], response.embeddings))
|
||||
return final_embeddings
|
||||
|
||||
async def _embed_voyage(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_VOYAGE_MODEL
|
||||
|
||||
client = voyageai.AsyncClient(
|
||||
api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT
|
||||
)
|
||||
|
||||
response = await client.embed(
|
||||
texts=texts,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncation=True,
|
||||
)
|
||||
return response.embeddings
|
||||
|
||||
async def _embed_azure(
|
||||
self, texts: list[str], model: str | None
|
||||
) -> list[Embedding]:
|
||||
response = await aembedding(
|
||||
model=model,
|
||||
input=texts,
|
||||
timeout=API_BASED_EMBEDDING_TIMEOUT,
|
||||
api_key=self.api_key,
|
||||
api_base=self.api_url,
|
||||
api_version=self.api_version,
|
||||
)
|
||||
embeddings = [embedding["embedding"] for embedding in response.data]
|
||||
return embeddings
|
||||
|
||||
async def _embed_vertex(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_VERTEX_MODEL
|
||||
|
||||
service_account_info = json.loads(self.api_key)
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
service_account_info
|
||||
)
|
||||
project_id = service_account_info["project_id"]
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
client = TextEmbeddingModel.from_pretrained(model)
|
||||
|
||||
inputs = [TextEmbeddingInput(text, embedding_type) for text in texts]
|
||||
|
||||
# Split into batches of 25 texts
|
||||
max_texts_per_batch = VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
|
||||
batches = [
|
||||
inputs[i : i + max_texts_per_batch]
|
||||
for i in range(0, len(inputs), max_texts_per_batch)
|
||||
]
|
||||
|
||||
# Dispatch all embedding calls asynchronously at once
|
||||
tasks = [
|
||||
client.get_embeddings_async(batch, auto_truncate=True) for batch in batches
|
||||
]
|
||||
|
||||
# Wait for all tasks to complete in parallel
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
return [embedding.values for batch in results for embedding in batch]
|
||||
|
||||
async def _embed_litellm_proxy(
|
||||
self, texts: list[str], model_name: str | None
|
||||
) -> list[Embedding]:
|
||||
if not model_name:
|
||||
raise ValueError("Model name is required for LiteLLM proxy embedding.")
|
||||
|
||||
if not self.api_url:
|
||||
raise ValueError("API URL is required for LiteLLM proxy embedding.")
|
||||
|
||||
headers = (
|
||||
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
|
||||
)
|
||||
|
||||
response = await self.http_client.post(
|
||||
self.api_url,
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": texts,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return [embedding["embedding"] for embedding in result["data"]]
|
||||
|
||||
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
|
||||
async def embed(
|
||||
self,
|
||||
*,
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
reduced_dimension: int | None = None,
|
||||
) -> list[Embedding]:
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return await self._embed_openai(texts, model_name, reduced_dimension)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return await self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return await self._embed_litellm_proxy(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return await self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return await self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return await self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
except openai.AuthenticationError:
|
||||
raise AuthenticationError(provider="OpenAI")
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise AuthenticationError(provider=str(self.provider))
|
||||
|
||||
error_string = format_embedding_error(
|
||||
e,
|
||||
str(self.provider),
|
||||
model_name or deployment_name,
|
||||
self.provider,
|
||||
sanitized_api_key=self.sanitized_api_key,
|
||||
status_code=e.response.status_code,
|
||||
)
|
||||
logger.error(error_string)
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
except Exception as e:
|
||||
if is_authentication_error(e):
|
||||
raise AuthenticationError(provider=str(self.provider))
|
||||
|
||||
error_string = format_embedding_error(
|
||||
e,
|
||||
str(self.provider),
|
||||
model_name or deployment_name,
|
||||
self.provider,
|
||||
sanitized_api_key=self.sanitized_api_key,
|
||||
)
|
||||
logger.error(error_string)
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
api_key: str,
|
||||
provider: EmbeddingProvider,
|
||||
api_url: str | None = None,
|
||||
api_version: str | None = None,
|
||||
) -> "CloudEmbedding":
|
||||
logger.debug(f"Creating Embedding instance for provider: {provider}")
|
||||
return CloudEmbedding(api_key, provider, api_url, api_version)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Explicitly close the client."""
|
||||
if not self._closed:
|
||||
await self.http_client.aclose()
|
||||
self._closed = True
|
||||
|
||||
async def __aenter__(self) -> "CloudEmbedding":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Finalizer to warn about unclosed clients."""
|
||||
if not self._closed:
|
||||
logger.warning(
|
||||
"CloudEmbedding was not properly closed. Use 'async with' or call aclose()"
|
||||
)
|
||||
|
||||
|
||||
# API-based reranking functions (moved from model server)
|
||||
async def cohere_rerank_api(
|
||||
query: str, docs: list[str], model_name: str, api_key: str
|
||||
) -> list[float]:
|
||||
cohere_client = CohereAsyncClient(api_key=api_key)
|
||||
response = await cohere_client.rerank(query=query, documents=docs, model=model_name)
|
||||
results = response.results
|
||||
sorted_results = sorted(results, key=lambda item: item.index)
|
||||
return [result.relevance_score for result in sorted_results]
|
||||
|
||||
|
||||
async def cohere_rerank_aws(
|
||||
query: str,
|
||||
docs: list[str],
|
||||
model_name: str,
|
||||
region_name: str,
|
||||
aws_access_key_id: str,
|
||||
aws_secret_access_key: str,
|
||||
) -> list[float]:
|
||||
session = aioboto3.Session(
|
||||
aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key
|
||||
)
|
||||
async with session.client(
|
||||
"bedrock-runtime", region_name=region_name
|
||||
) as bedrock_client:
|
||||
body = json.dumps(
|
||||
{
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
"api_version": 2,
|
||||
}
|
||||
)
|
||||
# Invoke the Bedrock model asynchronously
|
||||
response = await bedrock_client.invoke_model(
|
||||
modelId=model_name,
|
||||
accept="application/json",
|
||||
contentType="application/json",
|
||||
body=body,
|
||||
)
|
||||
|
||||
# Read the response asynchronously
|
||||
response_body = json.loads(await response["body"].read())
|
||||
|
||||
# Extract and sort the results
|
||||
results = response_body.get("results", [])
|
||||
sorted_results = sorted(results, key=lambda item: item["index"])
|
||||
|
||||
return [result["relevance_score"] for result in sorted_results]
|
||||
|
||||
|
||||
async def litellm_rerank(
|
||||
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
|
||||
) -> list[float]:
|
||||
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
api_url,
|
||||
json={
|
||||
"model": model_name,
|
||||
"query": query,
|
||||
"documents": docs,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return [
|
||||
item["relevance_score"]
|
||||
for item in sorted(result["results"], key=lambda x: x["index"])
|
||||
]
|
||||
|
||||
|
||||
class EmbeddingModel:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -113,8 +523,84 @@ class EmbeddingModel:
|
||||
)
|
||||
self.callback = callback
|
||||
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
# Only build model server endpoint for local models
|
||||
if self.provider_type is None:
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint: str | None = (
|
||||
f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
)
|
||||
else:
|
||||
# API providers don't need model server endpoint
|
||||
self.embed_server_endpoint = None
|
||||
|
||||
async def _make_direct_api_call(
|
||||
self,
|
||||
embed_request: EmbedRequest,
|
||||
tenant_id: str | None = None,
|
||||
request_id: str | None = None,
|
||||
) -> EmbedResponse:
|
||||
"""Make direct API call to cloud provider, bypassing model server."""
|
||||
if self.provider_type is None:
|
||||
raise ValueError("Provider type is required for direct API calls")
|
||||
|
||||
if self.api_key is None:
|
||||
logger.error("API key not provided for cloud model")
|
||||
raise RuntimeError("API key not provided for cloud model")
|
||||
|
||||
# Check for prefix usage with cloud models
|
||||
if embed_request.manual_query_prefix or embed_request.manual_passage_prefix:
|
||||
logger.warning("Prefix provided for cloud model, which is not supported")
|
||||
raise ValueError(
|
||||
"Prefix string is not valid for cloud models. "
|
||||
"Cloud models take an explicit text type instead."
|
||||
)
|
||||
|
||||
if not all(embed_request.texts):
|
||||
logger.error("Empty strings provided for embedding")
|
||||
raise ValueError("Empty strings are not allowed for embedding.")
|
||||
|
||||
if not embed_request.texts:
|
||||
logger.error("No texts provided for embedding")
|
||||
raise ValueError("No texts provided for embedding.")
|
||||
|
||||
start_time = time.monotonic()
|
||||
total_chars = sum(len(text) for text in embed_request.texts)
|
||||
|
||||
logger.info(
|
||||
f"Embedding {len(embed_request.texts)} texts with {total_chars} total characters with provider: {self.provider_type}"
|
||||
)
|
||||
|
||||
async with CloudEmbedding(
|
||||
api_key=self.api_key,
|
||||
provider=self.provider_type,
|
||||
api_url=self.api_url,
|
||||
api_version=self.api_version,
|
||||
) as cloud_model:
|
||||
embeddings = await cloud_model.embed(
|
||||
texts=embed_request.texts,
|
||||
model_name=embed_request.model_name,
|
||||
deployment_name=embed_request.deployment_name,
|
||||
text_type=embed_request.text_type,
|
||||
reduced_dimension=embed_request.reduced_dimension,
|
||||
)
|
||||
|
||||
if any(embedding is None for embedding in embeddings):
|
||||
error_message = "Embeddings contain None values\n"
|
||||
error_message += "Corresponding texts:\n"
|
||||
error_message += "\n".join(embed_request.texts)
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
logger.info(
|
||||
f"event=embedding_provider "
|
||||
f"texts={len(embed_request.texts)} "
|
||||
f"chars={total_chars} "
|
||||
f"provider={self.provider_type} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
|
||||
def _make_model_server_request(
|
||||
self,
|
||||
@@ -122,6 +608,12 @@ class EmbeddingModel:
|
||||
tenant_id: str | None = None,
|
||||
request_id: str | None = None,
|
||||
) -> EmbedResponse:
|
||||
if self.embed_server_endpoint is None:
|
||||
raise ValueError("Model server endpoint is not configured for local models")
|
||||
|
||||
# Store the endpoint in a local variable to help mypy understand it's not None
|
||||
endpoint = self.embed_server_endpoint
|
||||
|
||||
def _make_request() -> Response:
|
||||
headers = {}
|
||||
if tenant_id:
|
||||
@@ -131,7 +623,7 @@ class EmbeddingModel:
|
||||
headers["X-Onyx-Request-ID"] = request_id
|
||||
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint,
|
||||
endpoint,
|
||||
headers=headers,
|
||||
json=embed_request.model_dump(),
|
||||
)
|
||||
@@ -219,11 +711,28 @@ class EmbeddingModel:
|
||||
reduced_dimension=self.reduced_dimension,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
response = self._make_model_server_request(
|
||||
embed_request, tenant_id=tenant_id, request_id=request_id
|
||||
)
|
||||
end_time = time.time()
|
||||
start_time = time.monotonic()
|
||||
|
||||
# Route between direct API calls and model server calls
|
||||
if self.provider_type is not None:
|
||||
# For API providers, make direct API call
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop)
|
||||
response = loop.run_until_complete(
|
||||
self._make_direct_api_call(
|
||||
embed_request, tenant_id=tenant_id, request_id=request_id
|
||||
)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
else:
|
||||
# For local models, use model server
|
||||
response = self._make_model_server_request(
|
||||
embed_request, tenant_id=tenant_id, request_id=request_id
|
||||
)
|
||||
|
||||
end_time = time.monotonic()
|
||||
|
||||
processing_time = end_time - start_time
|
||||
logger.debug(
|
||||
@@ -360,29 +869,92 @@ class RerankingModel:
|
||||
model_server_host: str = MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
) -> None:
|
||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
|
||||
self.model_name = model_name
|
||||
self.provider_type = provider_type
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url
|
||||
|
||||
# Only build model server endpoint for local models
|
||||
if self.provider_type is None:
|
||||
model_server_url = build_model_server_url(
|
||||
model_server_host, model_server_port
|
||||
)
|
||||
self.rerank_server_endpoint: str | None = (
|
||||
model_server_url + "/encoder/cross-encoder-scores"
|
||||
)
|
||||
else:
|
||||
# API providers don't need model server endpoint
|
||||
self.rerank_server_endpoint = None
|
||||
|
||||
async def _make_direct_rerank_call(
|
||||
self, query: str, passages: list[str]
|
||||
) -> list[float]:
|
||||
"""Make direct API call to cloud provider, bypassing model server."""
|
||||
if self.provider_type is None:
|
||||
raise ValueError("Provider type is required for direct API calls")
|
||||
|
||||
if self.api_key is None:
|
||||
raise ValueError("API key is required for cloud provider")
|
||||
|
||||
if self.provider_type == RerankerProvider.COHERE:
|
||||
return await cohere_rerank_api(
|
||||
query, passages, self.model_name, self.api_key
|
||||
)
|
||||
elif self.provider_type == RerankerProvider.BEDROCK:
|
||||
aws_access_key_id, aws_secret_access_key, aws_region = pass_aws_key(
|
||||
self.api_key
|
||||
)
|
||||
return await cohere_rerank_aws(
|
||||
query,
|
||||
passages,
|
||||
self.model_name,
|
||||
aws_region,
|
||||
aws_access_key_id,
|
||||
aws_secret_access_key,
|
||||
)
|
||||
elif self.provider_type == RerankerProvider.LITELLM:
|
||||
if self.api_url is None:
|
||||
raise ValueError("API URL is required for LiteLLM reranking.")
|
||||
return await litellm_rerank(
|
||||
query, passages, self.api_url, self.model_name, self.api_key
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported reranking provider: {self.provider_type}")
|
||||
|
||||
def predict(self, query: str, passages: list[str]) -> list[float]:
|
||||
rerank_request = RerankRequest(
|
||||
query=query,
|
||||
documents=passages,
|
||||
model_name=self.model_name,
|
||||
provider_type=self.provider_type,
|
||||
api_key=self.api_key,
|
||||
api_url=self.api_url,
|
||||
)
|
||||
# Route between direct API calls and model server calls
|
||||
if self.provider_type is not None:
|
||||
# For API providers, make direct API call
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
asyncio.set_event_loop(loop)
|
||||
return loop.run_until_complete(
|
||||
self._make_direct_rerank_call(query, passages)
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
else:
|
||||
# For local models, use model server
|
||||
if self.rerank_server_endpoint is None:
|
||||
raise ValueError(
|
||||
"Rerank server endpoint is not configured for local models"
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
self.rerank_server_endpoint, json=rerank_request.model_dump()
|
||||
)
|
||||
response.raise_for_status()
|
||||
rerank_request = RerankRequest(
|
||||
query=query,
|
||||
documents=passages,
|
||||
model_name=self.model_name,
|
||||
provider_type=self.provider_type,
|
||||
api_key=self.api_key,
|
||||
api_url=self.api_url,
|
||||
)
|
||||
|
||||
return RerankResponse(**response.json()).scores
|
||||
response = requests.post(
|
||||
self.rerank_server_endpoint, json=rerank_request.model_dump()
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return RerankResponse(**response.json()).scores
|
||||
|
||||
|
||||
class QueryAnalysisModel:
|
||||
|
||||
@@ -151,7 +151,7 @@ def _build_ephemeral_publication_block(
|
||||
email=message_info.email,
|
||||
sender_id=message_info.sender_id,
|
||||
thread_messages=[],
|
||||
is_bot_msg=message_info.is_bot_msg,
|
||||
is_slash_command=message_info.is_slash_command,
|
||||
is_bot_dm=message_info.is_bot_dm,
|
||||
thread_to_respond=respond_ts,
|
||||
)
|
||||
@@ -225,10 +225,10 @@ def _build_doc_feedback_block(
|
||||
|
||||
def get_restate_blocks(
|
||||
msg: str,
|
||||
is_bot_msg: bool,
|
||||
is_slash_command: bool,
|
||||
) -> list[Block]:
|
||||
# Only the slash command needs this context because the user doesn't see their own input
|
||||
if not is_bot_msg:
|
||||
if not is_slash_command:
|
||||
return []
|
||||
|
||||
return [
|
||||
@@ -576,7 +576,7 @@ def build_slack_response_blocks(
|
||||
# If called with the OnyxBot slash command, the question is lost so we have to reshow it
|
||||
if not skip_restated_question:
|
||||
restate_question_block = get_restate_blocks(
|
||||
message_info.thread_messages[-1].message, message_info.is_bot_msg
|
||||
message_info.thread_messages[-1].message, message_info.is_slash_command
|
||||
)
|
||||
else:
|
||||
restate_question_block = []
|
||||
|
||||
@@ -177,7 +177,7 @@ def handle_generate_answer_button(
|
||||
sender_id=user_id or None,
|
||||
email=email or None,
|
||||
bypass_filters=True,
|
||||
is_bot_msg=False,
|
||||
is_slash_command=False,
|
||||
is_bot_dm=False,
|
||||
),
|
||||
slack_channel_config=slack_channel_config,
|
||||
|
||||
@@ -28,7 +28,7 @@ logger_base = setup_logger()
|
||||
|
||||
|
||||
def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
|
||||
if details.is_bot_msg and details.sender_id:
|
||||
if details.is_slash_command and details.sender_id:
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=details.channel_to_respond,
|
||||
@@ -124,11 +124,11 @@ def handle_message(
|
||||
messages = message_info.thread_messages
|
||||
sender_id = message_info.sender_id
|
||||
bypass_filters = message_info.bypass_filters
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
is_slash_command = message_info.is_slash_command
|
||||
is_bot_dm = message_info.is_bot_dm
|
||||
|
||||
action = "slack_message"
|
||||
if is_bot_msg:
|
||||
if is_slash_command:
|
||||
action = "slack_slash_message"
|
||||
elif bypass_filters:
|
||||
action = "slack_tag_message"
|
||||
@@ -197,7 +197,7 @@ def handle_message(
|
||||
|
||||
# If configured to respond to team members only, then cannot be used with a /OnyxBot command
|
||||
# which would just respond to the sender
|
||||
if send_to and is_bot_msg:
|
||||
if send_to and is_slash_command:
|
||||
if sender_id:
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
|
||||
@@ -81,15 +81,15 @@ def handle_regular_answer(
|
||||
messages = message_info.thread_messages
|
||||
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
is_slash_command = message_info.is_slash_command
|
||||
|
||||
# Capture whether response mode for channel is ephemeral. Even if the channel is set
|
||||
# to respond with an ephemeral message, we still send as non-ephemeral if
|
||||
# the message is a dm with the Onyx bot.
|
||||
send_as_ephemeral = (
|
||||
slack_channel_config.channel_config.get("is_ephemeral", False)
|
||||
and not message_info.is_bot_dm
|
||||
)
|
||||
or message_info.is_slash_command
|
||||
) and not message_info.is_bot_dm
|
||||
|
||||
# If the channel mis configured to respond with an ephemeral message,
|
||||
# or the message is a dm to the Onyx bot, we should use the proper onyx user from the email.
|
||||
@@ -164,7 +164,7 @@ def handle_regular_answer(
|
||||
# in an attached document set were available to all users in the channel.)
|
||||
bypass_acl = False
|
||||
|
||||
if not message_ts_to_respond_to and not is_bot_msg:
|
||||
if not message_ts_to_respond_to and not is_slash_command:
|
||||
# if the message is not "/onyx" command, then it should have a message ts to respond to
|
||||
raise RuntimeError(
|
||||
"No message timestamp to respond to in `handle_message`. This should never happen."
|
||||
@@ -316,13 +316,14 @@ def handle_regular_answer(
|
||||
return True
|
||||
|
||||
# Got an answer at this point, can remove reaction and give results
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
client=client,
|
||||
)
|
||||
if not is_slash_command: # Slash commands don't have reactions
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
client=client,
|
||||
)
|
||||
|
||||
if answer.answer_valid is False:
|
||||
logger.notice(
|
||||
|
||||
@@ -130,6 +130,10 @@ _SLACK_GREETINGS_TO_IGNORE = {
|
||||
# This is always (currently) the user id of Slack's official slackbot
|
||||
_OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
|
||||
|
||||
# Fields to exclude from Slack payload logging
|
||||
# Intention is to not log slack message content
|
||||
_EXCLUDED_SLACK_PAYLOAD_FIELDS = {"text", "blocks"}
|
||||
|
||||
|
||||
class SlackbotHandler:
|
||||
def __init__(self) -> None:
|
||||
@@ -570,6 +574,20 @@ class SlackbotHandler:
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def sanitize_slack_payload(payload: dict) -> dict:
|
||||
"""Remove message content from Slack payload for logging"""
|
||||
sanitized = {
|
||||
k: v for k, v in payload.items() if k not in _EXCLUDED_SLACK_PAYLOAD_FIELDS
|
||||
}
|
||||
if "event" in sanitized and isinstance(sanitized["event"], dict):
|
||||
sanitized["event"] = {
|
||||
k: v
|
||||
for k, v in sanitized["event"].items()
|
||||
if k not in _EXCLUDED_SLACK_PAYLOAD_FIELDS
|
||||
}
|
||||
return sanitized
|
||||
|
||||
|
||||
def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool:
|
||||
"""True to keep going, False to ignore this Slack request"""
|
||||
|
||||
@@ -762,7 +780,10 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
|
||||
if not check_message_limit():
|
||||
return False
|
||||
|
||||
logger.debug(f"Handling Slack request: {client.bot_name=} '{req.payload=}'")
|
||||
# Don't log Slack message content
|
||||
logger.debug(
|
||||
f"Handling Slack request: {client.bot_name=} '{sanitize_slack_payload(req.payload)=}'"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@@ -876,12 +897,13 @@ def build_request_details(
|
||||
sender_id=sender_id,
|
||||
email=email,
|
||||
bypass_filters=tagged,
|
||||
is_bot_msg=False,
|
||||
is_slash_command=False,
|
||||
is_bot_dm=event.get("channel_type") == "im",
|
||||
)
|
||||
|
||||
elif req.type == "slash_commands":
|
||||
channel = req.payload["channel_id"]
|
||||
channel_name = req.payload["channel_name"]
|
||||
msg = req.payload["text"]
|
||||
sender = req.payload["user_id"]
|
||||
expert_info = expert_info_from_slack_id(
|
||||
@@ -899,8 +921,8 @@ def build_request_details(
|
||||
sender_id=sender,
|
||||
email=email,
|
||||
bypass_filters=True,
|
||||
is_bot_msg=True,
|
||||
is_bot_dm=False,
|
||||
is_slash_command=True,
|
||||
is_bot_dm=channel_name == "directmessage",
|
||||
)
|
||||
|
||||
raise RuntimeError("Programming fault, this should never happen.")
|
||||
@@ -928,10 +950,9 @@ def process_message(
|
||||
if req.type == "events_api":
|
||||
event = cast(dict[str, Any], req.payload["event"])
|
||||
event_type = event.get("type")
|
||||
msg = cast(str, event.get("text", ""))
|
||||
logger.info(
|
||||
f"process_message start: {tenant_id=} {req.type=} {req.envelope_id=} "
|
||||
f"{event_type=} {msg=}"
|
||||
f"{event_type=}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
|
||||
@@ -13,7 +13,7 @@ class SlackMessageInfo(BaseModel):
|
||||
sender_id: str | None
|
||||
email: str | None
|
||||
bypass_filters: bool # User has tagged @OnyxBot
|
||||
is_bot_msg: bool # User is using /OnyxBot
|
||||
is_slash_command: bool # User is using /OnyxBot
|
||||
is_bot_dm: bool # User is direct messaging to OnyxBot
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ class ActionValuesEphemeralMessageMessageInfo(BaseModel):
|
||||
email: str | None
|
||||
sender_id: str | None
|
||||
thread_messages: list[ThreadMessage] | None
|
||||
is_bot_msg: bool | None
|
||||
is_slash_command: bool | None
|
||||
is_bot_dm: bool | None
|
||||
thread_to_respond: str | None
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import redis
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_stop import RedisConnectorStop
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -31,11 +30,6 @@ class RedisConnector:
|
||||
tenant_id, cc_pair_id, self.redis
|
||||
)
|
||||
|
||||
def new_index(self, search_settings_id: int) -> RedisConnectorIndex:
|
||||
return RedisConnectorIndex(
|
||||
self.tenant_id, self.cc_pair_id, search_settings_id, self.redis
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> str | None:
|
||||
"""
|
||||
@@ -81,3 +75,11 @@ class RedisConnector:
|
||||
|
||||
object_id = parts[1]
|
||||
return object_id
|
||||
|
||||
def db_lock_key(self, search_settings_id: int) -> str:
|
||||
"""
|
||||
Key for the db lock for an indexing attempt.
|
||||
Prevents multiple modifications to the current indexing attempt row
|
||||
from multiple docfetching/docprocessing tasks.
|
||||
"""
|
||||
return f"da_lock:indexing:db_{self.cc_pair_id}/{search_settings_id}"
|
||||
|
||||
@@ -1,126 +1,10 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
import redis
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
|
||||
|
||||
|
||||
class RedisConnectorIndexPayload(BaseModel):
|
||||
index_attempt_id: int | None
|
||||
started: datetime | None
|
||||
submitted: datetime
|
||||
celery_task_id: str | None
|
||||
|
||||
|
||||
class RedisConnectorIndex:
|
||||
"""Manages interactions with redis for indexing tasks. Should only be accessed
|
||||
through RedisConnector."""
|
||||
|
||||
PREFIX = "connectorindexing"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence" # "connectorindexing_fence"
|
||||
GENERATOR_TASK_PREFIX = PREFIX + "+generator" # "connectorindexing+generator_fence"
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # connectorindexing_generator_progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # connectorindexing_generator_complete
|
||||
|
||||
GENERATOR_LOCK_PREFIX = "da_lock:indexing:docfetching"
|
||||
FILESTORE_LOCK_PREFIX = "da_lock:indexing:filestore"
|
||||
DB_LOCK_PREFIX = "da_lock:indexing:db"
|
||||
PER_WORKER_LOCK_PREFIX = "da_lock:indexing:per_worker"
|
||||
|
||||
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
|
||||
TERMINATE_TTL = 600
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
# used to signal that the watchdog is running
|
||||
WATCHDOG_PREFIX = PREFIX + "_watchdog"
|
||||
WATCHDOG_TTL = 300
|
||||
|
||||
# used to signal that the connector itself is still running
|
||||
CONNECTOR_ACTIVE_PREFIX = PREFIX + "_connector_active"
|
||||
CONNECTOR_ACTIVE_TTL = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
redis: redis.Redis,
|
||||
) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.cc_pair_id = cc_pair_id
|
||||
self.search_settings_id = search_settings_id
|
||||
self.redis = redis
|
||||
|
||||
self.generator_complete_key = (
|
||||
f"{self.GENERATOR_COMPLETE_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
self.filestore_lock_key = (
|
||||
f"{self.FILESTORE_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
self.generator_lock_key = (
|
||||
f"{self.GENERATOR_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
self.per_worker_lock_key = (
|
||||
f"{self.PER_WORKER_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
self.db_lock_key = f"{self.DB_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
self.terminate_key = (
|
||||
f"{self.TERMINATE_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
|
||||
def set_generator_complete(self, payload: int | None) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.generator_complete_key, payload)
|
||||
|
||||
def generator_clear(self) -> None:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
|
||||
def get_completion(self) -> int | None:
|
||||
bytes = self.redis.get(self.generator_complete_key)
|
||||
if bytes is None:
|
||||
return None
|
||||
|
||||
status = int(cast(int, bytes))
|
||||
return status
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.filestore_lock_key)
|
||||
self.redis.delete(self.db_lock_key)
|
||||
self.redis.delete(self.generator_lock_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
# leaving these temporarily for backwards compat, TODO: remove
|
||||
for key in r.scan_iter(RedisConnectorIndex.CONNECTOR_ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.FILESTORE_LOCK_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
@@ -16,8 +15,6 @@ def is_fence(key_bytes: bytes) -> bool:
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
|
||||
return True
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.db.connector import create_connector
|
||||
from onyx.db.connector_credential_pair import add_credential_to_connector
|
||||
from onyx.db.credentials import PUBLIC_CREDENTIAL_ID
|
||||
from onyx.db.document import check_docs_exist
|
||||
from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import mock_successful_index_attempt
|
||||
@@ -264,5 +265,13 @@ def seed_initial_documents(
|
||||
.values(chunk_count=doc.chunk_count)
|
||||
)
|
||||
|
||||
# Since we bypass the indexing flow, we need to manually mark the document as indexed
|
||||
mark_document_as_indexed_for_cc_pair__no_commit(
|
||||
connector_id=connector_id,
|
||||
credential_id=PUBLIC_CREDENTIAL_ID,
|
||||
document_ids=[doc.id for doc in docs],
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
kv_store.store(KV_DOCUMENTS_SEEDED_KEY, True)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import io
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
@@ -101,8 +102,9 @@ from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.file_processing.extract_file_text import convert_docx_to_txt
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.documents.models import AuthStatus
|
||||
from onyx.server.documents.models import AuthUrl
|
||||
@@ -124,6 +126,7 @@ from onyx.server.documents.models import IndexAttemptSnapshot
|
||||
from onyx.server.documents.models import ObjectCreationIdResponse
|
||||
from onyx.server.documents.models import RunConnectorRequest
|
||||
from onyx.server.models import StatusResponse
|
||||
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -438,7 +441,9 @@ def is_zip_file(file: UploadFile) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def upload_files(files: list[UploadFile]) -> FileUploadResponse:
|
||||
def upload_files(
|
||||
files: list[UploadFile], file_origin: FileOrigin = FileOrigin.CONNECTOR
|
||||
) -> FileUploadResponse:
|
||||
for file in files:
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="File name cannot be empty")
|
||||
@@ -487,12 +492,17 @@ def upload_files(files: list[UploadFile]) -> FileUploadResponse:
|
||||
# For mypy, actual check happens at start of function
|
||||
assert file.filename is not None
|
||||
|
||||
# Special handling for docx files - only store the plaintext version
|
||||
if file.content_type and file.content_type.startswith(
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
docx_file_id = convert_docx_to_txt(file, file_store)
|
||||
deduped_file_paths.append(docx_file_id)
|
||||
# Special handling for doc files - only store the plaintext version
|
||||
file_type = mime_type_to_chat_file_type(file.content_type)
|
||||
if file_type == ChatFileType.DOC:
|
||||
extracted_text = extract_file_text(file.file, file.filename or "")
|
||||
text_file_id = file_store.save_file(
|
||||
content=io.BytesIO(extracted_text.encode()),
|
||||
display_name=file.filename,
|
||||
file_origin=file_origin,
|
||||
file_type="text/plain",
|
||||
)
|
||||
deduped_file_paths.append(text_file_id)
|
||||
deduped_file_names.append(file.filename)
|
||||
continue
|
||||
|
||||
@@ -520,7 +530,7 @@ def upload_files_api(
|
||||
files: list[UploadFile],
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
) -> FileUploadResponse:
|
||||
return upload_files(files)
|
||||
return upload_files(files, FileOrigin.OTHER)
|
||||
|
||||
|
||||
@router.get("/admin/connector")
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Form
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -27,6 +32,9 @@ from onyx.server.documents.models import CredentialDataUpdateRequest
|
||||
from onyx.server.documents.models import CredentialSnapshot
|
||||
from onyx.server.documents.models import CredentialSwapRequest
|
||||
from onyx.server.documents.models import ObjectCreationIdResponse
|
||||
from onyx.server.documents.private_key_types import FILE_TYPE_TO_FILE_PROCESSOR
|
||||
from onyx.server.documents.private_key_types import PrivateKeyFileTypes
|
||||
from onyx.server.documents.private_key_types import ProcessPrivateKeyFileProtocol
|
||||
from onyx.server.models import StatusResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
@@ -76,6 +84,7 @@ def get_cc_source_full_info(
|
||||
document_source=source_type,
|
||||
get_editable=get_editable,
|
||||
)
|
||||
|
||||
return [
|
||||
CredentialSnapshot.from_credential_db_model(credential)
|
||||
for credential in credentials
|
||||
@@ -149,6 +158,70 @@ def create_credential_from_model(
|
||||
)
|
||||
|
||||
|
||||
@router.post("/credential/private-key")
|
||||
def create_credential_with_private_key(
|
||||
credential_json: str = Form(...),
|
||||
admin_public: bool = Form(False),
|
||||
curator_public: bool = Form(False),
|
||||
groups: list[int] = Form([]),
|
||||
name: str | None = Form(None),
|
||||
source: str = Form(...),
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
uploaded_file: UploadFile = File(...),
|
||||
field_key: str = Form(...),
|
||||
type_definition_key: str = Form(...),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ObjectCreationIdResponse:
|
||||
try:
|
||||
credential_data = json.loads(credential_json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid JSON in credential_json: {str(e)}",
|
||||
)
|
||||
|
||||
private_key_processor: ProcessPrivateKeyFileProtocol | None = (
|
||||
FILE_TYPE_TO_FILE_PROCESSOR.get(PrivateKeyFileTypes(type_definition_key))
|
||||
)
|
||||
if private_key_processor is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid type definition key for private key file",
|
||||
)
|
||||
private_key_content: str = private_key_processor(uploaded_file)
|
||||
|
||||
credential_data[field_key] = private_key_content
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json=credential_data,
|
||||
admin_public=admin_public,
|
||||
curator_public=curator_public,
|
||||
groups=groups,
|
||||
name=name,
|
||||
source=DocumentSource(source),
|
||||
)
|
||||
|
||||
if not _ignore_credential_permissions(DocumentSource(source)):
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
)(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
target_group_ids=groups,
|
||||
object_is_public=curator_public,
|
||||
)
|
||||
|
||||
# Temporary fix for empty Google App credentials
|
||||
if DocumentSource(source) == DocumentSource.GMAIL:
|
||||
cleanup_gmail_credentials(db_session=db_session)
|
||||
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
return ObjectCreationIdResponse(
|
||||
id=credential.id,
|
||||
credential=CredentialSnapshot.from_credential_db_model(credential),
|
||||
)
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
|
||||
|
||||
@@ -209,6 +282,53 @@ def update_credential_data(
|
||||
return CredentialSnapshot.from_credential_db_model(credential)
|
||||
|
||||
|
||||
@router.put("/admin/credential/private-key/{credential_id}")
|
||||
def update_credential_private_key(
|
||||
credential_id: int,
|
||||
name: str = Form(...),
|
||||
credential_json: str = Form(...),
|
||||
uploaded_file: UploadFile = File(...),
|
||||
field_key: str = Form(...),
|
||||
type_definition_key: str = Form(...),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CredentialBase:
|
||||
try:
|
||||
credential_data = json.loads(credential_json)
|
||||
except json.JSONDecodeError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Invalid JSON in credential_json: {str(e)}",
|
||||
)
|
||||
|
||||
private_key_processor: ProcessPrivateKeyFileProtocol | None = (
|
||||
FILE_TYPE_TO_FILE_PROCESSOR.get(PrivateKeyFileTypes(type_definition_key))
|
||||
)
|
||||
if private_key_processor is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid type definition key for private key file",
|
||||
)
|
||||
private_key_content: str = private_key_processor(uploaded_file)
|
||||
credential_data[field_key] = private_key_content
|
||||
|
||||
credential = alter_credential(
|
||||
credential_id,
|
||||
name,
|
||||
credential_data,
|
||||
user,
|
||||
db_session,
|
||||
)
|
||||
|
||||
if credential is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Credential {credential_id} does not exist or does not belong to user",
|
||||
)
|
||||
|
||||
return CredentialSnapshot.from_credential_db_model(credential)
|
||||
|
||||
|
||||
@router.patch("/credential/{credential_id}")
|
||||
def update_credential_from_model(
|
||||
credential_id: int,
|
||||
|
||||
75
backend/onyx/server/documents/document_utils.py
Normal file
75
backend/onyx/server/documents/document_utils.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from cryptography.hazmat.primitives.serialization import pkcs12
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _is_password_related_error(error: Exception) -> bool:
|
||||
"""
|
||||
Check if the exception indicates a password-related issue rather than a format issue.
|
||||
"""
|
||||
error_msg = str(error).lower()
|
||||
password_keywords = ["mac", "integrity", "password", "authentication", "verify"]
|
||||
return any(keyword in error_msg for keyword in password_keywords)
|
||||
|
||||
|
||||
def validate_pkcs12_content(file_bytes: bytes) -> bool:
|
||||
"""
|
||||
Validate that the file content is actually a PKCS#12 file.
|
||||
This performs basic format validation without requiring passwords.
|
||||
"""
|
||||
try:
|
||||
# Basic file size check
|
||||
if len(file_bytes) < 10:
|
||||
logger.debug("File too small to be a valid PKCS#12 file")
|
||||
return False
|
||||
|
||||
# Check for PKCS#12 magic bytes/ASN.1 structure
|
||||
# PKCS#12 files start with ASN.1 SEQUENCE tag (0x30)
|
||||
if file_bytes[0] != 0x30:
|
||||
logger.debug("File does not start with ASN.1 SEQUENCE tag")
|
||||
return False
|
||||
|
||||
# Try to parse the outer ASN.1 structure without password validation
|
||||
# This checks if the file has the basic PKCS#12 structure
|
||||
try:
|
||||
# Attempt to load just to validate the basic format
|
||||
# We expect this to fail due to password, but it should fail with a specific error
|
||||
pkcs12.load_key_and_certificates(file_bytes, password=None)
|
||||
return True
|
||||
except ValueError as e:
|
||||
# Check if the error is related to password (expected) vs format issues
|
||||
if _is_password_related_error(e):
|
||||
# These errors indicate the file format is correct but password is wrong/missing
|
||||
logger.debug(
|
||||
f"PKCS#12 format appears valid, password-related error: {e}"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
# Other ValueError likely indicates format issues
|
||||
logger.debug(f"PKCS#12 format validation failed: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
# Try with empty password as fallback
|
||||
try:
|
||||
pkcs12.load_key_and_certificates(file_bytes, password=b"")
|
||||
return True
|
||||
except ValueError as e2:
|
||||
if _is_password_related_error(e2):
|
||||
logger.debug(
|
||||
f"PKCS#12 format appears valid with empty password attempt: {e2}"
|
||||
)
|
||||
return True
|
||||
else:
|
||||
logger.debug(
|
||||
f"PKCS#12 validation failed on both attempts: {e}, {e2}"
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
logger.debug(f"PKCS#12 validation failed: {e}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Unexpected error during PKCS#12 validation: {e}")
|
||||
return False
|
||||
57
backend/onyx/server/documents/private_key_types.py
Normal file
57
backend/onyx/server/documents/private_key_types.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import base64
|
||||
from enum import Enum
|
||||
from typing import Protocol
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
|
||||
from onyx.server.documents.document_utils import validate_pkcs12_content
|
||||
|
||||
|
||||
class ProcessPrivateKeyFileProtocol(Protocol):
|
||||
def __call__(self, file: UploadFile) -> str:
|
||||
"""
|
||||
Accepts a file-like object, validates the file (e.g., checks extension and content),
|
||||
and returns its contents as a base64-encoded string if valid.
|
||||
Raises an exception if validation fails.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class PrivateKeyFileTypes(Enum):
|
||||
SHAREPOINT_PFX_FILE = "sharepoint_pfx_file"
|
||||
|
||||
|
||||
def process_sharepoint_private_key_file(file: UploadFile) -> str:
|
||||
"""
|
||||
Process and validate a private key file upload.
|
||||
|
||||
Validates both the file extension and file content to ensure it's a valid PKCS#12 file.
|
||||
Content validation prevents attacks that rely on file extension spoofing.
|
||||
"""
|
||||
# First check file extension (basic filter)
|
||||
if not (file.filename and file.filename.lower().endswith(".pfx")):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid file type. Only .pfx files are supported."
|
||||
)
|
||||
|
||||
# Read file content for validation and processing
|
||||
private_key_bytes = file.file.read()
|
||||
|
||||
# Validate file content to prevent extension spoofing attacks
|
||||
if not validate_pkcs12_content(private_key_bytes):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Invalid file content. The uploaded file does not appear to be a valid PKCS#12 (.pfx) file.",
|
||||
)
|
||||
|
||||
# Convert to base64 if validation passes
|
||||
pfx_64 = base64.b64encode(private_key_bytes).decode("ascii")
|
||||
return pfx_64
|
||||
|
||||
|
||||
FILE_TYPE_TO_FILE_PROCESSOR: dict[
|
||||
PrivateKeyFileTypes, ProcessPrivateKeyFileProtocol
|
||||
] = {
|
||||
PrivateKeyFileTypes.SHAREPOINT_PFX_FILE: process_sharepoint_private_key_file,
|
||||
}
|
||||
@@ -23,6 +23,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.email_utils import send_user_email_invite
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.noauth_user import fetch_no_auth_user
|
||||
from onyx.auth.noauth_user import set_no_auth_user_preferences
|
||||
@@ -367,15 +368,11 @@ def remove_invited_user(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> int:
|
||||
tenant_id = get_current_tenant_id()
|
||||
user_emails = get_invited_users()
|
||||
remaining_users = [user for user in user_emails if user != user_email.user_email]
|
||||
|
||||
if MULTI_TENANT:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
|
||||
)([user_email.user_email], tenant_id)
|
||||
|
||||
number_of_invited_users = write_invited_users(remaining_users)
|
||||
number_of_invited_users = remove_user_from_invited_users(user_email.user_email)
|
||||
|
||||
try:
|
||||
if MULTI_TENANT and not DEV_MODE:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
@@ -31,7 +30,6 @@ from onyx.chat.prompt_builder.citations_prompt import (
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.chat_configs import HARD_DELETE_CHATS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
|
||||
@@ -63,9 +61,7 @@ from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.user_documents import create_user_files
|
||||
from onyx.file_processing.extract_file_text import docx_to_txt_filename
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_default_llms
|
||||
@@ -717,106 +713,65 @@ def upload_files_for_chat(
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="File size must be less than 20MB",
|
||||
detail="Images must be less than 20MB",
|
||||
)
|
||||
|
||||
file_store = get_default_file_store()
|
||||
|
||||
file_info: list[tuple[str, str | None, ChatFileType]] = []
|
||||
for file in files:
|
||||
file_type = mime_type_to_chat_file_type(file.content_type)
|
||||
|
||||
file_content = file.file.read() # Read the file content
|
||||
|
||||
# NOTE: Image conversion to JPEG used to be enforced here.
|
||||
# This was removed to:
|
||||
# 1. Preserve original file content for downloads
|
||||
# 2. Maintain transparency in formats like PNG
|
||||
# 3. Ameliorate issue with file conversion
|
||||
file_content_io = io.BytesIO(file_content)
|
||||
|
||||
new_content_type = file.content_type
|
||||
|
||||
# Store the file normally
|
||||
file_id = file_store.save_file(
|
||||
content=file_content_io,
|
||||
display_name=file.filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=new_content_type or file_type.value,
|
||||
# 5) Create a user file for each uploaded file
|
||||
user_files = create_user_files(files, RECENT_DOCS_FOLDER_ID, user, db_session)
|
||||
for user_file in user_files:
|
||||
# 6) Create connector
|
||||
connector_base = ConnectorBase(
|
||||
name=f"UserFile-{int(time.time())}",
|
||||
source=DocumentSource.FILE,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={
|
||||
"file_locations": [user_file.file_id],
|
||||
"file_names": [user_file.name],
|
||||
"zip_metadata": {},
|
||||
},
|
||||
refresh_freq=None,
|
||||
prune_freq=None,
|
||||
indexing_start=None,
|
||||
)
|
||||
connector = create_connector(
|
||||
db_session=db_session,
|
||||
connector_data=connector_base,
|
||||
)
|
||||
|
||||
# 4) If the file is a doc, extract text and store that separately
|
||||
if file_type == ChatFileType.DOC:
|
||||
# Re-wrap bytes in a fresh BytesIO so we start at position 0
|
||||
extracted_text_io = io.BytesIO(file_content)
|
||||
extracted_text = extract_file_text(
|
||||
file=extracted_text_io, # use the bytes we already read
|
||||
file_name=file.filename or "",
|
||||
)
|
||||
# 7) Create credential
|
||||
credential_info = CredentialBase(
|
||||
credential_json={},
|
||||
admin_public=True,
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=True,
|
||||
groups=[],
|
||||
name=f"UserFileCredential-{int(time.time())}",
|
||||
is_user_file=True,
|
||||
)
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
|
||||
text_file_id = file_store.save_file(
|
||||
content=io.BytesIO(extracted_text.encode()),
|
||||
display_name=file.filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type="text/plain",
|
||||
)
|
||||
# Return the text file as the "main" file descriptor for doc types
|
||||
file_info.append((text_file_id, file.filename, ChatFileType.PLAIN_TEXT))
|
||||
else:
|
||||
file_info.append((file_id, file.filename, file_type))
|
||||
|
||||
# 5) Create a user file for each uploaded file
|
||||
user_files = create_user_files([file], RECENT_DOCS_FOLDER_ID, user, db_session)
|
||||
for user_file in user_files:
|
||||
# 6) Create connector
|
||||
connector_base = ConnectorBase(
|
||||
name=f"UserFile-{int(time.time())}",
|
||||
source=DocumentSource.FILE,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={
|
||||
"file_locations": [user_file.file_id],
|
||||
"file_names": [user_file.name],
|
||||
"zip_metadata": {},
|
||||
},
|
||||
refresh_freq=None,
|
||||
prune_freq=None,
|
||||
indexing_start=None,
|
||||
)
|
||||
connector = create_connector(
|
||||
db_session=db_session,
|
||||
connector_data=connector_base,
|
||||
)
|
||||
|
||||
# 7) Create credential
|
||||
credential_info = CredentialBase(
|
||||
credential_json={},
|
||||
admin_public=True,
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=True,
|
||||
groups=[],
|
||||
name=f"UserFileCredential-{int(time.time())}",
|
||||
is_user_file=True,
|
||||
)
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
|
||||
# 8) Create connector credential pair
|
||||
cc_pair = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
cc_pair_name=f"UserFileCCPair-{int(time.time())}",
|
||||
access_type=AccessType.PRIVATE,
|
||||
auto_sync_options=None,
|
||||
groups=[],
|
||||
)
|
||||
user_file.cc_pair_id = cc_pair.data
|
||||
db_session.commit()
|
||||
# 8) Create connector credential pair
|
||||
cc_pair = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
cc_pair_name=f"UserFileCCPair-{int(time.time())}",
|
||||
access_type=AccessType.PRIVATE,
|
||||
auto_sync_options=None,
|
||||
groups=[],
|
||||
)
|
||||
user_file.cc_pair_id = cc_pair.data
|
||||
db_session.commit()
|
||||
|
||||
return {
|
||||
"files": [
|
||||
{"id": file_id, "type": file_type, "name": file_name}
|
||||
for file_id, file_name, file_type in file_info
|
||||
{
|
||||
"id": user_file.file_id,
|
||||
"type": mime_type_to_chat_file_type(user_file.content_type),
|
||||
"name": user_file.name,
|
||||
}
|
||||
for user_file in user_files
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import OKTA_PROFILE_TOOL_ENABLED
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Tool as ToolDBModel
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
@@ -17,6 +18,9 @@ from onyx.tools.tool_implementations.internet_search.internet_search_tool import
|
||||
from onyx.tools.tool_implementations.internet_search.providers import (
|
||||
get_available_providers,
|
||||
)
|
||||
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
|
||||
OktaProfileTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -63,6 +67,19 @@ BUILT_IN_TOOLS: list[InCodeToolInfo] = [
|
||||
if (bool(get_available_providers()))
|
||||
else []
|
||||
),
|
||||
# Show Okta Profile tool if the environment variables are set
|
||||
*(
|
||||
[
|
||||
InCodeToolInfo(
|
||||
cls=OktaProfileTool,
|
||||
description="The Okta Profile Action allows the assistant to fetch user information from Okta.",
|
||||
in_code_tool_id=OktaProfileTool.__name__,
|
||||
display_name=OktaProfileTool._DISPLAY_NAME,
|
||||
)
|
||||
]
|
||||
if OKTA_PROFILE_TOOL_ENABLED
|
||||
else []
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -5,12 +5,12 @@ from typing import Generic
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolResponse
|
||||
@@ -53,8 +53,8 @@ class Tool(abc.ABC, Generic[OVERRIDE_T]):
|
||||
def get_args_for_non_tool_calling_llm(
|
||||
self,
|
||||
query: str,
|
||||
history: list[PreviousMessage],
|
||||
llm: LLM,
|
||||
history: list["PreviousMessage"],
|
||||
llm: "LLM",
|
||||
force_run: bool = False,
|
||||
) -> dict[str, Any] | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -14,6 +14,10 @@ from onyx.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from onyx.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
from onyx.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||
from onyx.configs.app_configs import IMAGE_MODEL_NAME
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import OKTA_API_TOKEN
|
||||
from onyx.configs.app_configs import OPENID_CONFIG_URL
|
||||
from onyx.configs.chat_configs import NUM_INTERNET_SEARCH_CHUNKS
|
||||
from onyx.configs.chat_configs import NUM_INTERNET_SEARCH_RESULTS
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
@@ -41,6 +45,9 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
||||
InternetSearchTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.okta_profile.okta_profile_tool import (
|
||||
OktaProfileTool,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import compute_all_tool_tokens
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
@@ -265,6 +272,33 @@ def construct_tools(
|
||||
"Internet search tool requires a Bing or Exa API key, please contact your Onyx admin to get it added!"
|
||||
)
|
||||
|
||||
# Handle Okta Profile Tool
|
||||
elif tool_cls.__name__ == OktaProfileTool.__name__:
|
||||
if not user_oauth_token:
|
||||
raise ValueError(
|
||||
"Okta Profile Tool requires user OAuth token but none found"
|
||||
)
|
||||
|
||||
if not all([OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL]):
|
||||
raise ValueError(
|
||||
"Okta Profile Tool requires OAuth configuration to be set"
|
||||
)
|
||||
|
||||
if not OKTA_API_TOKEN:
|
||||
raise ValueError(
|
||||
"Okta Profile Tool requires OKTA_API_TOKEN to be set"
|
||||
)
|
||||
|
||||
tool_dict[db_tool_model.id] = [
|
||||
OktaProfileTool(
|
||||
access_token=user_oauth_token,
|
||||
client_id=OAUTH_CLIENT_ID,
|
||||
client_secret=OAUTH_CLIENT_SECRET,
|
||||
openid_config_url=OPENID_CONFIG_URL,
|
||||
okta_api_token=OKTA_API_TOKEN,
|
||||
)
|
||||
]
|
||||
|
||||
# Handle custom tools
|
||||
elif db_tool_model.openapi_schema:
|
||||
if not custom_tool_config:
|
||||
|
||||
@@ -0,0 +1,243 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.prompts.constants import GENERAL_SEP_PAT
|
||||
from onyx.tools.base_tool import BaseTool
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
OKTA_PROFILE_RESPONSE_ID = "okta_profile"
|
||||
|
||||
OKTA_TOOL_DESCRIPTION = """
|
||||
The Okta profile tool can retrieve user profile information from Okta including:
|
||||
- User ID, status, creation date
|
||||
- Profile details like name, email, department, location, title, manager, and more
|
||||
- Account status and activity
|
||||
"""
|
||||
|
||||
|
||||
class OIDCConfig(BaseModel):
|
||||
issuer: str
|
||||
jwks_uri: str | None = None
|
||||
userinfo_endpoint: str | None = None
|
||||
introspection_endpoint: str | None = None
|
||||
token_endpoint: str | None = None
|
||||
|
||||
|
||||
class OktaProfileTool(BaseTool):
|
||||
_NAME = "get_okta_profile"
|
||||
_DESCRIPTION = "This tool is used to get the user's profile information."
|
||||
_DISPLAY_NAME = "Okta Profile"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
access_token: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
openid_config_url: str,
|
||||
okta_api_token: str,
|
||||
request_timeout_sec: int = 15,
|
||||
) -> None:
|
||||
self.access_token = access_token
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.openid_config_url = openid_config_url
|
||||
self.request_timeout_sec = request_timeout_sec
|
||||
|
||||
# Extract Okta org URL from OpenID config URL using URL parsing
|
||||
# OpenID config URL format: https://{org}.okta.com/.well-known/openid_configuration
|
||||
parsed_url = urlparse(self.openid_config_url)
|
||||
self.okta_org_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
self.okta_api_token = okta_api_token
|
||||
|
||||
self._oidc_config: OIDCConfig | None = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._NAME
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._DESCRIPTION
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self._DISPLAY_NAME
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
},
|
||||
}
|
||||
|
||||
def _load_oidc_config(self) -> OIDCConfig:
|
||||
if self._oidc_config is not None:
|
||||
return self._oidc_config
|
||||
|
||||
resp = requests.get(self.openid_config_url, timeout=self.request_timeout_sec)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
self._oidc_config = OIDCConfig(**data)
|
||||
logger.debug(f"Loaded OIDC config from {self.openid_config_url}")
|
||||
return self._oidc_config
|
||||
|
||||
def _call_userinfo(self, access_token: str) -> dict[str, Any] | None:
|
||||
try:
|
||||
cfg = self._load_oidc_config()
|
||||
if not cfg.userinfo_endpoint:
|
||||
logger.info("OIDC config missing userinfo_endpoint")
|
||||
return None
|
||||
headers = {"Authorization": f"Bearer {access_token}"}
|
||||
r = requests.get(
|
||||
cfg.userinfo_endpoint, headers=headers, timeout=self.request_timeout_sec
|
||||
)
|
||||
if r.status_code == 200:
|
||||
return r.json()
|
||||
logger.info(
|
||||
f"userinfo call returned status {r.status_code}: {r.text[:200]}"
|
||||
)
|
||||
return None
|
||||
except requests.RequestException as e:
|
||||
logger.debug(f"userinfo request failed: {e}")
|
||||
return None
|
||||
|
||||
def _call_introspection(self, access_token: str) -> dict[str, Any] | None:
|
||||
try:
|
||||
cfg = self._load_oidc_config()
|
||||
if not cfg.introspection_endpoint:
|
||||
logger.info("OIDC config missing introspection_endpoint")
|
||||
return None
|
||||
data = {
|
||||
"token": access_token,
|
||||
"token_type_hint": "access_token",
|
||||
}
|
||||
auth: tuple[str, str] | None = (self.client_id, self.client_secret)
|
||||
r = requests.post(
|
||||
cfg.introspection_endpoint,
|
||||
data=data,
|
||||
auth=auth,
|
||||
headers={"Accept": "application/json"},
|
||||
timeout=self.request_timeout_sec,
|
||||
)
|
||||
if r.status_code == 200:
|
||||
return r.json()
|
||||
logger.info(
|
||||
f"introspection call returned status {r.status_code}: {r.text[:200]}"
|
||||
)
|
||||
return None
|
||||
except requests.RequestException as e:
|
||||
logger.debug(f"introspection request failed: {e}")
|
||||
return None
|
||||
|
||||
def _call_users_api(self, uid: str) -> dict[str, Any]:
|
||||
"""Call Okta Users API to fetch full user profile.
|
||||
|
||||
Requires okta_org_url and okta_api_token to be set. Raises exception on any error.
|
||||
"""
|
||||
if not self.okta_org_url or not self.okta_api_token:
|
||||
raise ValueError(
|
||||
"Okta org URL and API token are required for user profile lookup"
|
||||
)
|
||||
|
||||
try:
|
||||
url = f"{self.okta_org_url.rstrip('/')}/api/v1/users/{uid}"
|
||||
headers = {"Authorization": f"SSWS {self.okta_api_token}"}
|
||||
r = requests.get(url, headers=headers, timeout=self.request_timeout_sec)
|
||||
if r.status_code == 200:
|
||||
return r.json()
|
||||
raise ValueError(
|
||||
f"Okta Users API call failed with status {r.status_code}: {r.text[:200]}"
|
||||
)
|
||||
except requests.RequestException as e:
|
||||
raise ValueError(f"Okta Users API request failed: {e}") from e
|
||||
|
||||
def build_tool_message_content(
|
||||
self, *args: ToolResponse
|
||||
) -> str | list[str | dict[str, Any]]:
|
||||
# The tool emits a single aggregated packet; pass it through as compact JSON
|
||||
profile = args[-1].response if args else {}
|
||||
return json.dumps(profile)
|
||||
|
||||
def get_args_for_non_tool_calling_llm(
|
||||
self,
|
||||
query: str,
|
||||
history: list[PreviousMessage],
|
||||
llm: LLM,
|
||||
force_run: bool = False,
|
||||
) -> dict[str, Any] | None:
|
||||
if force_run:
|
||||
return {}
|
||||
|
||||
# Use LLM to determine if this tool should be called based on the query
|
||||
prompt = f"""
|
||||
You are helping to determine if an Okta profile lookup tool should be called based on a user's query.
|
||||
|
||||
{OKTA_TOOL_DESCRIPTION}
|
||||
|
||||
Query: {query}
|
||||
|
||||
Conversation history:
|
||||
{GENERAL_SEP_PAT}
|
||||
{history}
|
||||
{GENERAL_SEP_PAT}
|
||||
|
||||
Should the Okta profile tool be called for this query? Respond with only "YES" or "NO".
|
||||
""".strip()
|
||||
response = llm.invoke(prompt)
|
||||
if response and "YES" in message_to_string(response).upper():
|
||||
return {}
|
||||
|
||||
return None
|
||||
|
||||
def run(
|
||||
self, override_kwargs: None = None, **llm_kwargs: Any
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
# Try to get UID from userinfo first, then fallback to introspection
|
||||
uid_candidate = None
|
||||
|
||||
# Try userinfo endpoint first
|
||||
userinfo_data = self._call_userinfo(self.access_token)
|
||||
if userinfo_data and isinstance(userinfo_data, dict):
|
||||
uid_candidate = userinfo_data.get("uid")
|
||||
|
||||
# Only try introspection if userinfo didn't provide a UID
|
||||
if not uid_candidate:
|
||||
introspection_data = self._call_introspection(self.access_token)
|
||||
if introspection_data and isinstance(introspection_data, dict):
|
||||
uid_candidate = introspection_data.get("uid")
|
||||
|
||||
if not uid_candidate:
|
||||
raise ValueError(
|
||||
"Unable to fetch user profile from Okta. This likely means your Okta "
|
||||
"token has expired. Please logout, log back in, and try again."
|
||||
)
|
||||
|
||||
# Call Users API to get full profile - this is now required
|
||||
users_api_data = self._call_users_api(uid_candidate)
|
||||
|
||||
yield ToolResponse(
|
||||
id=OKTA_PROFILE_RESPONSE_ID, response=users_api_data["profile"]
|
||||
)
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
# Return the single aggregated profile packet
|
||||
if not args:
|
||||
return {}
|
||||
return args[-1].response
|
||||
@@ -13,6 +13,7 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import SLACK_CHANNEL_ID
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR
|
||||
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR
|
||||
|
||||
|
||||
@@ -34,30 +35,6 @@ class LoggerContextVars:
|
||||
doc_permission_sync_ctx.set(dict())
|
||||
|
||||
|
||||
class TaskAttemptSingleton:
|
||||
"""Used to tell if this process is an indexing job, and if so what is the
|
||||
unique identifier for this indexing attempt. For things like the API server,
|
||||
main background job (scheduler), etc. this will not be used."""
|
||||
|
||||
_INDEX_ATTEMPT_ID: None | int = None
|
||||
_CONNECTOR_CREDENTIAL_PAIR_ID: None | int = None
|
||||
|
||||
@classmethod
|
||||
def get_index_attempt_id(cls) -> None | int:
|
||||
return cls._INDEX_ATTEMPT_ID
|
||||
|
||||
@classmethod
|
||||
def get_connector_credential_pair_id(cls) -> None | int:
|
||||
return cls._CONNECTOR_CREDENTIAL_PAIR_ID
|
||||
|
||||
@classmethod
|
||||
def set_cc_and_index_id(
|
||||
cls, index_attempt_id: int, connector_credential_pair_id: int
|
||||
) -> None:
|
||||
cls._INDEX_ATTEMPT_ID = index_attempt_id
|
||||
cls._CONNECTOR_CREDENTIAL_PAIR_ID = connector_credential_pair_id
|
||||
|
||||
|
||||
def get_log_level_from_str(log_level_str: str = LOG_LEVEL) -> int:
|
||||
log_level_dict = {
|
||||
"CRITICAL": logging.CRITICAL,
|
||||
@@ -102,14 +79,12 @@ class OnyxLoggingAdapter(logging.LoggerAdapter):
|
||||
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
|
||||
break
|
||||
|
||||
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
|
||||
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
|
||||
|
||||
if index_attempt_id is not None:
|
||||
msg = f"[Index Attempt: {index_attempt_id}] {msg}"
|
||||
|
||||
if cc_pair_id is not None:
|
||||
msg = f"[CC Pair: {cc_pair_id}] {msg}"
|
||||
index_attempt_info = INDEX_ATTEMPT_INFO_CONTEXTVAR.get()
|
||||
if index_attempt_info:
|
||||
cc_pair_id, index_attempt_id = index_attempt_info
|
||||
msg = (
|
||||
f"[Index Attempt: {index_attempt_id}] [CC Pair: {cc_pair_id}] {msg}"
|
||||
)
|
||||
|
||||
break
|
||||
|
||||
@@ -230,7 +205,7 @@ def setup_logger(
|
||||
log_levels = ["debug", "info", "notice"]
|
||||
for level in log_levels:
|
||||
file_name = (
|
||||
f"/var/log/{LOG_FILE_NAME}_{level}.log"
|
||||
f"/var/log/onyx/{LOG_FILE_NAME}_{level}.log"
|
||||
if is_containerized
|
||||
else f"./log/{LOG_FILE_NAME}_{level}.log"
|
||||
)
|
||||
|
||||
26
backend/onyx/utils/search_nlp_models_utils.py
Normal file
26
backend/onyx/utils/search_nlp_models_utils.py
Normal file
@@ -0,0 +1,26 @@
|
||||
def pass_aws_key(api_key: str) -> tuple[str, str, str]:
|
||||
"""Parse AWS API key string into components.
|
||||
|
||||
Args:
|
||||
api_key: String in format 'aws_ACCESSKEY_SECRETKEY_REGION'
|
||||
|
||||
Returns:
|
||||
Tuple of (access_key, secret_key, region)
|
||||
|
||||
Raises:
|
||||
ValueError: If key format is invalid
|
||||
"""
|
||||
if not api_key.startswith("aws"):
|
||||
raise ValueError("API key must start with 'aws' prefix")
|
||||
parts = api_key.split("_")
|
||||
if len(parts) != 4:
|
||||
raise ValueError(
|
||||
f"API key must be in format 'aws_ACCESSKEY_SECRETKEY_REGION', got {len(parts) - 1} parts. "
|
||||
"This is an onyx specific format for formatting the aws secrets for bedrock"
|
||||
)
|
||||
|
||||
try:
|
||||
_, aws_access_key_id, aws_secret_access_key, aws_region = parts
|
||||
return aws_access_key_id, aws_secret_access_key, aws_region
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse AWS key components: {str(e)}")
|
||||
@@ -44,12 +44,12 @@ litellm==1.72.2
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
Mako==1.2.4
|
||||
markitdown[pdf, docx, pptx, xlsx, xls]==0.1.2
|
||||
msal==1.28.0
|
||||
nltk==3.9.1
|
||||
Office365-REST-Python-Client==2.5.9
|
||||
oauthlib==3.2.2
|
||||
openai==1.75.0
|
||||
openpyxl==3.0.10
|
||||
passlib==1.7.4
|
||||
playwright==1.41.2
|
||||
psutil==5.9.5
|
||||
@@ -66,7 +66,7 @@ pypdf==5.4.0
|
||||
pytest-mock==3.12.0
|
||||
pytest-playwright==0.7.0
|
||||
python-docx==1.1.2
|
||||
python-dotenv==1.0.0
|
||||
python-dotenv==1.1.1
|
||||
python-multipart==0.0.20
|
||||
pywikibot==9.0.0
|
||||
redis==5.0.8
|
||||
@@ -101,3 +101,5 @@ prometheus_client==0.21.0
|
||||
fastapi-limiter==0.1.6
|
||||
prometheus_fastapi_instrumentator==7.1.0
|
||||
sendgrid==6.11.0
|
||||
voyageai==0.2.3
|
||||
cohere==5.6.1
|
||||
|
||||
@@ -22,7 +22,6 @@ from onyx.configs.app_configs import REDIS_SSL
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_pool import RedisPool
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -130,9 +129,6 @@ def onyx_redis(
|
||||
logger.info(f"Purging locks associated with deleting cc_pair={cc_pair_id}.")
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
match_pattern = f"{tenant_id}:{RedisConnectorIndex.FENCE_PREFIX}_{cc_pair_id}/*"
|
||||
purge_by_match_and_type(match_pattern, "string", batch, dry_run, r)
|
||||
|
||||
redis_delete_if_exists_helper(
|
||||
f"{tenant_id}:{redis_connector.prune.fence_key}", dry_run, r
|
||||
)
|
||||
|
||||
@@ -21,6 +21,11 @@ ONYX_REQUEST_ID_CONTEXTVAR: contextvars.ContextVar[str | None] = contextvars.Con
|
||||
"onyx_request_id", default=None
|
||||
)
|
||||
|
||||
# Used to store cc pair id and index attempt id in multithreaded environments
|
||||
INDEX_ATTEMPT_INFO_CONTEXTVAR: contextvars.ContextVar[tuple[int, int] | None] = (
|
||||
contextvars.ContextVar("index_attempt_info", default=None)
|
||||
)
|
||||
|
||||
"""Utils related to contextvars"""
|
||||
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import pytest
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.salesforce.connector import SalesforceConnector
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
|
||||
|
||||
def extract_key_value_pairs_to_set(
|
||||
@@ -35,7 +36,7 @@ def _load_reference_data(
|
||||
@pytest.fixture
|
||||
def salesforce_connector() -> SalesforceConnector:
|
||||
connector = SalesforceConnector(
|
||||
requested_objects=["Account", "Contact", "Opportunity"],
|
||||
requested_objects=[ACCOUNT_OBJECT_TYPE, "Contact", "Opportunity"],
|
||||
)
|
||||
|
||||
username = os.environ["SF_USERNAME"]
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
|
||||
|
||||
# NOTE: Sharepoint site for tests is "sharepoint-tests"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -43,6 +49,24 @@ EXPECTED_DOCUMENTS = [
|
||||
),
|
||||
]
|
||||
|
||||
EXPECTED_PAGES = [
|
||||
ExpectedDocument(
|
||||
semantic_identifier="CollabHome",
|
||||
content=(
|
||||
"# Home\n\nDisplay recent news.\n\n## News\n\nShow recent activities from your site\n\n"
|
||||
"## Site activity\n\n## Quick links\n\nLearn about a team site\n\nLearn how to add a page\n\n"
|
||||
"Add links to important documents and pages.\n\n## Quick links\n\nDocuments\n\n"
|
||||
"Add a document library\n\n## Document library"
|
||||
),
|
||||
folder_path=None,
|
||||
),
|
||||
ExpectedDocument(
|
||||
semantic_identifier="Home",
|
||||
content="# Home",
|
||||
folder_path=None,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def verify_document_metadata(doc: Document) -> None:
|
||||
"""Verify common metadata that should be present on all documents."""
|
||||
@@ -61,7 +85,7 @@ def verify_document_content(doc: Document, expected: ExpectedDocument) -> None:
|
||||
assert doc.semantic_identifier == expected.semantic_identifier
|
||||
assert len(doc.sections) == 1
|
||||
assert doc.sections[0].text is not None
|
||||
assert expected.content in doc.sections[0].text
|
||||
assert expected.content == doc.sections[0].text
|
||||
verify_document_metadata(doc)
|
||||
|
||||
|
||||
@@ -76,6 +100,17 @@ def find_document(documents: list[Document], semantic_identifier: str) -> Docume
|
||||
return matching_docs[0]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_store_image() -> MagicMock:
|
||||
"""Mock store_image_and_create_section to return a predefined ImageSection."""
|
||||
mock = MagicMock()
|
||||
mock.return_value = (
|
||||
ImageSection(image_file_id="mocked-file-id", link="https://example.com/image"),
|
||||
"mocked-file-id",
|
||||
)
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sharepoint_credentials() -> dict[str, str]:
|
||||
return {
|
||||
@@ -87,199 +122,247 @@ def sharepoint_credentials() -> dict[str, str]:
|
||||
|
||||
def test_sharepoint_connector_all_sites__docs_only(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
mock_store_image: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with no sites
|
||||
connector = SharepointConnector(include_site_pages=False)
|
||||
with patch(
|
||||
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
|
||||
mock_store_image,
|
||||
):
|
||||
# Initialize connector with no sites
|
||||
connector = SharepointConnector(
|
||||
include_site_pages=False, include_site_documents=True
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Not asserting expected sites because that can change in test tenant at any time
|
||||
# Finding any docs is good enough to verify that the connector is working
|
||||
document_batches = list(connector.load_from_state())
|
||||
assert document_batches, "Should find documents from all sites"
|
||||
# Not asserting expected sites because that can change in test tenant at any time
|
||||
# Finding any docs is good enough to verify that the connector is working
|
||||
document_batches = load_all_docs_from_checkpoint_connector(
|
||||
connector=connector,
|
||||
start=0,
|
||||
end=time.time(),
|
||||
)
|
||||
assert document_batches, "Should find documents from all sites"
|
||||
|
||||
|
||||
def test_sharepoint_connector_all_sites__pages_only(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
mock_store_image: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
with patch(
|
||||
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
|
||||
mock_store_image,
|
||||
):
|
||||
# Initialize connector with no docs
|
||||
connector = SharepointConnector(
|
||||
include_site_pages=True, include_site_documents=False
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Not asserting expected sites because that can change in test tenant at any time
|
||||
# Finding any docs is good enough to verify that the connector is working
|
||||
document_batches = load_all_docs_from_checkpoint_connector(
|
||||
connector=connector,
|
||||
start=0,
|
||||
end=time.time(),
|
||||
)
|
||||
assert document_batches, "Should find site pages from all sites"
|
||||
|
||||
|
||||
def test_sharepoint_connector_specific_folder(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
mock_store_image: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the test site URL and specific folder
|
||||
connector = SharepointConnector(
|
||||
sites=[os.environ["SHAREPOINT_SITE"] + "/Shared Documents/test"]
|
||||
)
|
||||
with patch(
|
||||
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
|
||||
mock_store_image,
|
||||
):
|
||||
# Initialize connector with the test site URL and specific folder
|
||||
connector = SharepointConnector(
|
||||
sites=[os.environ["SHAREPOINT_SITE"] + "/Shared Documents/test"],
|
||||
include_site_pages=False,
|
||||
include_site_documents=True,
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Get all documents
|
||||
document_batches = list(connector.load_from_state())
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
# Get all documents
|
||||
found_documents: list[Document] = load_all_docs_from_checkpoint_connector(
|
||||
connector=connector,
|
||||
start=0,
|
||||
end=time.time(),
|
||||
)
|
||||
|
||||
# Should only find documents in the test folder
|
||||
test_folder_docs = [
|
||||
doc
|
||||
for doc in EXPECTED_DOCUMENTS
|
||||
if doc.folder_path and doc.folder_path.startswith("test")
|
||||
]
|
||||
assert len(found_documents) == len(
|
||||
test_folder_docs
|
||||
), "Should only find documents in test folder"
|
||||
# Should only find documents in the test folder
|
||||
test_folder_docs = [
|
||||
doc
|
||||
for doc in EXPECTED_DOCUMENTS
|
||||
if doc.folder_path and doc.folder_path.startswith("test")
|
||||
]
|
||||
assert len(found_documents) == len(
|
||||
test_folder_docs
|
||||
), "Should only find documents in test folder"
|
||||
|
||||
# Verify each expected document
|
||||
for expected in test_folder_docs:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
# Verify each expected document
|
||||
for expected in test_folder_docs:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
|
||||
|
||||
def test_sharepoint_connector_root_folder__docs_only(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
mock_store_image: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the base site URL
|
||||
connector = SharepointConnector(
|
||||
sites=[os.environ["SHAREPOINT_SITE"]], include_site_pages=False
|
||||
)
|
||||
with patch(
|
||||
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
|
||||
mock_store_image,
|
||||
):
|
||||
# Initialize connector with the base site URL
|
||||
connector = SharepointConnector(
|
||||
sites=[os.environ["SHAREPOINT_SITE"]],
|
||||
include_site_pages=False,
|
||||
include_site_documents=True,
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Get all documents
|
||||
document_batches = list(connector.load_from_state())
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
# Get all documents
|
||||
found_documents: list[Document] = load_all_docs_from_checkpoint_connector(
|
||||
connector=connector,
|
||||
start=0,
|
||||
end=time.time(),
|
||||
)
|
||||
|
||||
assert len(found_documents) == len(
|
||||
EXPECTED_DOCUMENTS
|
||||
), "Should find all documents in main library"
|
||||
assert len(found_documents) == len(
|
||||
EXPECTED_DOCUMENTS
|
||||
), "Should find all documents in main library"
|
||||
|
||||
# Verify each expected document
|
||||
for expected in EXPECTED_DOCUMENTS:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
# Verify each expected document
|
||||
for expected in EXPECTED_DOCUMENTS:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
|
||||
|
||||
def test_sharepoint_connector_other_library(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
mock_store_image: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the other library
|
||||
connector = SharepointConnector(
|
||||
sites=[
|
||||
os.environ["SHAREPOINT_SITE"] + "/Other Library",
|
||||
with patch(
|
||||
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
|
||||
mock_store_image,
|
||||
):
|
||||
# Initialize connector with the other library
|
||||
connector = SharepointConnector(
|
||||
sites=[
|
||||
os.environ["SHAREPOINT_SITE"] + "/Other Library",
|
||||
],
|
||||
include_site_pages=False,
|
||||
include_site_documents=True,
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Get all documents
|
||||
found_documents: list[Document] = load_all_docs_from_checkpoint_connector(
|
||||
connector=connector,
|
||||
start=0,
|
||||
end=time.time(),
|
||||
)
|
||||
expected_documents: list[ExpectedDocument] = [
|
||||
doc for doc in EXPECTED_DOCUMENTS if doc.library == "Other Library"
|
||||
]
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
# Should find all documents in `Other Library`
|
||||
assert len(found_documents) == len(
|
||||
expected_documents
|
||||
), "Should find all documents in `Other Library`"
|
||||
|
||||
# Get all documents
|
||||
document_batches = list(connector.load_from_state())
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
expected_documents: list[ExpectedDocument] = [
|
||||
doc for doc in EXPECTED_DOCUMENTS if doc.library == "Other Library"
|
||||
]
|
||||
|
||||
# Should find all documents in `Other Library`
|
||||
assert len(found_documents) == len(
|
||||
expected_documents
|
||||
), "Should find all documents in `Other Library`"
|
||||
|
||||
# Verify each expected document
|
||||
for expected in expected_documents:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
# Verify each expected document
|
||||
for expected in expected_documents:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
|
||||
|
||||
def test_sharepoint_connector_poll(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
mock_store_image: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the base site URL
|
||||
connector = SharepointConnector(
|
||||
sites=["https://danswerai.sharepoint.com/sites/sharepoint-tests"]
|
||||
)
|
||||
with patch(
|
||||
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
|
||||
mock_store_image,
|
||||
):
|
||||
# Initialize connector with the base site URL
|
||||
connector = SharepointConnector(sites=[os.environ["SHAREPOINT_SITE"]])
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Set time window to only capture test1.docx (modified at 2025-01-28 20:51:42+00:00)
|
||||
start = datetime(2025, 1, 28, 20, 51, 30, tzinfo=timezone.utc) # 12 seconds before
|
||||
end = datetime(2025, 1, 28, 20, 51, 50, tzinfo=timezone.utc) # 8 seconds after
|
||||
# Set time window to only capture test1.docx (modified at 2025-01-28 20:51:42+00:00)
|
||||
start = datetime(
|
||||
2025, 1, 28, 20, 51, 30, tzinfo=timezone.utc
|
||||
) # 12 seconds before
|
||||
end = datetime(2025, 1, 28, 20, 51, 50, tzinfo=timezone.utc) # 8 seconds after
|
||||
|
||||
# Get documents within the time window
|
||||
document_batches = list(connector._fetch_from_sharepoint(start=start, end=end))
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
# Get documents within the time window
|
||||
found_documents: list[Document] = load_all_docs_from_checkpoint_connector(
|
||||
connector=connector,
|
||||
start=start.timestamp(),
|
||||
end=end.timestamp(),
|
||||
)
|
||||
|
||||
# Should only find test1.docx
|
||||
assert len(found_documents) == 1, "Should only find one document in the time window"
|
||||
doc = found_documents[0]
|
||||
assert doc.semantic_identifier == "test1.docx"
|
||||
verify_document_metadata(doc)
|
||||
verify_document_content(
|
||||
doc, [d for d in EXPECTED_DOCUMENTS if d.semantic_identifier == "test1.docx"][0]
|
||||
)
|
||||
# Should only find test1.docx
|
||||
assert (
|
||||
len(found_documents) == 1
|
||||
), "Should only find one document in the time window"
|
||||
doc = found_documents[0]
|
||||
assert doc.semantic_identifier == "test1.docx"
|
||||
verify_document_content(
|
||||
doc,
|
||||
next(
|
||||
d for d in EXPECTED_DOCUMENTS if d.semantic_identifier == "test1.docx"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_sharepoint_connector_pages(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
mock_store_image: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the base site URL
|
||||
connector = SharepointConnector(
|
||||
sites=["https://danswerai.sharepoint.com/sites/sharepoint-tests-pages"]
|
||||
)
|
||||
with patch(
|
||||
"onyx.connectors.sharepoint.connector.store_image_and_create_section",
|
||||
mock_store_image,
|
||||
):
|
||||
connector = SharepointConnector(
|
||||
sites=[os.environ["SHAREPOINT_SITE"]],
|
||||
include_site_pages=True,
|
||||
include_site_documents=False,
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Get documents within the time window
|
||||
document_batches = list(connector.load_from_state())
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
found_documents = load_all_docs_from_checkpoint_connector(
|
||||
connector=connector,
|
||||
start=0,
|
||||
end=time.time(),
|
||||
)
|
||||
|
||||
# Should only find CollabHome
|
||||
assert len(found_documents) == 1, "Should only find one page"
|
||||
doc = found_documents[0]
|
||||
assert doc.semantic_identifier == "CollabHome"
|
||||
verify_document_metadata(doc)
|
||||
assert len(doc.sections) == 1
|
||||
assert (
|
||||
doc.sections[0].text
|
||||
== """
|
||||
# Home
|
||||
assert len(found_documents) == len(
|
||||
EXPECTED_PAGES
|
||||
), "Should find all pages in test site"
|
||||
|
||||
Display recent news.
|
||||
|
||||
## News
|
||||
|
||||
Show recent activities from your site
|
||||
|
||||
## Site activity
|
||||
|
||||
## Quick links
|
||||
|
||||
Learn about a team site
|
||||
|
||||
Learn how to add a page
|
||||
|
||||
Add links to important documents and pages.
|
||||
|
||||
## Quick links
|
||||
|
||||
Documents
|
||||
|
||||
Add a document library
|
||||
|
||||
## Document library
|
||||
""".strip()
|
||||
)
|
||||
for expected in EXPECTED_PAGES:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
|
||||
Binary file not shown.
@@ -0,0 +1,113 @@
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.connectors.sharepoint.connector import SharepointAuthMethod
|
||||
from onyx.db.enums import AccessType
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestConnector
|
||||
from tests.integration.common_utils.test_models import DATestCredential
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
SharepointTestEnvSetupTuple = tuple[
|
||||
DATestUser, # admin_user
|
||||
DATestUser, # regular_user_1
|
||||
DATestUser, # regular_user_2
|
||||
DATestCredential,
|
||||
DATestConnector,
|
||||
DATestCCPair,
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def sharepoint_test_env_setup() -> Generator[SharepointTestEnvSetupTuple]:
|
||||
# Reset all data before running the test
|
||||
reset_all()
|
||||
# Required environment variables for SharePoint certificate authentication
|
||||
sp_client_id = os.environ.get("PERM_SYNC_SHAREPOINT_CLIENT_ID")
|
||||
sp_private_key = os.environ.get("PERM_SYNC_SHAREPOINT_PRIVATE_KEY")
|
||||
sp_certificate_password = os.environ.get(
|
||||
"PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD"
|
||||
)
|
||||
sp_directory_id = os.environ.get("PERM_SYNC_SHAREPOINT_DIRECTORY_ID")
|
||||
sharepoint_sites = "https://danswerai.sharepoint.com/sites/Permisisonsync"
|
||||
admin_email = "admin@onyx.app"
|
||||
user1_email = "subash@onyx.app"
|
||||
user2_email = "raunak@onyx.app"
|
||||
|
||||
if not sp_private_key or not sp_certificate_password or not sp_directory_id:
|
||||
pytest.skip("Skipping test because required environment variables are not set")
|
||||
|
||||
# Certificate-based credentials
|
||||
credentials = {
|
||||
"authentication_method": SharepointAuthMethod.CERTIFICATE.value,
|
||||
"sp_client_id": sp_client_id,
|
||||
"sp_private_key": sp_private_key,
|
||||
"sp_certificate_password": sp_certificate_password,
|
||||
"sp_directory_id": sp_directory_id,
|
||||
}
|
||||
|
||||
# Create users
|
||||
admin_user: DATestUser = UserManager.create(email=admin_email)
|
||||
regular_user_1: DATestUser = UserManager.create(email=user1_email)
|
||||
regular_user_2: DATestUser = UserManager.create(email=user2_email)
|
||||
|
||||
# Create LLM provider for search functionality
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
# Create credential
|
||||
credential: DATestCredential = CredentialManager.create(
|
||||
source=DocumentSource.SHAREPOINT,
|
||||
credential_json=credentials,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Create connector with SharePoint-specific configuration
|
||||
connector: DATestConnector = ConnectorManager.create(
|
||||
name="SharePoint Test",
|
||||
input_type=InputType.POLL,
|
||||
source=DocumentSource.SHAREPOINT,
|
||||
connector_specific_config={
|
||||
"sites": sharepoint_sites.split(","),
|
||||
},
|
||||
access_type=AccessType.SYNC, # Enable permission sync
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Create CC pair with permission sync enabled
|
||||
cc_pair: DATestCCPair = CCPairManager.create(
|
||||
credential_id=credential.id,
|
||||
connector_id=connector.id,
|
||||
access_type=AccessType.SYNC, # Enable permission sync
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Wait for both indexing and permission sync to complete
|
||||
before = datetime.now(tz=timezone.utc)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
timeout=float("inf"),
|
||||
)
|
||||
|
||||
# Wait for permission sync completion specifically
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
timeout=float("inf"),
|
||||
)
|
||||
|
||||
yield admin_user, regular_user_1, regular_user_2, credential, connector, cc_pair
|
||||
@@ -0,0 +1,214 @@
|
||||
import os
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.access.access import _get_access_for_documents
|
||||
from ee.onyx.db.external_perm import fetch_external_groups_for_user
|
||||
from onyx.access.utils import prefix_external_group
|
||||
from onyx.access.utils import prefix_user_email
|
||||
from onyx.configs.constants import PUBLIC_DOC_PAT
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import fetch_user_by_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.connector_job_tests.sharepoint.conftest import (
|
||||
SharepointTestEnvSetupTuple,
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_user_acl(user: User, db_session: Session) -> set[str]:
|
||||
db_external_groups = (
|
||||
fetch_external_groups_for_user(db_session, user.id) if user else []
|
||||
)
|
||||
prefixed_external_groups = [
|
||||
prefix_external_group(db_external_group.external_user_group_id)
|
||||
for db_external_group in db_external_groups
|
||||
]
|
||||
|
||||
user_acl = set(prefixed_external_groups)
|
||||
user_acl.update({prefix_user_email(user.email), PUBLIC_DOC_PAT})
|
||||
return user_acl
|
||||
|
||||
|
||||
def get_user_document_access_via_acl(
|
||||
test_user: DATestUser, document_ids: List[str], db_session: Session
|
||||
) -> List[str]:
|
||||
|
||||
# Get the actual User object from the database
|
||||
user = fetch_user_by_id(db_session, UUID(test_user.id))
|
||||
if not user:
|
||||
logger.error(f"Could not find user with ID {test_user.id}")
|
||||
return []
|
||||
|
||||
user_acl = get_user_acl(user, db_session)
|
||||
logger.info(f"User {user.email} ACL entries: {user_acl}")
|
||||
|
||||
# Get document access information
|
||||
doc_access_map = _get_access_for_documents(document_ids, db_session)
|
||||
logger.info(f"Found access info for {len(doc_access_map)} documents")
|
||||
|
||||
accessible_docs = []
|
||||
for doc_id, doc_access in doc_access_map.items():
|
||||
doc_acl = doc_access.to_acl()
|
||||
logger.info(f"Document {doc_id} ACL: {doc_acl}")
|
||||
|
||||
# Check if user has any matching ACL entry
|
||||
if user_acl.intersection(doc_acl):
|
||||
accessible_docs.append(doc_id)
|
||||
logger.info(f"User {user.email} has access to document {doc_id}")
|
||||
else:
|
||||
logger.info(f"User {user.email} does NOT have access to document {doc_id}")
|
||||
|
||||
return accessible_docs
|
||||
|
||||
|
||||
def get_all_connector_documents(
|
||||
cc_pair: DATestCCPair, db_session: Session
|
||||
) -> List[str]:
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from sqlalchemy import select
|
||||
|
||||
stmt = select(DocumentByConnectorCredentialPair.id).where(
|
||||
DocumentByConnectorCredentialPair.connector_id == cc_pair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == cc_pair.credential_id,
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
document_ids = [row[0] for row in result.fetchall()]
|
||||
logger.info(
|
||||
f"Found {len(document_ids)} documents for connector {cc_pair.connector_id}"
|
||||
)
|
||||
|
||||
return document_ids
|
||||
|
||||
|
||||
def get_documents_by_permission_type(
|
||||
document_ids: List[str], db_session: Session
|
||||
) -> List[str]:
|
||||
"""
|
||||
Categorize documents by their permission types
|
||||
Returns a dictionary with lists of document IDs for each permission type
|
||||
"""
|
||||
doc_access_map = _get_access_for_documents(document_ids, db_session)
|
||||
|
||||
public_docs = []
|
||||
|
||||
for doc_id, doc_access in doc_access_map.items():
|
||||
if doc_access.is_public:
|
||||
public_docs.append(doc_id)
|
||||
|
||||
return public_docs
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission tests are enterprise only",
|
||||
)
|
||||
def test_public_documents_accessible_by_all_users(
|
||||
sharepoint_test_env_setup: SharepointTestEnvSetupTuple,
|
||||
) -> None:
|
||||
"""Test that public documents are accessible by both test users using ACL verification"""
|
||||
(
|
||||
admin_user,
|
||||
regular_user_1,
|
||||
regular_user_2,
|
||||
credential,
|
||||
connector,
|
||||
cc_pair,
|
||||
) = sharepoint_test_env_setup
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Get all documents for this connector
|
||||
all_document_ids = get_all_connector_documents(cc_pair, db_session)
|
||||
|
||||
# Test that regular_user_1 can access documents
|
||||
accessible_docs_user1 = get_user_document_access_via_acl(
|
||||
test_user=regular_user_1,
|
||||
document_ids=all_document_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Test that regular_user_2 can access documents
|
||||
accessible_docs_user2 = get_user_document_access_via_acl(
|
||||
test_user=regular_user_2,
|
||||
document_ids=all_document_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
logger.info(f"User 1 has access to {len(accessible_docs_user1)} documents")
|
||||
logger.info(f"User 2 has access to {len(accessible_docs_user2)} documents")
|
||||
|
||||
# For public documents, both users should have access to at least some docs
|
||||
assert len(accessible_docs_user1) == 8, (
|
||||
f"User 1 should have access to documents. Found "
|
||||
f"{len(accessible_docs_user1)} accessible docs out of "
|
||||
f"{len(all_document_ids)} total"
|
||||
)
|
||||
assert len(accessible_docs_user2) == 1, (
|
||||
f"User 2 should have access to documents. Found "
|
||||
f"{len(accessible_docs_user2)} accessible docs out of "
|
||||
f"{len(all_document_ids)} total"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Successfully verified public documents are accessible by users via ACL"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission tests are enterprise only",
|
||||
)
|
||||
def test_group_based_permissions(
|
||||
sharepoint_test_env_setup: SharepointTestEnvSetupTuple,
|
||||
) -> None:
|
||||
"""Test that documents with group permissions are accessible only by users in that group using ACL verification"""
|
||||
(
|
||||
admin_user,
|
||||
regular_user_1,
|
||||
regular_user_2,
|
||||
credential,
|
||||
connector,
|
||||
cc_pair,
|
||||
) = sharepoint_test_env_setup
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Get all documents for this connector
|
||||
all_document_ids = get_all_connector_documents(cc_pair, db_session)
|
||||
|
||||
if not all_document_ids:
|
||||
pytest.skip("No documents found for connector - skipping test")
|
||||
|
||||
# Test access for both users
|
||||
accessible_docs_user1 = get_user_document_access_via_acl(
|
||||
test_user=regular_user_1,
|
||||
document_ids=all_document_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
accessible_docs_user2 = get_user_document_access_via_acl(
|
||||
test_user=regular_user_2,
|
||||
document_ids=all_document_ids,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
logger.info(f"User 1 has access to {len(accessible_docs_user1)} documents")
|
||||
logger.info(f"User 2 has access to {len(accessible_docs_user2)} documents")
|
||||
|
||||
public_docs = get_documents_by_permission_type(all_document_ids, db_session)
|
||||
|
||||
# Check if user 2 has access to any non-public documents
|
||||
non_public_access_user2 = [
|
||||
doc for doc in accessible_docs_user2 if doc not in public_docs
|
||||
]
|
||||
|
||||
assert (
|
||||
len(non_public_access_user2) == 0
|
||||
), f"User 2 should only have access to public documents. Found access to non-public docs: {non_public_access_user2}"
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user