Compare commits

..

51 Commits

Author SHA1 Message Date
pablonyx
76ac0243f5 k 2025-03-30 13:10:55 -07:00
rkuo-danswer
cb5bbd3812 Feature/mit integration tests (#4299)
* new mit integration test template

* edit

* fix problem with ACL type tags and MIT testing for test_connector_deletion

* fix test_connector_deletion_for_overlapping_connectors

* disable some enterprise only tests in MIT version

* disable a bunch of user group / curator tests in MIT version

* wire off more tests

* typo fix

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-03-30 02:41:08 +00:00
Yuhong Sun
742d29e504 Remove BETA 2025-03-29 15:38:46 -07:00
SubashMohan
ecc155d082 fix: ensure base_url ends with a trailing slash (#4388) 2025-03-29 14:34:30 -07:00
pablonyx
0857e4809d fix background color 2025-03-28 16:33:30 -07:00
Chris Weaver
22e00a1f5c Fix duplicate docs (#4378)
* Initial

* Fix duplicate docs

* Add tests

* Switch to list comprehension

* Fix test
2025-03-28 22:25:26 +00:00
Chris Weaver
0d0588a0c1 Remove OnyxContext (#4376)
* Remove OnyxContext

* Fix UT

* Fix tests v2
2025-03-28 12:39:51 -07:00
rkuo-danswer
aab777f844 Bugfix/acl prefix (#4377)
* fix acl prefixing

* increase timeout a tad

* block access to init'ing DocumentAccess directly, fix test to work with ee/MIT

* fix env var checks

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-28 05:52:35 +00:00
pablonyx
babbe7689a k (#4380) 2025-03-28 02:23:45 +00:00
evan-danswer
a123661c92 fixed shared folder issue (#4371)
* fixed shared folder issue

* fix existing tests

* default allow files shared with me for service account
2025-03-27 23:39:52 +00:00
pablonyx
c554889baf Fix actions link (#4374) 2025-03-27 16:39:35 -07:00
rkuo-danswer
f08fa878a6 refactor file extension checking and add test for blob s3 (#4369)
* refactor file extension checking and add test for blob s3

* code review

* fix checking ext

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 18:57:44 +00:00
pablonyx
d307534781 add some debug logging (#4328) 2025-03-27 11:49:32 -07:00
rkuo-danswer
6f54791910 adjust some vars in real time (#4365)
* adjust some vars in real time

* some sanity checking

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 17:30:08 +00:00
pablonyx
0d5497bb6b Add multi-tenant user invitation flow test (#4360) 2025-03-27 09:53:15 -07:00
Chris Weaver
7648627503 Save all logs + add log persistence to most Onyx-owned containers (#4368)
* Save all logs + add log persistence to most Onyx-owned containers

* Separate volumes for each container

* Small fixes
2025-03-26 22:25:39 -07:00
pablonyx
927554d5ca slight robustification (#4367) 2025-03-27 03:23:36 +00:00
pablonyx
7dcec6caf5 Fix session touching (#4363)
* fix session touching

* Revert "fix session touching"

This reverts commit c473d5c9a2.

* Revert "Revert "fix session touching""

This reverts commit 26a71d40b6.

* update

* quick nit
2025-03-27 01:18:46 +00:00
rkuo-danswer
036648146d possible fix for confluence query filter (#4280)
* possible fix for confluence query filter

* nuke the attachment filter query ... it doesn't work!

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 00:35:14 +00:00
rkuo-danswer
2aa4697ac8 permission sync runs so often that it starves out other tasks if run at high priority (#4364)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 00:22:53 +00:00
rkuo-danswer
bc9b4e4f45 use slack's built in rate limit handler for the bot (#4362)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-26 21:55:04 +00:00
evan-danswer
178a64f298 fix issue with drive connector service account indexing (#4356)
* fix issue with drive connector service account indexing

* correct checkpoint resumption

* final set of fixes

* nit

* fix typing

* logging and CW comments

* nit
2025-03-26 20:54:26 +00:00
pablonyx
c79f1edf1d add a flush (#4361) 2025-03-26 14:40:52 -07:00
pablonyx
7c8e23aa54 Fix saml conversion from ext_perm -> basic (#4343)
* fix saml conversion from ext_perm -> basic

* quick nit

* minor fix

* finalize

* update

* quick fix
2025-03-26 20:36:51 +00:00
pablonyx
d37b427d52 fix email flow (#4339) 2025-03-26 18:59:12 +00:00
pablonyx
a65fefd226 test fix 2025-03-26 12:43:38 -07:00
rkuo-danswer
bb09bde519 Bugfix/google drive size threshold 2 (#4355) 2025-03-26 12:06:36 -07:00
Tim Rosenblatt
0f6cf0fc58 Fixes docker logs helper text in run-nginx.sh (#3678)
The docker container name is slightly wrong, and this commit fixes it.
2025-03-26 09:03:35 -07:00
pablonyx
fed06b592d Auto refresh credentials (#4268)
* Auto refresh credentials

* remove dupes

* clean up + tests

* k

* quick nit

* add brief comment

* misc typing
2025-03-26 01:53:31 +00:00
pablonyx
8d92a1524e fix invitation on cloud (#4351)
* fix invitation on cloud

* k
2025-03-26 01:25:17 +00:00
pablonyx
ecfea9f5ed Email formatting devices (#4353)
* update email formatting

* k

* update

* k

* nit
2025-03-25 21:42:32 +00:00
rkuo-danswer
b269f1ba06 fix broken function call (#4354)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-25 21:07:31 +00:00
pablonyx
30c878efa5 Quick fix (#4341)
* quick fix

* Revert "quick fix"

This reverts commit f113616276.

* smaller chnage
2025-03-25 18:39:55 +00:00
pablonyx
2024776c19 Respect contextvars when parallelizing for Google Drive (#4291)
* k

* k

* fix typing
2025-03-25 17:40:12 +00:00
pablonyx
431316929c k (#4336) 2025-03-25 17:00:35 +00:00
pablonyx
c5b9c6e308 update (#4344) 2025-03-25 16:56:23 +00:00
pablonyx
73dd188b3f update (#4338) 2025-03-25 16:55:25 +00:00
evan-danswer
79b061abbc Daylight savings time handling (#4345)
* confluence timezone improvements

* confluence timezone improvements
2025-03-25 16:11:30 +00:00
rkuo-danswer
552f1ead4f use correct namespace in redis for certain keys (#4340)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-25 04:10:31 +00:00
evan-danswer
17925b49e8 typing fix (#4342)
* typing fix

* changed type hint to help future coders
2025-03-25 01:01:13 +00:00
rkuo-danswer
55fb5c3ca5 add size threshold for google drive (#4329)
* add size threshold for google drive

* greptile nits

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-24 04:09:28 +00:00
evan-danswer
99546e4a4d zendesk checkpointed connector (#4311)
* zendesk v1

* logic fix

* zendesk testing

* add unit tests

* zendesk caching

* CW comments

* fix unit tests
2025-03-23 20:43:13 +00:00
pablonyx
c25d56f4a5 Improved drive flow UX (#4331)
* wip

* k

* looking good

* clenaed up

* quick nit
2025-03-23 19:21:03 +00:00
Chris Weaver
35f3f4f120 Small slack bot fixes (#4333) 2025-03-22 23:22:17 +00:00
Weves
25b69a8aca Adjust spammy log 2025-03-22 14:52:09 -07:00
pablonyx
1b7d710b2a Fix links from file metadata (#4324)
* quick fix

* clarify comment

* fix file metadata

* k
2025-03-22 18:21:47 +00:00
pablonyx
ae3d3db3f4 Update slack bot listing endpoint (#4325)
* update slack bot listing endpoint

* nit
2025-03-22 18:21:31 +00:00
evan-danswer
fb79a9e700 Checkpointed GitHub connector (#4307)
* WIP github checkpointing

* first draft of github checkpointing

* nit

* CW comments

* github basic connector test

* connector test env var

* secrets cant start with GITHUB_

* unit tests and bug fix

* connector failures

* address CW comments

* validation fix

* validation fix

* remove prints

* fixed tests

* 100 items per page
2025-03-22 01:48:05 +00:00
rkuo-danswer
587ba11bbc alembic script logging fixes (#4322)
* log fixing

* fix typos

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-22 00:50:58 +00:00
pablonyx
fce81ebb60 Minor ux nits (#4327)
* k

* quick fix
2025-03-21 21:50:56 +00:00
Chris Weaver
61facfb0a8 Fix slack connector (#4326) 2025-03-21 21:30:03 +00:00
154 changed files with 5507 additions and 1707 deletions

View File

@@ -0,0 +1,209 @@
name: Run MIT Integration Tests v2
concurrency:
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Web Docker image
run: |
docker pull onyxdotapp/onyx-web-server:latest
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# images are built and run locally for testing purposes. Not pushed.
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: onyxdotapp/onyx-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
- name: Start Docker containers
run: |
cd deployment/docker_compose
AUTH_TYPE=basic \
POSTGRES_POOL_PRE_PING=true \
POSTGRES_USE_NULL_POOL=true \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f onyx-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Start Mock Services
run: |
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
onyxdotapp/onyx-integration:test \
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
continue-on-error: true
id: run_tests
- name: Check test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
# ------------------------------------------------------------
# Always gather logs BEFORE "down":
- name: Dump API server logs
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack down -v

View File

@@ -9,6 +9,10 @@ on:
- cron: "0 16 * * *"
env:
# AWS
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}
@@ -45,6 +49,8 @@ env:
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
# Github
ACCESS_TOKEN_GITHUB: ${{ secrets.ACCESS_TOKEN_GITHUB }}
# Gitbook
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}

View File

@@ -102,6 +102,7 @@ COPY ./alembic /app/alembic
COPY ./alembic_tenants /app/alembic_tenants
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf
COPY ./static /app/static
# Escape hatch scripts
COPY ./scripts/debugging /app/scripts/debugging

View File

@@ -84,7 +84,7 @@ keys = console
keys = generic
[logger_root]
level = WARN
level = INFO
handlers = console
qualname =

View File

@@ -25,6 +25,9 @@ from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
from onyx.db.models import Base
from celery.backends.database.session import ResultModelBase # type: ignore
# Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be
# hidden! (defaults to level=WARN)
# Alembic Config object
config = context.config
@@ -36,6 +39,7 @@ if config.config_file_name is not None and config.attributes.get(
target_metadata = [Base.metadata, ResultModelBase.metadata]
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
logger = logging.getLogger(__name__)
ssl_context: ssl.SSLContext | None = None
@@ -64,7 +68,7 @@ def include_object(
return True
def get_schema_options() -> tuple[str, bool, bool]:
def get_schema_options() -> tuple[str, bool, bool, bool]:
x_args_raw = context.get_x_argument()
x_args = {}
for arg in x_args_raw:
@@ -76,6 +80,10 @@ def get_schema_options() -> tuple[str, bool, bool]:
create_schema = x_args.get("create_schema", "true").lower() == "true"
upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true"
# continue on error with individual tenant
# only applies to online migrations
continue_on_error = x_args.get("continue", "false").lower() == "true"
if (
MULTI_TENANT
and schema_name == POSTGRES_DEFAULT_SCHEMA
@@ -86,14 +94,12 @@ def get_schema_options() -> tuple[str, bool, bool]:
"Please specify a tenant-specific schema."
)
return schema_name, create_schema, upgrade_all_tenants
return schema_name, create_schema, upgrade_all_tenants, continue_on_error
def do_run_migrations(
connection: Connection, schema_name: str, create_schema: bool
) -> None:
logger.info(f"About to migrate schema: {schema_name}")
if create_schema:
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"'))
connection.execute(text("COMMIT"))
@@ -134,7 +140,12 @@ def provide_iam_token_for_alembic(
async def run_async_migrations() -> None:
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
(
schema_name,
create_schema,
upgrade_all_tenants,
continue_on_error,
) = get_schema_options()
engine = create_async_engine(
build_connection_string(),
@@ -151,9 +162,15 @@ async def run_async_migrations() -> None:
if upgrade_all_tenants:
tenant_schemas = get_all_tenant_ids()
i_tenant = 0
num_tenants = len(tenant_schemas)
for schema in tenant_schemas:
i_tenant += 1
logger.info(
f"Migrating schema: index={i_tenant} num_tenants={num_tenants} schema={schema}"
)
try:
logger.info(f"Migrating schema: {schema}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
@@ -162,7 +179,12 @@ async def run_async_migrations() -> None:
)
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
raise
if not continue_on_error:
logger.error("--continue is not set, raising exception!")
raise
logger.warning("--continue is set, continuing to next schema.")
else:
try:
logger.info(f"Migrating schema: {schema_name}")
@@ -180,7 +202,11 @@ async def run_async_migrations() -> None:
def run_migrations_offline() -> None:
schema_name, _, upgrade_all_tenants = get_schema_options()
"""This doesn't really get used when we migrate in the cloud."""
logger.info("run_migrations_offline starting.")
schema_name, _, upgrade_all_tenants, continue_on_error = get_schema_options()
url = build_connection_string()
if upgrade_all_tenants:
@@ -230,6 +256,7 @@ def run_migrations_offline() -> None:
def run_migrations_online() -> None:
logger.info("run_migrations_online starting.")
asyncio.run(run_async_migrations())

View File

@@ -28,6 +28,20 @@ depends_on = None
def upgrade() -> None:
# First, drop any existing indexes to avoid conflicts
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
# Drop existing columns if they exist
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
# Create a GIN index for full-text search on chat_message.message
op.execute(
"""

View File

@@ -93,12 +93,12 @@ def _get_access_for_documents(
)
# To avoid collisions of group namings between connectors, they need to be prefixed
access_map[document_id] = DocumentAccess(
user_emails=non_ee_access.user_emails,
user_groups=set(user_group_info.get(document_id, [])),
access_map[document_id] = DocumentAccess.build(
user_emails=list(non_ee_access.user_emails),
user_groups=user_group_info.get(document_id, []),
is_public=is_public_anywhere,
external_user_emails=ext_u_emails,
external_user_group_ids=ext_u_groups,
external_user_emails=list(ext_u_emails),
external_user_group_ids=list(ext_u_groups),
)
return access_map

View File

@@ -2,7 +2,6 @@ from ee.onyx.server.query_and_chat.models import OneShotQAResponse
from onyx.chat.models import AllCitations
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.process_message import ChatPacketStream
@@ -32,8 +31,6 @@ def gather_stream_for_answer_api(
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
elif isinstance(packet, AllCitations):
response.citations = packet.citations
elif isinstance(packet, OnyxContexts):
response.contexts = packet
if answer:
response.answer = answer

View File

@@ -25,6 +25,10 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
#####
# Auto Permission Sync
#####
DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# In seconds, default is 5 minutes
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
@@ -39,6 +43,7 @@ CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
@@ -72,6 +77,13 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
)
GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("SLACK_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# The posthog client does not accept empty API keys or hosts however it fails silently
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app

View File

@@ -3,6 +3,8 @@ from collections.abc import Generator
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
@@ -66,13 +68,13 @@ GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all doc permissions every 5 minutes
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
DocumentSource.SLACK: 5 * 60,
DocumentSource.SLACK: SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
}
# If nothing is specified here, we run the doc_sync every time the celery beat runs
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all group permissions every 30 minutes
DocumentSource.GOOGLE_DRIVE: 5 * 60,
DocumentSource.GOOGLE_DRIVE: GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY,
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
}

View File

@@ -64,7 +64,15 @@ def get_application() -> FastAPI:
add_tenant_id_middleware(application, logger)
if AUTH_TYPE == AuthType.CLOUD:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
# For Google OAuth, refresh tokens are requested by:
# 1. Adding the right scopes
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
oauth_client = GoogleOAuth2(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
# Use standard scopes that include profile and email
scopes=["openid", "email", "profile"],
)
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -87,6 +95,16 @@ def get_application() -> FastAPI:
)
if AUTH_TYPE == AuthType.OIDC:
# Ensure we request offline_access for refresh tokens
try:
oidc_scopes = list(OIDC_SCOPE_OVERRIDE or BASE_SCOPES)
if "offline_access" not in oidc_scopes:
oidc_scopes.append("offline_access")
except Exception as e:
logger.warning(f"Error configuring OIDC scopes: {e}")
# Fall back to default scopes if there's an error
oidc_scopes = BASE_SCOPES
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -94,8 +112,8 @@ def get_application() -> FastAPI:
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
OPENID_CONFIG_URL,
# BASE_SCOPES is the same as not setting this
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
# Use the configured scopes
base_scopes=oidc_scopes,
),
auth_backend,
USER_AUTH_SECRET,

View File

@@ -14,7 +14,6 @@ from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
)
from ee.onyx.server.query_and_chat.models import ChatBasicResponse
from ee.onyx.server.query_and_chat.models import SimpleDoc
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
@@ -56,25 +55,6 @@ logger = setup_logger()
router = APIRouter(prefix="/chat")
def _translate_doc_response_to_simple_doc(
doc_response: QADocsResponse,
) -> list[SimpleDoc]:
return [
SimpleDoc(
id=doc.document_id,
semantic_identifier=doc.semantic_identifier,
link=doc.link,
blurb=doc.blurb,
match_highlights=[
highlight for highlight in doc.match_highlights if highlight
],
source_type=doc.source_type,
metadata=doc.metadata,
)
for doc in doc_response.top_documents
]
def _get_final_context_doc_indices(
final_context_docs: list[LlmDoc] | None,
top_docs: list[SavedSearchDoc] | None,
@@ -111,9 +91,6 @@ def _convert_packet_stream_to_response(
elif isinstance(packet, QADocsResponse):
response.top_documents = packet.top_documents
# TODO: deprecate `simple_search_docs`
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
# This is a no-op if agent_sub_questions hasn't already been filled
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)

View File

@@ -8,7 +8,6 @@ from pydantic import model_validator
from ee.onyx.server.manage.models import StandardAnswer
from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
@@ -164,8 +163,6 @@ class ChatBasicResponse(BaseModel):
cited_documents: dict[int, str] | None = None
# FOR BACKWARDS COMPATIBILITY
# TODO: deprecate both of these
simple_search_docs: list[SimpleDoc] | None = None
llm_chunks_indices: list[int] | None = None
# agentic fields
@@ -220,4 +217,3 @@ class OneShotQAResponse(BaseModel):
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
chat_message_id: int | None = None
contexts: OnyxContexts | None = None

View File

@@ -38,6 +38,7 @@ router = APIRouter(prefix="/auth/saml")
async def upsert_saml_user(email: str) -> User:
logger.debug(f"Attempting to upsert SAML user with email: {email}")
get_async_session_context = contextlib.asynccontextmanager(
get_async_session
) # type:ignore
@@ -48,9 +49,13 @@ async def upsert_saml_user(email: str) -> User:
async with get_user_db_context(session) as user_db:
async with get_user_manager_context(user_db) as user_manager:
try:
return await user_manager.get_by_email(email)
user = await user_manager.get_by_email(email)
# If user has a non-authenticated role, treat as non-existent
if not user.role.is_web_login:
raise exceptions.UserNotExists()
return user
except exceptions.UserNotExists:
logger.notice("Creating user from SAML login")
logger.info("Creating user from SAML login")
user_count = await get_user_count()
role = UserRole.ADMIN if user_count == 0 else UserRole.BASIC
@@ -59,11 +64,10 @@ async def upsert_saml_user(email: str) -> User:
password = fastapi_users_pw_helper.generate()
hashed_pass = fastapi_users_pw_helper.hash(password)
user: User = await user_manager.create(
user = await user_manager.create(
UserCreate(
email=email,
password=hashed_pass,
is_verified=True,
role=role,
)
)

View File

@@ -87,11 +87,14 @@ async def get_or_provision_tenant(
# If we have a pre-provisioned tenant, assign it to the user
await assign_tenant_to_user(tenant_id, email, referral_source)
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
return tenant_id
else:
# If no pre-provisioned tenant is available, create a new one on-demand
tenant_id = await create_tenant(email, referral_source)
return tenant_id
# Notify control plane if we have created / assigned a new tenant
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)
return tenant_id
except Exception as e:
# If we've encountered an error, log and raise an exception
@@ -116,10 +119,6 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane if not already done in provision_tenant
if not DEV_MODE and referral_source:
await notify_control_plane(tenant_id, email, referral_source)
except Exception as e:
logger.exception(f"Tenant provisioning failed: {str(e)}")
# Attempt to rollback the tenant provisioning
@@ -561,7 +560,3 @@ async def assign_tenant_to_user(
except Exception:
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
raise Exception("Failed to assign tenant to user")
# Notify control plane with retry logic
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)

View File

@@ -70,6 +70,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
"""
Add users to a tenant with proper transaction handling.
Checks if users already have a tenant mapping to avoid duplicates.
If a user already has an active mapping to any tenant, the new mapping will be added as inactive.
"""
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
@@ -88,9 +89,25 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
.first()
)
# If user already has an active mapping, add this one as inactive
if not existing_mapping:
# Only add if mapping doesn't exist
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
# Check if the user already has an active mapping to any tenant
has_active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.first()
)
db_session.add(
UserTenantMapping(
email=email,
tenant_id=tenant_id,
active=False if has_active_mapping else True,
)
)
# Commit the transaction
db_session.commit()

View File

@@ -18,7 +18,7 @@ def _get_access_for_document(
document_id=document_id,
)
return DocumentAccess.build(
doc_access = DocumentAccess.build(
user_emails=info[1] if info and info[1] else [],
user_groups=[],
external_user_emails=[],
@@ -26,6 +26,8 @@ def _get_access_for_document(
is_public=info[2] if info else False,
)
return doc_access
def get_access_for_document(
document_id: str,
@@ -38,12 +40,12 @@ def get_access_for_document(
def get_null_document_access() -> DocumentAccess:
return DocumentAccess(
user_emails=set(),
user_groups=set(),
return DocumentAccess.build(
user_emails=[],
user_groups=[],
is_public=False,
external_user_emails=set(),
external_user_group_ids=set(),
external_user_emails=[],
external_user_group_ids=[],
)
@@ -56,18 +58,18 @@ def _get_access_for_documents(
document_ids=document_ids,
)
doc_access = {
document_id: DocumentAccess(
user_emails=set([email for email in user_emails if email]),
document_id: DocumentAccess.build(
user_emails=[email for email in user_emails if email],
# MIT version will wipe all groups and external groups on update
user_groups=set(),
user_groups=[],
is_public=is_public,
external_user_emails=set(),
external_user_group_ids=set(),
external_user_emails=[],
external_user_group_ids=[],
)
for document_id, user_emails, is_public in document_access_info
}
# Sometimes the document has not be indexed by the indexing job yet, in those cases
# Sometimes the document has not been indexed by the indexing job yet, in those cases
# the document does not exist and so we use least permissive. Specifically the EE version
# checks the MIT version permissions and creates a superset. This ensures that this flow
# does not fail even if the Document has not yet been indexed.

View File

@@ -56,34 +56,46 @@ class DocExternalAccess:
)
@dataclass(frozen=True)
@dataclass(frozen=True, init=False)
class DocumentAccess(ExternalAccess):
# User emails for Onyx users, None indicates admin
user_emails: set[str | None]
# Names of user groups associated with this document
user_groups: set[str]
def to_acl(self) -> set[str]:
return set(
[
prefix_user_email(user_email)
for user_email in self.user_emails
if user_email
]
+ [prefix_user_group(group_name) for group_name in self.user_groups]
+ [
prefix_user_email(user_email)
for user_email in self.external_user_emails
]
+ [
# The group names are already prefixed by the source type
# This adds an additional prefix of "external_group:"
prefix_external_group(group_name)
for group_name in self.external_user_group_ids
]
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
external_user_emails: set[str]
external_user_group_ids: set[str]
is_public: bool
def __init__(self) -> None:
raise TypeError(
"Use `DocumentAccess.build(...)` instead of creating an instance directly."
)
def to_acl(self) -> set[str]:
# the acl's emitted by this function are prefixed by type
# to get the native objects, access the member variables directly
acl_set: set[str] = set()
for user_email in self.user_emails:
if user_email:
acl_set.add(prefix_user_email(user_email))
for group_name in self.user_groups:
acl_set.add(prefix_user_group(group_name))
for external_user_email in self.external_user_emails:
acl_set.add(prefix_user_email(external_user_email))
for external_group_id in self.external_user_group_ids:
acl_set.add(prefix_external_group(external_group_id))
if self.is_public:
acl_set.add(PUBLIC_DOC_PAT)
return acl_set
@classmethod
def build(
cls,
@@ -93,29 +105,32 @@ class DocumentAccess(ExternalAccess):
external_user_group_ids: list[str],
is_public: bool,
) -> "DocumentAccess":
return cls(
external_user_emails={
prefix_user_email(external_email)
for external_email in external_user_emails
},
external_user_group_ids={
prefix_external_group(external_group_id)
for external_group_id in external_user_group_ids
},
user_emails={
prefix_user_email(user_email)
for user_email in user_emails
if user_email
},
user_groups=set(user_groups),
is_public=is_public,
"""Don't prefix incoming data wth acl type, prefix on read from to_acl!"""
obj = object.__new__(cls)
object.__setattr__(
obj, "user_emails", {user_email for user_email in user_emails if user_email}
)
object.__setattr__(obj, "user_groups", set(user_groups))
object.__setattr__(
obj,
"external_user_emails",
{external_email for external_email in external_user_emails},
)
object.__setattr__(
obj,
"external_user_group_ids",
{external_group_id for external_group_id in external_user_group_ids},
)
object.__setattr__(obj, "is_public", is_public)
return obj
default_public_access = DocumentAccess(
external_user_emails=set(),
external_user_group_ids=set(),
user_emails=set(),
user_groups=set(),
default_public_access = DocumentAccess.build(
external_user_emails=[],
external_user_group_ids=[],
user_emails=[],
user_groups=[],
is_public=True,
)

View File

@@ -7,7 +7,6 @@ from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
@@ -24,7 +23,7 @@ def process_llm_stream(
should_stream_answer: bool,
writer: StreamWriter,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[OnyxContext] | list[LlmDoc] | None = None,
displayed_search_results: list[LlmDoc] | None = None,
) -> AIMessageChunk:
tool_call_chunk = AIMessageChunk(content="")

View File

@@ -156,7 +156,6 @@ def generate_initial_answer(
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,

View File

@@ -183,7 +183,6 @@ def generate_validate_refined_answer(
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,

View File

@@ -57,7 +57,6 @@ def format_results(
for tool_response in yield_search_responses(
query=state.question,
get_retrieved_sections=lambda: reranked_documents,
get_reranked_sections=lambda: state.retrieved_documents,
get_final_context_sections=lambda: reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,

View File

@@ -13,9 +13,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
)
from onyx.tools.tool_implementations.search.search_utils import section_to_llm_doc
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
@@ -59,9 +57,7 @@ def basic_use_tool_response(
search_response_summary = cast(SearchResponseSummary, yield_item.response)
for section in search_response_summary.top_sections:
if section.center_chunk.document_id not in initial_search_results:
initial_search_results.append(
context_from_inference_section(section)
)
initial_search_results.append(section_to_llm_doc(section))
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.behavior.skip_gen_ai_answer_generation:

View File

@@ -16,10 +16,10 @@ from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME
from onyx.configs.constants import ONYX_SLACK_URL
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.models import User
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.file import FileWithMimeType
from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import MULTI_TENANT
@@ -62,6 +62,11 @@ HTML_EMAIL_TEMPLATE = """\
}}
.header img {{
max-width: 140px;
width: 140px;
height: auto;
filter: brightness(1.1) contrast(1.2);
border-radius: 8px;
padding: 5px;
}}
.body-content {{
padding: 20px 30px;
@@ -78,12 +83,16 @@ HTML_EMAIL_TEMPLATE = """\
}}
.cta-button {{
display: inline-block;
padding: 12px 20px;
background-color: #000000;
padding: 14px 24px;
background-color: #0055FF;
color: #ffffff !important;
text-decoration: none;
border-radius: 4px;
font-weight: 500;
font-weight: 600;
font-size: 16px;
margin-top: 10px;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);
text-align: center;
}}
.footer {{
font-size: 13px;
@@ -166,6 +175,7 @@ def send_email(
if not EMAIL_CONFIGURED:
raise ValueError("Email is not configured.")
# Create a multipart/alternative message - this indicates these are alternative versions of the same content
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["To"] = user_email
@@ -174,17 +184,30 @@ def send_email(
msg["Date"] = formatdate(localtime=True)
msg["Message-ID"] = make_msgid(domain="onyx.app")
part_text = MIMEText(text_body, "plain")
part_html = MIMEText(html_body, "html")
msg.attach(part_text)
msg.attach(part_html)
# Add text part first (lowest priority)
text_part = MIMEText(text_body, "plain")
msg.attach(text_part)
if inline_png:
# For HTML with images, create a multipart/related container
related = MIMEMultipart("related")
# Add the HTML part to the related container
html_part = MIMEText(html_body, "html")
related.attach(html_part)
# Add image with proper Content-ID to the related container
img = MIMEImage(inline_png[1], _subtype="png")
img.add_header("Content-ID", inline_png[0]) # CID reference
img.add_header("Content-ID", f"<{inline_png[0]}>")
img.add_header("Content-Disposition", "inline", filename=inline_png[0])
msg.attach(img)
related.attach(img)
# Add the related part to the message (higher priority than text)
msg.attach(related)
else:
# No images, just add HTML directly (higher priority than text)
html_part = MIMEText(html_body, "html")
msg.attach(html_part)
try:
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
@@ -332,17 +355,23 @@ def send_forgot_password_email(
onyx_file = OnyxRuntime.get_emailable_logo()
subject = f"{application_name} Forgot Password"
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
if MULTI_TENANT:
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
subject = f"Reset Your {application_name} Password"
heading = "Reset Your Password"
tenant_param = f"&tenant={tenant_id}" if tenant_id and MULTI_TENANT else ""
message = "<p>Please click the button below to reset your password. This link will expire in 24 hours.</p>"
cta_text = "Reset Password"
cta_link = f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
html_content = build_html_email(
application_name,
"Reset Your Password",
heading,
message,
cta_text,
cta_link,
)
text_content = (
f"Please click the following link to reset your password. This link will expire in 24 hours.\n"
f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}"
)
text_content = f"Click the following link to reset your password: {link}"
send_email(
user_email,
subject,
@@ -356,6 +385,7 @@ def send_forgot_password_email(
def send_user_verification_email(
user_email: str,
token: str,
new_organization: bool = False,
mail_from: str = EMAIL_FROM,
) -> None:
# Builds a verification email
@@ -372,6 +402,8 @@ def send_user_verification_email(
subject = f"{application_name} Email Verification"
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
if new_organization:
link = add_url_params(link, {"first_user": "true"})
message = (
f"<p>Click the following link to verify your email address:</p><p>{link}</p>"
)

View File

@@ -0,0 +1,211 @@
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
import httpx
from fastapi_users.manager import BaseUserManager
from sqlalchemy.ext.asyncio import AsyncSession
from onyx.configs.app_configs import OAUTH_CLIENT_ID
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
from onyx.db.models import OAuthAccount
from onyx.db.models import User
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Standard OAuth refresh token endpoints
REFRESH_ENDPOINTS = {
"google": "https://oauth2.googleapis.com/token",
}
# NOTE: Keeping this as a utility function for potential future debugging,
# but not using it in production code
async def _test_expire_oauth_token(
user: User,
oauth_account: OAuthAccount,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
expire_in_seconds: int = 10,
) -> bool:
"""
Utility function for testing - Sets an OAuth token to expire in a short time
to facilitate testing of the refresh flow.
Not used in production code.
"""
try:
new_expires_at = int(
(datetime.now(timezone.utc).timestamp() + expire_in_seconds)
)
updated_data: Dict[str, Any] = {"expires_at": new_expires_at}
await user_manager.user_db.update_oauth_account(
user, cast(Any, oauth_account), updated_data
)
return True
except Exception as e:
logger.exception(f"Error setting artificial expiration: {str(e)}")
return False
async def refresh_oauth_token(
user: User,
oauth_account: OAuthAccount,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
) -> bool:
"""
Attempt to refresh an OAuth token that's about to expire or has expired.
Returns True if successful, False otherwise.
"""
if not oauth_account.refresh_token:
logger.warning(
f"No refresh token available for {user.email}'s {oauth_account.oauth_name} account"
)
return False
provider = oauth_account.oauth_name
if provider not in REFRESH_ENDPOINTS:
logger.warning(f"Refresh endpoint not configured for provider: {provider}")
return False
try:
logger.info(f"Refreshing OAuth token for {user.email}'s {provider} account")
async with httpx.AsyncClient() as client:
response = await client.post(
REFRESH_ENDPOINTS[provider],
data={
"client_id": OAUTH_CLIENT_ID,
"client_secret": OAUTH_CLIENT_SECRET,
"refresh_token": oauth_account.refresh_token,
"grant_type": "refresh_token",
},
headers={"Content-Type": "application/x-www-form-urlencoded"},
)
if response.status_code != 200:
logger.error(
f"Failed to refresh OAuth token: Status {response.status_code}"
)
return False
token_data = response.json()
new_access_token = token_data.get("access_token")
new_refresh_token = token_data.get(
"refresh_token", oauth_account.refresh_token
)
expires_in = token_data.get("expires_in")
# Calculate new expiry time if provided
new_expires_at: Optional[int] = None
if expires_in:
new_expires_at = int(
(datetime.now(timezone.utc).timestamp() + expires_in)
)
# Update the OAuth account
updated_data: Dict[str, Any] = {
"access_token": new_access_token,
"refresh_token": new_refresh_token,
}
if new_expires_at:
updated_data["expires_at"] = new_expires_at
# Update oidc_expiry in user model if we're tracking it
if TRACK_EXTERNAL_IDP_EXPIRY:
oidc_expiry = datetime.fromtimestamp(
new_expires_at, tz=timezone.utc
)
await user_manager.user_db.update(
user, {"oidc_expiry": oidc_expiry}
)
# Update the OAuth account
await user_manager.user_db.update_oauth_account(
user, cast(Any, oauth_account), updated_data
)
logger.info(f"Successfully refreshed OAuth token for {user.email}")
return True
except Exception as e:
logger.exception(f"Error refreshing OAuth token: {str(e)}")
return False
async def check_and_refresh_oauth_tokens(
user: User,
db_session: AsyncSession,
user_manager: BaseUserManager[User, Any],
) -> None:
"""
Check if any OAuth tokens are expired or about to expire and refresh them.
"""
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
return
now_timestamp = datetime.now(timezone.utc).timestamp()
# Buffer time to refresh tokens before they expire (in seconds)
buffer_seconds = 300 # 5 minutes
for oauth_account in user.oauth_accounts:
# Skip accounts without refresh tokens
if not oauth_account.refresh_token:
continue
# If token is about to expire, refresh it
if (
oauth_account.expires_at
and oauth_account.expires_at - now_timestamp < buffer_seconds
):
logger.info(f"OAuth token for {user.email} is about to expire - refreshing")
success = await refresh_oauth_token(
user, oauth_account, db_session, user_manager
)
if not success:
logger.warning(
"Failed to refresh OAuth token. User may need to re-authenticate."
)
async def check_oauth_account_has_refresh_token(
user: User,
oauth_account: OAuthAccount,
) -> bool:
"""
Check if an OAuth account has a refresh token.
Returns True if a refresh token exists, False otherwise.
"""
return bool(oauth_account.refresh_token)
async def get_oauth_accounts_requiring_refresh_token(user: User) -> List[OAuthAccount]:
"""
Returns a list of OAuth accounts for a user that are missing refresh tokens.
These accounts will need re-authentication to get refresh tokens.
"""
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
return []
accounts_needing_refresh = []
for oauth_account in user.oauth_accounts:
has_refresh_token = await check_oauth_account_has_refresh_token(
user, oauth_account
)
if not has_refresh_token:
accounts_needing_refresh.append(oauth_account)
return accounts_needing_refresh

View File

@@ -26,6 +26,7 @@ class UserRole(str, Enum):
SLACK_USER = "slack_user"
EXT_PERM_USER = "ext_perm_user"
@property
def is_web_login(self) -> bool:
return self not in [
UserRole.SLACK_USER,

View File

@@ -5,12 +5,16 @@ import string
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from typing import Dict
from typing import List
from typing import Optional
from typing import Protocol
from typing import Tuple
from typing import TypeVar
import jwt
from email_validator import EmailNotValidError
@@ -315,7 +319,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.role.is_web_login() and user_create.role.is_web_login():
if not user.role.is_web_login and user_create.role.is_web_login:
user_update = UserUpdateWithRole(
password=user_create.password,
is_verified=user_create.is_verified,
@@ -486,7 +490,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.role.is_web_login():
if not user.role.is_web_login:
await self.user_db.update(
user,
{
@@ -581,8 +585,10 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
logger.notice(
f"Verification requested for user {user.id}. Verification token: {token}"
)
send_user_verification_email(user.email, token)
user_count = await get_user_count()
send_user_verification_email(
user.email, token, new_organization=user_count == 1
)
async def authenticate(
self, credentials: OAuth2PasswordRequestForm
@@ -623,7 +629,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
self.password_helper.hash(credentials.password)
return None
if not user.role.is_web_login():
if not user.role.is_web_login:
raise BasicAuthenticationError(
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
@@ -688,16 +694,20 @@ cookie_transport = CookieTransport(
)
def get_redis_strategy() -> RedisStrategy:
return TenantAwareRedisStrategy()
T = TypeVar("T", covariant=True)
ID = TypeVar("ID", contravariant=True)
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> DatabaseStrategy:
return DatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
)
# Protocol for strategies that support token refreshing without inheritance.
class RefreshableStrategy(Protocol):
"""Protocol for authentication strategies that support token refreshing."""
async def refresh_token(self, token: Optional[str], user: Any) -> str:
"""
Refresh an existing token by extending its lifetime.
Returns either the same token with extended expiration or a new token.
"""
...
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
@@ -756,6 +766,75 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
redis = await get_async_redis_connection()
await redis.delete(f"{self.key_prefix}{token}")
async def refresh_token(self, token: Optional[str], user: User) -> str:
"""Refresh a token by extending its expiration time in Redis."""
if token is None:
# If no token provided, create a new one
return await self.write_token(user)
redis = await get_async_redis_connection()
token_key = f"{self.key_prefix}{token}"
# Check if token exists
token_data_str = await redis.get(token_key)
if not token_data_str:
# Token not found, create new one
return await self.write_token(user)
# Token exists, extend its lifetime
token_data = json.loads(token_data_str)
await redis.set(
token_key,
json.dumps(token_data),
ex=self.lifetime_seconds,
)
return token
class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]):
"""Database strategy with token refreshing capabilities."""
def __init__(
self,
access_token_db: AccessTokenDatabase[AccessToken],
lifetime_seconds: Optional[int] = None,
):
super().__init__(access_token_db, lifetime_seconds)
self._access_token_db = access_token_db
async def refresh_token(self, token: Optional[str], user: User) -> str:
"""Refresh a token by updating its expiration time in the database."""
if token is None:
return await self.write_token(user)
# Find the token in database
access_token = await self._access_token_db.get_by_token(token)
if access_token is None:
# Token not found, create new one
return await self.write_token(user)
# Update expiration time
new_expires = datetime.now(timezone.utc) + timedelta(
seconds=float(self.lifetime_seconds or SESSION_EXPIRE_TIME_SECONDS)
)
await self._access_token_db.update(access_token, {"expires": new_expires})
return token
def get_redis_strategy() -> TenantAwareRedisStrategy:
return TenantAwareRedisStrategy()
def get_database_strategy(
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
) -> RefreshableDatabaseStrategy:
return RefreshableDatabaseStrategy(
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
)
if AUTH_BACKEND == AuthBackend.REDIS:
auth_backend = AuthenticationBackend(
@@ -806,6 +885,88 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
return router
def get_refresh_router(
self,
backend: AuthenticationBackend,
requires_verification: bool = REQUIRE_EMAIL_VERIFICATION,
) -> APIRouter:
"""
Provide a router for session token refreshing.
"""
# Import the oauth_refresher here to avoid circular imports
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
router = APIRouter()
get_current_user_token = self.authenticator.current_user_token(
active=True, verified=requires_verification
)
refresh_responses: OpenAPIResponseType = {
**{
status.HTTP_401_UNAUTHORIZED: {
"description": "Missing token or inactive user."
}
},
**backend.transport.get_openapi_login_responses_success(),
}
@router.post(
"/refresh", name=f"auth:{backend.name}.refresh", responses=refresh_responses
)
async def refresh(
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
user_manager: BaseUserManager[models.UP, models.ID] = Depends(
get_user_manager
),
db_session: AsyncSession = Depends(get_async_session),
) -> Response:
try:
user, token = user_token
logger.info(f"Processing token refresh request for user {user.email}")
# Check if user has OAuth accounts that need refreshing
await check_and_refresh_oauth_tokens(
user=cast(User, user),
db_session=db_session,
user_manager=cast(Any, user_manager),
)
# Check if strategy supports refreshing
supports_refresh = hasattr(strategy, "refresh_token") and callable(
getattr(strategy, "refresh_token")
)
if supports_refresh:
try:
refresh_method = getattr(strategy, "refresh_token")
new_token = await refresh_method(token, user)
logger.info(
f"Successfully refreshed session token for user {user.email}"
)
return await backend.transport.get_login_response(new_token)
except Exception as e:
logger.error(f"Error refreshing session token: {str(e)}")
# Fallback to logout and login if refresh fails
await backend.logout(strategy, user, token)
return await backend.login(strategy, user)
# Fallback: logout and login again
logger.info(
"Strategy doesn't support refresh - using logout/login flow"
)
await backend.logout(strategy, user, token)
return await backend.login(strategy, user)
except Exception as e:
logger.error(f"Unexpected error in refresh endpoint: {str(e)}")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Token refresh failed: {str(e)}",
)
return router
fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
get_user_manager, [auth_backend]
@@ -1039,12 +1200,20 @@ def get_oauth_router(
"referral_source": referral_source or "default_referral",
}
state = generate_state_token(state_data, state_secret)
# Get the basic authorization URL
authorization_url = await oauth_client.get_authorization_url(
authorize_redirect_url,
state,
scopes,
)
# For Google OAuth, add parameters to request refresh tokens
if oauth_client.name == "google":
authorization_url = add_url_params(
authorization_url, {"access_type": "offline", "prompt": "consent"}
)
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
@router.get(

View File

@@ -34,7 +34,6 @@ from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGrou
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import ColoredFormatter
from onyx.utils.logger import PlainFormatter
@@ -225,7 +224,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
Will raise WorkerShutdown to kill the celery worker if the timeout
is reached."""
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
@@ -311,7 +310,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
time_start = time.monotonic()
logger.info("Waiting for primary worker to be ready...")

View File

@@ -1,6 +1,5 @@
from datetime import timedelta
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
@@ -10,12 +9,10 @@ from celery.utils.log import get_task_logger
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import SqlEngine
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
@@ -141,8 +138,6 @@ class DynamicTenantScheduler(PersistentScheduler):
"""Only updates the actual beat schedule on the celery app when it changes"""
do_update = False
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
task_logger.debug("_try_updating_schedule starting")
tenant_ids = get_all_tenant_ids()
@@ -152,16 +147,7 @@ class DynamicTenantScheduler(PersistentScheduler):
current_schedule = self.schedule.items()
# get potential new state
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
if beat_multiplier_raw is not None:
try:
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
beat_multiplier = float(beat_multiplier_bytes.decode())
except ValueError:
task_logger.error(
f"Invalid beat_multiplier value: {beat_multiplier_raw}"
)
beat_multiplier = OnyxRuntime.get_beat_multiplier()
new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)

View File

@@ -38,10 +38,11 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@@ -102,7 +103,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
# This is singleton work that should be done on startup exactly once
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
info: dict[str, Any] = cast(dict, r.info("replication"))
@@ -235,7 +236,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
lock: RedisLock = worker.primary_worker_lock
r = get_shared_redis_client()
r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA)
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")

View File

@@ -14,7 +14,7 @@ logger = setup_logger()
# Only set up memory monitoring in container environment
if is_running_in_container():
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_DIR = "/var/log/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files

View File

@@ -21,6 +21,7 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# we have a better implementation (backpressure, etc)
# Note that DynamicTenantScheduler can adjust the runtime value for this via Redis
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT = 1.0
# tasks that run in either self-hosted on cloud
beat_task_templates: list[dict] = []

View File

@@ -30,6 +30,9 @@ from onyx.db.connector_credential_pair import (
)
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import (
delete_all_documents_by_connector_credential_pair__no_commit,
)
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from onyx.db.engine import get_session_with_current_tenant
@@ -386,6 +389,8 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
credential_id_to_delete: int | None = None
connector_id_to_delete: int | None = None
if not cc_pair:
task_logger.warning(
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
@@ -440,16 +445,35 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
)
# Store IDs before potentially expiring cc_pair
connector_id_to_delete = cc_pair.connector_id
credential_id_to_delete = cc_pair.credential_id
# Explicitly delete document by connector credential pair records before deleting the connector
# This is needed because connector_id is a primary key in that table and cascading deletes won't work
delete_all_documents_by_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=connector_id_to_delete,
credential_id=credential_id_to_delete,
)
# Flush to ensure document deletion happens before connector deletion
db_session.flush()
# Expire the cc_pair to ensure SQLAlchemy doesn't try to manage its state
# related to the deleted DocumentByConnectorCredentialPair during commit
db_session.expire(cc_pair)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
connector_id=connector_id_to_delete,
credential_id=credential_id_to_delete,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=cc_pair.connector_id,
connector_id=connector_id_to_delete,
)
if not connector or not len(connector.credentials):
task_logger.info(
@@ -482,15 +506,15 @@ def monitor_connector_deletion_taskset(
task_logger.exception(
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
)
raise e
task_logger.info(
f"Connector deletion succeeded: "
f"cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"connector={connector_id_to_delete} "
f"credential={credential_id_to_delete} "
f"docs_deleted={fence_data.num_tasks}"
)
@@ -540,7 +564,7 @@ def validate_connector_deletion_fences(
def validate_connector_deletion_fence(
tenant_id: str,
key_bytes: bytes,
queued_tasks: set[str],
queued_upsert_tasks: set[str],
r: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
@@ -627,7 +651,7 @@ def validate_connector_deletion_fence(
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_tasks:
if member_str in queued_upsert_tasks:
continue
tasks_not_in_celery += 1

View File

@@ -17,6 +17,7 @@ from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.document import upsert_document_external_perms
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
@@ -63,11 +64,14 @@ from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyn
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.server.utils import make_short_id
from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
logger = setup_logger()
@@ -104,9 +108,10 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
if not source_sync_period:
return True
source_sync_period = DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
source_sync_period *= int(OnyxRuntime.get_doc_permission_sync_multiplier())
# If the last sync is greater than the full fetch period, we run the sync
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
@@ -284,7 +289,7 @@ def try_creating_permissions_sync_task(
),
queue=OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.HIGH,
priority=OnyxCeleryPriority.MEDIUM,
)
# fill in the celery task id
@@ -875,6 +880,21 @@ def monitor_ccpair_permissions_taskset(
f"remaining={remaining} "
f"initial={initial}"
)
# Add telemetry for permission syncing progress
optional_telemetry(
record_type=RecordType.PERMISSION_SYNC_PROGRESS,
data={
"cc_pair_id": cc_pair_id,
"id": payload.id if payload else None,
"total_docs": initial if initial is not None else 0,
"remaining_docs": remaining,
"synced_docs": (initial - remaining) if initial is not None else 0,
"is_complete": remaining == 0,
},
tenant_id=tenant_id,
)
if remaining > 0:
return

View File

@@ -271,7 +271,7 @@ def try_creating_external_group_sync_task(
),
queue=OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.HIGH,
priority=OnyxCeleryPriority.MEDIUM,
)
payload.celery_task_id = result.id

View File

@@ -72,6 +72,7 @@ from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_utils import is_fence
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
@@ -401,7 +402,11 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
logger.warning(f"Adding {key_bytes} to the lookup table.")
redis_client.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
redis_client.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
redis_client.set(
OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE,
1,
ex=OnyxRuntime.get_build_fence_lookup_table_interval(),
)
# 1/3: KICKOFF

View File

@@ -56,9 +56,12 @@ from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.redis.redis_connector import RedisConnector
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
@@ -435,7 +438,7 @@ def _run_indexing(
while checkpoint.has_more:
logger.info(
f"Running '{ctx.source}' connector with checkpoint: {checkpoint}"
f"Running '{ctx.source.value}' connector with checkpoint: {checkpoint}"
)
for document_batch, failure, next_checkpoint in connector_runner.run(
checkpoint
@@ -570,6 +573,22 @@ def _run_indexing(
if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))
# Add telemetry for indexing progress
optional_telemetry(
record_type=RecordType.INDEXING_PROGRESS,
data={
"index_attempt_id": index_attempt_id,
"cc_pair_id": ctx.cc_pair_id,
"connector_id": ctx.connector_id,
"credential_id": ctx.credential_id,
"total_docs_indexed": document_count,
"total_chunks": chunk_count,
"batch_num": batch_num,
"source": ctx.source.value,
},
tenant_id=tenant_id,
)
memory_tracer.increment_and_maybe_trace()
# `make sure the checkpoints aren't getting too large`at some regular interval
@@ -585,6 +604,30 @@ def _run_indexing(
checkpoint=checkpoint,
)
# Add telemetry for completed indexing
redis_connector = RedisConnector(tenant_id, ctx.cc_pair_id)
redis_connector_index = redis_connector.new_index(
index_attempt_start.search_settings_id
)
final_progress = redis_connector_index.get_progress() or 0
optional_telemetry(
record_type=RecordType.INDEXING_COMPLETE,
data={
"index_attempt_id": index_attempt_id,
"cc_pair_id": ctx.cc_pair_id,
"connector_id": ctx.connector_id,
"credential_id": ctx.credential_id,
"total_docs_indexed": document_count,
"total_chunks": chunk_count,
"batch_count": batch_num,
"time_elapsed_seconds": time.monotonic() - start_time,
"source": ctx.source.value,
"redis_progress": final_progress,
},
tenant_id=tenant_id,
)
except Exception as e:
logger.exception(
"Connector run exceptioned after elapsed time: "

View File

@@ -194,17 +194,6 @@ class StreamingError(BaseModel):
stack_trace: str | None = None
class OnyxContext(BaseModel):
content: str
document_id: str
semantic_identifier: str
blurb: str
class OnyxContexts(BaseModel):
contexts: list[OnyxContext]
class OnyxAnswer(BaseModel):
answer: str | None
@@ -270,7 +259,6 @@ class PersonaOverrideConfig(BaseModel):
AnswerQuestionPossibleReturn = (
OnyxAnswerPiece
| CitationInfo
| OnyxContexts
| FileChatDisplay
| CustomToolResponse
| StreamingError

View File

@@ -29,7 +29,6 @@ from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import MessageSpecificCitations
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
@@ -73,6 +72,7 @@ from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
from onyx.db.chat import update_chat_session_updated_at_timestamp
from onyx.db.engine import get_session_context_manager
from onyx.db.milestone import check_multi_assistant_milestone
from onyx.db.milestone import create_milestone_if_not_exists
@@ -130,7 +130,6 @@ from onyx.tools.tool_implementations.internet_search.internet_search_tool import
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
@@ -299,7 +298,6 @@ def _get_force_search_settings(
ChatPacket = (
StreamingError
| QADocsResponse
| OnyxContexts
| LLMRelevanceFilterResponse
| FinalUsedContextDocsResponse
| ChatMessageDetail
@@ -918,8 +916,6 @@ def stream_chat_message_objects(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
yield cast(OnyxContexts, packet.response)
elif isinstance(packet, StreamStopInfo):
if packet.stop_reason == StreamStopReason.FINISHED:
@@ -1069,6 +1065,8 @@ def stream_chat_message_objects(
prev_message = next_answer_message
logger.debug("Committing messages")
# Explicitly update the timestamp on the chat session
update_chat_session_updated_at_timestamp(chat_session_id, db_session)
db_session.commit() # actually save user / assistant message
yield AgenticMessageResponseIDInfo(agentic_message_ids=agentic_message_ids)

View File

@@ -301,6 +301,10 @@ def prune_sections(
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
assert (
len(set([chunk.document_id for chunk in chunks])) == 1
), "One distinct document must be passed into merge_doc_chunks"
# Assuming there are no duplicates by this point
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)

View File

@@ -3,7 +3,6 @@ from collections.abc import Sequence
from pydantic import BaseModel
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.context.search.models import InferenceChunk
@@ -12,7 +11,7 @@ class DocumentIdOrderMapping(BaseModel):
def map_document_id_order(
chunks: Sequence[InferenceChunk | LlmDoc | OnyxContext], one_indexed: bool = True
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
) -> DocumentIdOrderMapping:
order_mapping = {}
current = 1 if one_indexed else 0

View File

@@ -1,6 +1,8 @@
import json
import os
import urllib.parse
from datetime import datetime
from datetime import timezone
from typing import cast
from onyx.auth.schemas import AuthBackend
@@ -157,10 +159,7 @@ VESPA_CLOUD_CERT_PATH = os.environ.get("VESPA_CLOUD_CERT_PATH")
VESPA_CLOUD_KEY_PATH = os.environ.get("VESPA_CLOUD_KEY_PATH")
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
try:
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE", 16))
except ValueError:
INDEX_BATCH_SIZE = 16
INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE") or 16)
MAX_DRIVE_WORKERS = int(os.environ.get("MAX_DRIVE_WORKERS", 4))
@@ -386,10 +385,27 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
# https://jira.atlassian.com/browse/CONFCLOUD-69670
def get_current_tz_offset() -> int:
# datetime now() gets local time, datetime.now(timezone.utc) gets UTC time.
# remove tzinfo to compare non-timezone-aware objects.
time_diff = datetime.now() - datetime.now(timezone.utc).replace(tzinfo=None)
return round(time_diff.total_seconds() / 3600)
# enter as a floating point offset from UTC in hours (-24 < val < 24)
# this will be applied globally, so it probably makes sense to transition this to per
# connector as some point.
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
# For the default value, we assume that the user's local timezone is more likely to be
# correct (i.e. the configured user's timezone or the default server one) than UTC.
# https://developer.atlassian.com/cloud/confluence/cql-fields/#created
CONFLUENCE_TIMEZONE_OFFSET = float(
os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset())
)
GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int(
os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024)
)
JIRA_CONNECTOR_LABELS_TO_SKIP = [
ignored_tag
@@ -676,3 +692,7 @@ IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
"IMAGE_ANALYSIS_SYSTEM_PROMPT",
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT,
)
DISABLE_AUTO_AUTH_REFRESH = (
os.environ.get("DISABLE_AUTO_AUTH_REFRESH", "").lower() == "true"
)

View File

@@ -382,6 +382,7 @@ ONYX_CLOUD_TENANT_ID = "cloud"
# the redis namespace for runtime variables
ONYX_CLOUD_REDIS_RUNTIME = "runtime"
CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT = 600
class OnyxCeleryTask:

View File

@@ -87,7 +87,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
credentials.get(key)
for key in ["aws_access_key_id", "aws_secret_access_key"]
):
raise ConnectorMissingCredentialError("Google Cloud Storage")
raise ConnectorMissingCredentialError("Amazon S3")
session = boto3.Session(
aws_access_key_id=credentials["aws_access_key_id"],

View File

@@ -65,19 +65,7 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
_SLIM_DOC_BATCH_SIZE = 5000
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
"gif",
"mp4",
"mov",
"mp3",
"wav",
]
_FULL_EXTENSION_FILTER_STRING = "".join(
[
f" and title!~'*.{extension}'"
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
]
)
ONE_HOUR = 3600
class ConfluenceConnector(
@@ -207,7 +195,6 @@ class ConfluenceConnector(
def _construct_attachment_query(self, confluence_page_id: str) -> str:
attachment_query = f"type=attachment and container='{confluence_page_id}'"
attachment_query += self.cql_label_filter
attachment_query += _FULL_EXTENSION_FILTER_STRING
return attachment_query
def _get_comment_string_for_page_id(self, page_id: str) -> str:
@@ -372,11 +359,13 @@ class ConfluenceConnector(
if not validate_attachment_filetype(
attachment,
):
logger.info(f"Skipping attachment: {attachment['title']}")
continue
logger.info(f"Processing attachment: {attachment['title']}")
# Attempt to get textual content or image summarization:
try:
logger.info(f"Processing attachment: {attachment['title']}")
response = convert_attachment_to_content(
confluence_client=self.confluence_client,
attachment=attachment,
@@ -429,7 +418,17 @@ class ConfluenceConnector(
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
return self._fetch_document_batches(start, end)
try:
return self._fetch_document_batches(start, end)
except Exception as e:
if "field 'updated' is invalid" in str(e) and start is not None:
logger.warning(
"Confluence says we provided an invalid 'updated' field. This may indicate"
"a real issue, but can also appear during edge cases like daylight"
f"savings time changes. Retrying with a 1 hour offset. Error: {e}"
)
return self._fetch_document_batches(start - ONE_HOUR, end)
raise
def retrieve_all_slim_documents(
self,

View File

@@ -498,10 +498,12 @@ class OnyxConfluence:
new_start = get_start_param_from_url(url_suffix)
previous_start = get_start_param_from_url(old_url_suffix)
if new_start - previous_start > len(results):
logger.warning(
logger.debug(
f"Start was updated by more than the amount of results "
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
f"Previous Start: {previous_start}, Len Results: {len(results)}."
f"retrieved for `{url_suffix}`. This is a bug with Confluence, "
"but we have logic to work around it - don't worry this isn't"
f" causing an issue. Start: {new_start}, Previous Start: "
f"{previous_start}, Len Results: {len(results)}."
)
# Update the url_suffix to use the adjusted start

View File

@@ -28,8 +28,9 @@ from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import is_text_file_extension
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.extract_file_text import read_text_file
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
@@ -69,7 +70,9 @@ def _process_egnyte_file(
file_name = file_metadata["name"]
extension = get_file_ext(file_name)
if not is_valid_file_ext(extension):
if not is_accepted_file_ext(
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
):
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
return None

View File

@@ -22,8 +22,9 @@ from onyx.db.engine import get_session_with_current_tenant
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
@@ -51,7 +52,7 @@ def _read_files_and_metadata(
file_content, ignore_dirs=True
):
yield os.path.join(directory_path, file_info.filename), subfile, metadata
elif is_valid_file_ext(extension):
elif is_accepted_file_ext(extension, OnyxExtensionType.All):
yield file_name, file_content, metadata
else:
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
@@ -122,7 +123,7 @@ def _process_file(
logger.warning(f"No file record found for '{file_name}' in PG; skipping.")
return []
if not is_valid_file_ext(extension):
if not is_accepted_file_ext(extension, OnyxExtensionType.All):
logger.warning(
f"Skipping file '{file_name}' with unrecognized extension '{extension}'"
)
@@ -219,24 +220,34 @@ def _process_file(
# 2) Otherwise: text-based approach. Possibly with embedded images.
file.seek(0)
text_content = ""
embedded_images: list[tuple[bytes, str]] = []
# Extract text and images from the file
text_content, embedded_images = extract_text_and_images(
extraction_result = extract_text_and_images(
file=file,
file_name=file_name,
pdf_pass=pdf_pass,
)
# Merge file-specific metadata (from file content) with provided metadata
if extraction_result.metadata:
logger.debug(
f"Found file-specific metadata for {file_name}: {extraction_result.metadata}"
)
metadata.update(extraction_result.metadata)
# Build sections: first the text as a single Section
sections: list[TextSection | ImageSection] = []
link_in_meta = metadata.get("link")
if text_content.strip():
sections.append(TextSection(link=link_in_meta, text=text_content.strip()))
if extraction_result.text_content.strip():
logger.debug(f"Creating TextSection for {file_name} with link: {link_in_meta}")
sections.append(
TextSection(link=link_in_meta, text=extraction_result.text_content.strip())
)
# Then any extracted images from docx, etc.
for idx, (img_data, img_name) in enumerate(embedded_images, start=1):
for idx, (img_data, img_name) in enumerate(
extraction_result.embedded_images, start=1
):
# Store each embedded image as a separate file in PGFileStore
# and create a section with the image reference
try:

View File

@@ -1,8 +1,10 @@
import copy
import time
from collections.abc import Iterator
from collections.abc import Generator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from enum import Enum
from typing import Any
from typing import cast
@@ -13,26 +15,30 @@ from github.GithubException import GithubException
from github.Issue import Issue
from github.PaginatedList import PaginatedList
from github.PullRequest import PullRequest
from github.Requester import Requester
from pydantic import BaseModel
from typing_extensions import override
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import ConnectorFailure
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import TextSection
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
logger = setup_logger()
ITEMS_PER_PAGE = 100
_MAX_NUM_RATE_LIMIT_RETRIES = 5
@@ -48,7 +54,7 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
def _get_batch_rate_limited(
git_objs: PaginatedList, page_num: int, github_client: Github, attempt_num: int = 0
) -> list[Any]:
) -> list[PullRequest | Issue]:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github"
@@ -69,21 +75,6 @@ def _get_batch_rate_limited(
)
def _batch_github_objects(
git_objs: PaginatedList, github_client: Github, batch_size: int
) -> Iterator[list[Any]]:
page_num = 0
while True:
batch = _get_batch_rate_limited(git_objs, page_num, github_client)
page_num += 1
if not batch:
break
for mini_batch in batch_generator(batch, batch_size=batch_size):
yield mini_batch
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
return Document(
id=pull_request.html_url,
@@ -95,7 +86,9 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
# updated_at is UTC time but is timezone unaware, explicitly add UTC
# as there is logic in indexing to prevent wrong timestamped docs
# due to local time discrepancies with UTC
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc),
doc_updated_at=pull_request.updated_at.replace(tzinfo=timezone.utc)
if pull_request.updated_at
else None,
metadata={
"merged": str(pull_request.merged),
"state": pull_request.state,
@@ -122,31 +115,58 @@ def _convert_issue_to_document(issue: Issue) -> Document:
)
class GithubConnector(LoadConnector, PollConnector):
class SerializedRepository(BaseModel):
# id is part of the raw_data as well, just pulled out for convenience
id: int
headers: dict[str, str | int]
raw_data: dict[str, Any]
def to_Repository(self, requester: Requester) -> Repository.Repository:
return Repository.Repository(
requester, self.headers, self.raw_data, completed=True
)
class GithubConnectorStage(Enum):
START = "start"
PRS = "prs"
ISSUES = "issues"
class GithubConnectorCheckpoint(ConnectorCheckpoint):
stage: GithubConnectorStage
curr_page: int
cached_repo_ids: list[int] | None = None
cached_repo: SerializedRepository | None = None
class GithubConnector(CheckpointConnector[GithubConnectorCheckpoint]):
def __init__(
self,
repo_owner: str,
repositories: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
state_filter: str = "all",
include_prs: bool = True,
include_issues: bool = False,
) -> None:
self.repo_owner = repo_owner
self.repositories = repositories
self.batch_size = batch_size
self.state_filter = state_filter
self.include_prs = include_prs
self.include_issues = include_issues
self.github_client: Github | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
# defaults to 30 items per page, can be set to as high as 100
self.github_client = (
Github(
credentials["github_access_token"], base_url=GITHUB_CONNECTOR_BASE_URL
credentials["github_access_token"],
base_url=GITHUB_CONNECTOR_BASE_URL,
per_page=ITEMS_PER_PAGE,
)
if GITHUB_CONNECTOR_BASE_URL
else Github(credentials["github_access_token"])
else Github(credentials["github_access_token"], per_page=ITEMS_PER_PAGE)
)
return None
@@ -217,85 +237,193 @@ class GithubConnector(LoadConnector, PollConnector):
return self._get_all_repos(github_client, attempt_num + 1)
def _fetch_from_github(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
self,
checkpoint: GithubConnectorCheckpoint,
start: datetime | None = None,
end: datetime | None = None,
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
repos = []
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self._get_github_repos(self.github_client)
checkpoint = copy.deepcopy(checkpoint)
# First run of the connector, fetch all repos and store in checkpoint
if checkpoint.cached_repo_ids is None:
repos = []
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self._get_github_repos(self.github_client)
else:
# Single repository (backward compatibility)
repos = [self._get_github_repo(self.github_client)]
else:
# Single repository (backward compatibility)
repos = [self._get_github_repo(self.github_client)]
else:
# All repositories
repos = self._get_all_repos(self.github_client)
# All repositories
repos = self._get_all_repos(self.github_client)
if not repos:
checkpoint.has_more = False
return checkpoint
for repo in repos:
if self.include_prs:
logger.info(f"Fetching PRs for repo: {repo.name}")
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
checkpoint.cached_repo_ids = sorted([repo.id for repo in repos])
checkpoint.cached_repo = SerializedRepository(
id=checkpoint.cached_repo_ids[0],
headers=repos[0].raw_headers,
raw_data=repos[0].raw_data,
)
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.curr_page = 0
# save checkpoint with repo ids retrieved
return checkpoint
for pr_batch in _batch_github_objects(
pull_requests, self.github_client, self.batch_size
assert checkpoint.cached_repo is not None, "No repo saved in checkpoint"
repo = checkpoint.cached_repo.to_Repository(self.github_client.requester)
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
logger.info(f"Fetching PRs for repo: {repo.name}")
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
doc_batch: list[Document] = []
pr_batch = _get_batch_rate_limited(
pull_requests, checkpoint.curr_page, self.github_client
)
checkpoint.curr_page += 1
done_with_prs = False
for pr in pr_batch:
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) < start
):
doc_batch: list[Document] = []
for pr in pr_batch:
if start is not None and pr.updated_at < start:
yield doc_batch
break
if end is not None and pr.updated_at > end:
continue
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
yield doc_batch
if self.include_issues:
logger.info(f"Fetching issues for repo: {repo.name}")
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
for issue_batch in _batch_github_objects(
issues, self.github_client, self.batch_size
yield from doc_batch
done_with_prs = True
break
# Skip PRs updated after the end date
if (
end is not None
and pr.updated_at
and pr.updated_at.replace(tzinfo=timezone.utc) > end
):
doc_batch = []
for issue in issue_batch:
issue = cast(Issue, issue)
if start is not None and issue.updated_at < start:
yield doc_batch
break
if end is not None and issue.updated_at > end:
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
doc_batch.append(_convert_issue_to_document(issue))
yield doc_batch
continue
try:
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
except Exception as e:
error_msg = f"Error converting PR to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(pr.id), document_link=pr.html_url
),
failure_message=error_msg,
exception=e,
)
continue
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_github()
# if we found any PRs on the page, yield any associated documents and return the checkpoint
if not done_with_prs and len(pr_batch) > 0:
yield from doc_batch
return checkpoint
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.utcfromtimestamp(start)
end_datetime = datetime.utcfromtimestamp(end)
# if we went past the start date during the loop or there are no more
# prs to get, we move on to issues
checkpoint.stage = GithubConnectorStage.ISSUES
checkpoint.curr_page = 0
checkpoint.stage = GithubConnectorStage.ISSUES
if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES:
logger.info(f"Fetching issues for repo: {repo.name}")
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
doc_batch = []
issue_batch = _get_batch_rate_limited(
issues, checkpoint.curr_page, self.github_client
)
checkpoint.curr_page += 1
done_with_issues = False
for issue in cast(list[Issue], issue_batch):
# we iterate backwards in time, so at this point we stop processing prs
if (
start is not None
and issue.updated_at.replace(tzinfo=timezone.utc) < start
):
yield from doc_batch
done_with_issues = True
break
# Skip PRs updated after the end date
if (
end is not None
and issue.updated_at.replace(tzinfo=timezone.utc) > end
):
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
try:
doc_batch.append(_convert_issue_to_document(issue))
except Exception as e:
error_msg = f"Error converting issue to document: {e}"
logger.exception(error_msg)
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=str(issue.id),
document_link=issue.html_url,
),
failure_message=error_msg,
exception=e,
)
continue
# if we found any issues on the page, yield them and return the checkpoint
if not done_with_issues and len(issue_batch) > 0:
yield from doc_batch
return checkpoint
# if we went past the start date during the loop or there are no more
# issues to get, we move on to the next repo
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.curr_page = 0
checkpoint.has_more = len(checkpoint.cached_repo_ids) > 1
if checkpoint.cached_repo_ids:
next_id = checkpoint.cached_repo_ids.pop()
next_repo = self.github_client.get_repo(next_id)
checkpoint.cached_repo = SerializedRepository(
id=next_id,
headers=next_repo.raw_headers,
raw_data=next_repo.raw_data,
)
return checkpoint
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: GithubConnectorCheckpoint,
) -> CheckpointOutput[GithubConnectorCheckpoint]:
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc)
# Move start time back by 3 hours, since some Issues/PRs are getting dropped
# Could be due to delayed processing on GitHub side
# The non-updated issues since last poll will be shortcut-ed and not embedded
adjusted_start_datetime = start_datetime - timedelta(hours=3)
epoch = datetime.utcfromtimestamp(0)
epoch = datetime.fromtimestamp(0, tz=timezone.utc)
if adjusted_start_datetime < epoch:
adjusted_start_datetime = epoch
return self._fetch_from_github(adjusted_start_datetime, end_datetime)
return self._fetch_from_github(
checkpoint, start=adjusted_start_datetime, end=end_datetime
)
def validate_connector_settings(self) -> None:
if self.github_client is None:
@@ -397,6 +525,16 @@ class GithubConnector(LoadConnector, PollConnector):
f"Unexpected error during GitHub settings validation: {exc}"
)
def validate_checkpoint_json(
self, checkpoint_json: str
) -> GithubConnectorCheckpoint:
return GithubConnectorCheckpoint.model_validate_json(checkpoint_json)
def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint:
return GithubConnectorCheckpoint(
stage=GithubConnectorStage.PRS, curr_page=0, has_more=True
)
if __name__ == "__main__":
import os
@@ -406,7 +544,9 @@ if __name__ == "__main__":
repositories=os.environ["REPOSITORIES"],
)
connector.load_credentials(
{"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]}
{"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"]}
)
document_batches = connector.load_from_checkpoint(
0, time.time(), connector.build_dummy_checkpoint()
)
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@@ -2,11 +2,11 @@ import copy
import threading
from collections.abc import Callable
from collections.abc import Iterator
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from enum import Enum
from functools import partial
from typing import Any
from typing import cast
from typing import Protocol
from urllib.parse import urlparse
@@ -15,6 +15,7 @@ from google.oauth2.service_account import Credentials as ServiceAccountCredentia
from googleapiclient.errors import HttpError # type: ignore
from typing_extensions import override
from onyx.configs.app_configs import GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import MAX_DRIVE_WORKERS
from onyx.configs.constants import DocumentSource
@@ -27,7 +28,9 @@ from onyx.connectors.google_drive.doc_conversion import (
)
from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from onyx.connectors.google_drive.file_retrieval import (
get_all_files_in_my_drive_and_shared,
)
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
from onyx.connectors.google_drive.models import DriveRetrievalStage
@@ -57,13 +60,13 @@ from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.lazy import lazy_eval
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
from onyx.utils.threadpool_concurrency import parallel_yield
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.threadpool_concurrency import ThreadSafeDict
logger = setup_logger()
@@ -85,12 +88,18 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
def _convert_single_file(
creds: Any,
primary_admin_email: str,
allow_images: bool,
size_threshold: int,
retriever_email: str,
file: dict[str, Any],
) -> Document | ConnectorFailure | None:
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
# We used to always get the user email from the file owners when available,
# but this was causing issues with shared folders where the owner was not included in the service account
# now we use the email of the account that successfully listed the file. Leaving this in case we end up
# wanting to retry with file owners and/or admin email at some point.
# user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_email = retriever_email
# Only construct these services when needed
user_drive_service = lazy_eval(
lambda: get_drive_service(creds, user_email=user_email)
@@ -103,6 +112,7 @@ def _convert_single_file(
drive_service=user_drive_service,
docs_service=docs_service,
allow_images=allow_images,
size_threshold=size_threshold,
)
@@ -238,6 +248,8 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
self._retrieved_ids: set[str] = set()
self.allow_images = False
self.size_threshold = GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD
def set_allow_images(self, value: bool) -> None:
self.allow_images = value
@@ -445,10 +457,11 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
logger.info(f"Getting all files in my drive as '{user_email}'")
yield from add_retrieval_info(
get_all_files_in_my_drive(
get_all_files_in_my_drive_and_shared(
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
is_slim=is_slim,
include_shared_with_me=self.include_files_shared_with_me,
start=curr_stage.completed_until if resuming else start,
end=end,
),
@@ -456,6 +469,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
DriveRetrievalStage.MY_DRIVE_FILES,
)
curr_stage.stage = DriveRetrievalStage.SHARED_DRIVE_FILES
resuming = False # we are starting the next stage for the first time
if curr_stage.stage == DriveRetrievalStage.SHARED_DRIVE_FILES:
@@ -491,7 +505,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
)
yield from _yield_from_drive(drive_id, start)
curr_stage.stage = DriveRetrievalStage.FOLDER_FILES
resuming = False # we are starting the next stage for the first time
if curr_stage.stage == DriveRetrievalStage.FOLDER_FILES:
def _yield_from_folder_crawl(
@@ -544,6 +558,16 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
checkpoint, is_slim, DriveRetrievalStage.MY_DRIVE_FILES
)
# Setup initial completion map on first connector run
for email in all_org_emails:
# don't overwrite existing completion map on resuming runs
if email in checkpoint.completion_map:
continue
checkpoint.completion_map[email] = StageCompletion(
stage=DriveRetrievalStage.START,
completed_until=0,
)
# we've found all users and drives, now time to actually start
# fetching stuff
logger.info(f"Found {len(all_org_emails)} users to impersonate")
@@ -557,11 +581,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
drive_ids_to_retrieve, checkpoint
)
for email in all_org_emails:
checkpoint.completion_map[email] = StageCompletion(
stage=DriveRetrievalStage.START,
completed_until=0,
)
user_retrieval_gens = [
self._impersonate_user_for_retrieval(
email,
@@ -792,10 +811,12 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
return
for file in drive_files:
if file.error is not None:
if file.error is None:
checkpoint.completion_map[file.user_email].update(
stage=file.completion_stage,
completed_until=file.drive_file[GoogleFields.MODIFIED_TIME.value],
completed_until=datetime.fromisoformat(
file.drive_file[GoogleFields.MODIFIED_TIME.value]
).timestamp(),
completed_until_parent_id=file.parent_id,
)
yield file
@@ -897,117 +918,86 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
checkpoint: GoogleDriveCheckpoint,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[list[Document | ConnectorFailure]]:
) -> Iterator[Document | ConnectorFailure]:
try:
# Create a larger process pool for file conversion
with ThreadPoolExecutor(max_workers=8) as executor:
# Prepare a partial function with the credentials and admin email
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
self.allow_images,
# Prepare a partial function with the credentials and admin email
convert_func = partial(
_convert_single_file,
self.creds,
self.allow_images,
self.size_threshold,
)
# Fetch files in batches
batches_complete = 0
files_batch: list[RetrievedDriveFile] = []
def _yield_batch(
files_batch: list[RetrievedDriveFile],
) -> Iterator[Document | ConnectorFailure]:
nonlocal batches_complete
# Process the batch using run_functions_tuples_in_parallel
func_with_args = [
(
convert_func,
(
file.user_email,
file.drive_file,
),
)
for file in files_batch
]
results = cast(
list[Document | ConnectorFailure | None],
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
)
# Fetch files in batches
batches_complete = 0
files_batch: list[GoogleDriveFileType] = []
for retrieved_file in self._fetch_drive_items(
is_slim=False,
checkpoint=checkpoint,
start=start,
end=end,
):
if retrieved_file.error is not None:
failure_stage = retrieved_file.completion_stage.value
failure_message = (
f"retrieval failure during stage: {failure_stage},"
)
failure_message += f"user: {retrieved_file.user_email},"
failure_message += (
f"parent drive/folder: {retrieved_file.parent_id},"
)
failure_message += f"error: {retrieved_file.error}"
logger.error(failure_message)
yield [
ConnectorFailure(
failed_entity=EntityFailure(
entity_id=failure_stage,
),
failure_message=failure_message,
exception=retrieved_file.error,
)
]
continue
files_batch.append(retrieved_file.drive_file)
docs_and_failures = [result for result in results if result is not None]
if len(files_batch) < self.batch_size:
continue
if docs_and_failures:
yield from docs_and_failures
batches_complete += 1
# Process the batch
futures = [
executor.submit(convert_func, file) for file in files_batch
]
documents = []
for future in as_completed(futures):
try:
doc = future.result()
if doc is not None:
documents.append(doc)
except Exception as e:
error_str = f"Error converting file: {e}"
logger.error(error_str)
yield [
ConnectorFailure(
failed_document=DocumentFailure(
document_id=retrieved_file.drive_file["id"],
document_link=retrieved_file.drive_file[
"webViewLink"
],
),
failure_message=error_str,
exception=e,
)
]
for retrieved_file in self._fetch_drive_items(
is_slim=False,
checkpoint=checkpoint,
start=start,
end=end,
):
if retrieved_file.error is not None:
failure_stage = retrieved_file.completion_stage.value
failure_message = (
f"retrieval failure during stage: {failure_stage},"
)
failure_message += f"user: {retrieved_file.user_email},"
failure_message += (
f"parent drive/folder: {retrieved_file.parent_id},"
)
failure_message += f"error: {retrieved_file.error}"
logger.error(failure_message)
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=failure_stage,
),
failure_message=failure_message,
exception=retrieved_file.error,
)
if documents:
yield documents
batches_complete += 1
files_batch = []
continue
files_batch.append(retrieved_file)
if batches_complete > BATCHES_PER_CHECKPOINT:
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
return # create a new checkpoint
if len(files_batch) < self.batch_size:
continue
# Process any remaining files
if files_batch:
futures = [
executor.submit(convert_func, file) for file in files_batch
]
documents = []
for future in as_completed(futures):
try:
doc = future.result()
if doc is not None:
documents.append(doc)
except Exception as e:
error_str = f"Error converting file: {e}"
logger.error(error_str)
yield [
ConnectorFailure(
failed_document=DocumentFailure(
document_id=retrieved_file.drive_file["id"],
document_link=retrieved_file.drive_file[
"webViewLink"
],
),
failure_message=error_str,
exception=e,
)
]
yield from _yield_batch(files_batch)
files_batch = []
if documents:
yield documents
if batches_complete > BATCHES_PER_CHECKPOINT:
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_ids
return # create a new checkpoint
# Process any remaining files
if files_batch:
yield from _yield_batch(files_batch)
except Exception as e:
logger.exception(f"Error extracting documents from Google Drive: {e}")
raise e
@@ -1029,10 +1019,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
checkpoint = copy.deepcopy(checkpoint)
self._retrieved_ids = checkpoint.retrieved_folder_and_drive_ids
try:
for doc_list in self._extract_docs_from_google_drive(
checkpoint, start, end
):
yield from doc_list
yield from self._extract_docs_from_google_drive(checkpoint, start, end)
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
@@ -1067,9 +1054,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
raise RuntimeError(
"_extract_slim_docs_from_google_drive: Stop signal detected"
)
callback.progress("_extract_slim_docs_from_google_drive", 1)
yield slim_batch
def retrieve_all_slim_documents(

View File

@@ -76,7 +76,7 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
return is_valid_image_type(mime_type)
def _extract_sections_basic(
def _download_and_extract_sections_basic(
file: dict[str, str],
service: GoogleDriveService,
allow_images: bool,
@@ -87,35 +87,17 @@ def _extract_sections_basic(
mime_type = file["mimeType"]
link = file.get("webViewLink", "")
try:
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
@@ -124,88 +106,100 @@ def _extract_sections_basic(
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to download {file_name}")
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
# Process based on mime type
if mime_type == "text/plain":
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response))
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to download {file_name}")
return []
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
# Process based on mime type
if mime_type == "text/plain":
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response,
file_name=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
]
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response,
file_name=file_id,
display_name=file_name,
media_type=mime_type,
image_data=img_data,
file_name=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
]
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_name=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
else:
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
except Exception as e:
logger.error(f"Error processing file {file_name}: {e}")
return []
else:
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
def convert_drive_item_to_document(
@@ -213,6 +207,7 @@ def convert_drive_item_to_document(
drive_service: Callable[[], GoogleDriveService],
docs_service: Callable[[], GoogleDocsService],
allow_images: bool,
size_threshold: int,
) -> Document | ConnectorFailure | None:
"""
Main entry point for converting a Google Drive file => Document object.
@@ -240,9 +235,24 @@ def convert_drive_item_to_document(
f"Error in advanced parsing: {e}. Falling back to basic extraction."
)
size_str = file.get("size")
if size_str:
try:
size_int = int(size_str)
except ValueError:
logger.warning(f"Parsing string to int failed: size_str={size_str}")
else:
if size_int > size_threshold:
logger.warning(
f"{file.get('name')} exceeds size threshold of {size_threshold}. Skipping."
)
return None
# If we don't have sections yet, use the basic extraction method
if not sections:
sections = _extract_sections_basic(file, drive_service(), allow_images)
sections = _download_and_extract_sections_basic(
file, drive_service(), allow_images
)
# If we still don't have any sections, skip this file
if not sections:

View File

@@ -123,7 +123,7 @@ def crawl_folders_for_files(
end=end,
):
found_files = True
logger.info(f"Found file: {file['name']}")
logger.info(f"Found file: {file['name']}, user email: {user_email}")
yield RetrievedDriveFile(
drive_file=file,
user_email=user_email,
@@ -214,10 +214,11 @@ def get_files_in_shared_drive(
yield file
def get_all_files_in_my_drive(
def get_all_files_in_my_drive_and_shared(
service: GoogleDriveService,
update_traversed_ids_func: Callable,
is_slim: bool,
include_shared_with_me: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
@@ -229,7 +230,8 @@ def get_all_files_in_my_drive(
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
folder_query += " and 'me' in owners"
if not include_shared_with_me:
folder_query += " and 'me' in owners"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
@@ -246,7 +248,8 @@ def get_all_files_in_my_drive(
# Then get the files
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += " and 'me' in owners"
if not include_shared_with_me:
file_query += " and 'me' in owners"
file_query += _generate_time_range_filter(start, end)
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,

View File

@@ -75,7 +75,7 @@ class HighspotClient:
self.key = key
self.secret = secret
self.base_url = base_url
self.base_url = base_url.rstrip("/") + "/"
self.timeout = timeout
# Set up session with retry logic

View File

@@ -20,8 +20,8 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import VALID_FILE_EXTENSIONS
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -298,7 +298,7 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
elif (
is_valid_format
and file_extension in VALID_FILE_EXTENSIONS
and file_extension in ALL_ACCEPTED_FILE_EXTENSIONS
and can_download
):
# For documents, try to get the text content

View File

@@ -8,7 +8,6 @@ from typing import TypeAlias
from typing import TypeVar
from pydantic import BaseModel
from typing_extensions import override
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import ConnectorCheckpoint
@@ -231,7 +230,7 @@ class CheckpointConnector(BaseConnector[CT]):
"""
raise NotImplementedError
@override
@abc.abstractmethod
def build_dummy_checkpoint(self) -> CT:
raise NotImplementedError

View File

@@ -438,7 +438,11 @@ def _get_all_doc_ids(
class ProcessedSlackMessage(BaseModel):
doc: Document | None
thread_ts: str | None
# if the message is part of a thread, this is the thread_ts
# otherwise, this is the message_ts. Either way, will be a unique identifier.
# In the future, if the message becomes a thread, then the thread_ts
# will be set to the message_ts.
thread_or_message_ts: str
failure: ConnectorFailure | None
@@ -452,6 +456,7 @@ def _process_message(
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> ProcessedSlackMessage:
thread_ts = message.get("thread_ts")
thread_or_message_ts = thread_ts or message["ts"]
try:
# causes random failures for testing checkpointing / continue on failure
# import random
@@ -467,16 +472,18 @@ def _process_message(
seen_thread_ts=seen_thread_ts,
msg_filter_func=msg_filter_func,
)
return ProcessedSlackMessage(doc=doc, thread_ts=thread_ts, failure=None)
return ProcessedSlackMessage(
doc=doc, thread_or_message_ts=thread_or_message_ts, failure=None
)
except Exception as e:
logger.exception(f"Error processing message {message['ts']}")
return ProcessedSlackMessage(
doc=None,
thread_ts=thread_ts,
thread_or_message_ts=thread_or_message_ts,
failure=ConnectorFailure(
failed_document=DocumentFailure(
document_id=_build_doc_id(
channel_id=channel["id"], thread_ts=(thread_ts or message["ts"])
channel_id=channel["id"], thread_ts=thread_or_message_ts
),
document_link=get_message_link(message, client, channel["id"]),
),
@@ -616,7 +623,7 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
for future in as_completed(futures):
processed_slack_message = future.result()
doc = processed_slack_message.doc
thread_ts = processed_slack_message.thread_ts
thread_or_message_ts = processed_slack_message.thread_or_message_ts
failure = processed_slack_message.failure
if doc:
# handle race conditions here since this is single
@@ -624,11 +631,13 @@ class SlackConnector(SlimConnector, CheckpointConnector[SlackCheckpoint]):
# but since this is single threaded, we won't run into simul
# writes. At worst, we can duplicate a thread, which will be
# deduped later on.
if thread_ts not in seen_thread_ts:
if thread_or_message_ts not in seen_thread_ts:
yield doc
assert thread_ts, "found non-None doc with None thread_ts"
seen_thread_ts.add(thread_ts)
assert (
thread_or_message_ts
), "found non-None doc with None thread_or_message_ts"
seen_thread_ts.add(thread_or_message_ts)
elif failure:
yield failure

View File

@@ -1,23 +1,32 @@
import copy
import time
from collections.abc import Iterator
from typing import Any
from typing import cast
import requests
from pydantic import BaseModel
from requests.exceptions import HTTPError
from typing_extensions import override
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
time_str_to_utc,
)
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorFailure
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import parse_html_page_basic
@@ -26,6 +35,7 @@ from onyx.utils.retry_wrapper import retry_builder
MAX_PAGE_SIZE = 30 # Zendesk API maximum
MAX_AUTHOR_MAP_SIZE = 50_000 # Reset author map cache if it gets too large
_SLIM_BATCH_SIZE = 1000
@@ -53,10 +63,22 @@ class ZendeskClient:
# Sleep for the duration indicated by the Retry-After header
time.sleep(int(retry_after))
elif (
response.status_code == 403
and response.json().get("error") == "SupportProductInactive"
):
return response.json()
response.raise_for_status()
return response.json()
class ZendeskPageResponse(BaseModel):
data: list[dict[str, Any]]
meta: dict[str, Any]
has_more: bool
def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
content_tags: dict[str, str] = {}
params = {"page[size]": MAX_PAGE_SIZE}
@@ -82,11 +104,9 @@ def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]:
def _get_articles(
client: ZendeskClient, start_time: int | None = None, page_size: int = MAX_PAGE_SIZE
) -> Iterator[dict[str, Any]]:
params = (
{"start_time": start_time, "page[size]": page_size}
if start_time
else {"page[size]": page_size}
)
params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"}
if start_time is not None:
params["start_time"] = start_time
while True:
data = client.make_request("help_center/articles", params)
@@ -98,10 +118,30 @@ def _get_articles(
params["page[after]"] = data["meta"]["after_cursor"]
def _get_article_page(
client: ZendeskClient,
start_time: int | None = None,
after_cursor: str | None = None,
page_size: int = MAX_PAGE_SIZE,
) -> ZendeskPageResponse:
params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"}
if start_time is not None:
params["start_time"] = start_time
if after_cursor is not None:
params["page[after]"] = after_cursor
data = client.make_request("help_center/articles", params)
return ZendeskPageResponse(
data=data["articles"],
meta=data["meta"],
has_more=bool(data["meta"].get("has_more", False)),
)
def _get_tickets(
client: ZendeskClient, start_time: int | None = None
) -> Iterator[dict[str, Any]]:
params = {"start_time": start_time} if start_time else {"start_time": 0}
params = {"start_time": start_time or 0}
while True:
data = client.make_request("incremental/tickets.json", params)
@@ -114,9 +154,33 @@ def _get_tickets(
break
def _fetch_author(client: ZendeskClient, author_id: str) -> BasicExpertInfo | None:
# TODO: maybe these don't need to be their own functions?
def _get_tickets_page(
client: ZendeskClient, start_time: int | None = None
) -> ZendeskPageResponse:
params = {"start_time": start_time or 0}
# NOTE: for some reason zendesk doesn't seem to be respecting the start_time param
# in my local testing with very few tickets. We'll look into it if this becomes an
# issue in larger deployments
data = client.make_request("incremental/tickets.json", params)
if data.get("error") == "SupportProductInactive":
raise ValueError(
"Zendesk Support Product is not active for this account, No tickets to index"
)
return ZendeskPageResponse(
data=data["tickets"],
meta={"end_time": data["end_time"]},
has_more=not bool(data.get("end_of_stream", False)),
)
def _fetch_author(
client: ZendeskClient, author_id: str | int
) -> BasicExpertInfo | None:
# Skip fetching if author_id is invalid
if not author_id or author_id == "-1":
# cast to str to avoid issues with zendesk changing their types
if not author_id or str(author_id) == "-1":
return None
try:
@@ -278,13 +342,22 @@ def _ticket_to_document(
)
class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
class ZendeskConnectorCheckpoint(ConnectorCheckpoint):
# We use cursor-based paginated retrieval for articles
after_cursor_articles: str | None
# We use timestamp-based paginated retrieval for tickets
next_start_time_tickets: int | None
cached_author_map: dict[str, BasicExpertInfo] | None
cached_content_tags: dict[str, str] | None
class ZendeskConnector(SlimConnector, CheckpointConnector[ZendeskConnectorCheckpoint]):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
content_type: str = "articles",
) -> None:
self.batch_size = batch_size
self.content_type = content_type
self.subdomain = ""
# Fetch all tags ahead of time
@@ -304,33 +377,50 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
)
return None
def load_from_state(self) -> GenerateDocumentsOutput:
return self.poll_source(None, None)
def poll_source(
self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ZendeskConnectorCheckpoint,
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
if self.client is None:
raise ZendeskCredentialsNotSetUpError()
self.content_tags = _get_content_tag_mapping(self.client)
if checkpoint.cached_content_tags is None:
checkpoint.cached_content_tags = _get_content_tag_mapping(self.client)
return checkpoint # save the content tags to the checkpoint
self.content_tags = checkpoint.cached_content_tags
if self.content_type == "articles":
yield from self._poll_articles(start)
checkpoint = yield from self._retrieve_articles(start, end, checkpoint)
return checkpoint
elif self.content_type == "tickets":
yield from self._poll_tickets(start)
checkpoint = yield from self._retrieve_tickets(start, end, checkpoint)
return checkpoint
else:
raise ValueError(f"Unsupported content_type: {self.content_type}")
def _poll_articles(
self, start: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
articles = _get_articles(self.client, start_time=int(start) if start else None)
def _retrieve_articles(
self,
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
checkpoint: ZendeskConnectorCheckpoint,
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
checkpoint = copy.deepcopy(checkpoint)
# This one is built on the fly as there may be more many more authors than tags
author_map: dict[str, BasicExpertInfo] = {}
author_map: dict[str, BasicExpertInfo] = checkpoint.cached_author_map or {}
after_cursor = checkpoint.after_cursor_articles
doc_batch: list[Document] = []
doc_batch = []
response = _get_article_page(
self.client,
start_time=int(start) if start else None,
after_cursor=after_cursor,
)
articles = response.data
has_more = response.has_more
after_cursor = response.meta.get("after_cursor")
for article in articles:
if (
article.get("body") is None
@@ -342,66 +432,109 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
):
continue
new_author_map, documents = _article_to_document(
article, self.content_tags, author_map, self.client
)
try:
new_author_map, document = _article_to_document(
article, self.content_tags, author_map, self.client
)
except Exception as e:
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=f"{article.get('id')}",
document_link=article.get("html_url", ""),
),
failure_message=str(e),
exception=e,
)
continue
if new_author_map:
author_map.update(new_author_map)
doc_batch.append(documents)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch.clear()
doc_batch.append(document)
if doc_batch:
yield doc_batch
if not has_more:
yield from doc_batch
checkpoint.has_more = False
return checkpoint
def _poll_tickets(
self, start: SecondsSinceUnixEpoch | None
) -> GenerateDocumentsOutput:
# Sometimes no documents are retrieved, but the cursor
# is still updated so the connector makes progress.
yield from doc_batch
checkpoint.after_cursor_articles = after_cursor
last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None
checkpoint.has_more = bool(
end is None
or last_doc_updated_at is None
or last_doc_updated_at.timestamp() <= end
)
checkpoint.cached_author_map = (
author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None
)
return checkpoint
def _retrieve_tickets(
self,
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
checkpoint: ZendeskConnectorCheckpoint,
) -> CheckpointOutput[ZendeskConnectorCheckpoint]:
checkpoint = copy.deepcopy(checkpoint)
if self.client is None:
raise ZendeskCredentialsNotSetUpError()
author_map: dict[str, BasicExpertInfo] = {}
author_map: dict[str, BasicExpertInfo] = checkpoint.cached_author_map or {}
ticket_generator = _get_tickets(
self.client, start_time=int(start) if start else None
doc_batch: list[Document] = []
next_start_time = int(checkpoint.next_start_time_tickets or start or 0)
ticket_response = _get_tickets_page(self.client, start_time=next_start_time)
tickets = ticket_response.data
has_more = ticket_response.has_more
next_start_time = ticket_response.meta["end_time"]
for ticket in tickets:
if ticket.get("status") == "deleted":
continue
try:
new_author_map, document = _ticket_to_document(
ticket=ticket,
author_map=author_map,
client=self.client,
default_subdomain=self.subdomain,
)
except Exception as e:
yield ConnectorFailure(
failed_document=DocumentFailure(
document_id=f"{ticket.get('id')}",
document_link=ticket.get("url", ""),
),
failure_message=str(e),
exception=e,
)
continue
if new_author_map:
author_map.update(new_author_map)
doc_batch.append(document)
if not has_more:
yield from doc_batch
checkpoint.has_more = False
return checkpoint
yield from doc_batch
checkpoint.next_start_time_tickets = next_start_time
last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None
checkpoint.has_more = bool(
end is None
or last_doc_updated_at is None
or last_doc_updated_at.timestamp() <= end
)
while True:
doc_batch = []
for _ in range(self.batch_size):
try:
ticket = next(ticket_generator)
# Check if the ticket status is deleted and skip it if so
if ticket.get("status") == "deleted":
continue
new_author_map, documents = _ticket_to_document(
ticket=ticket,
author_map=author_map,
client=self.client,
default_subdomain=self.subdomain,
)
if new_author_map:
author_map.update(new_author_map)
doc_batch.append(documents)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch.clear()
except StopIteration:
# No more tickets to process
if doc_batch:
yield doc_batch
return
if doc_batch:
yield doc_batch
checkpoint.cached_author_map = (
author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None
)
return checkpoint
def retrieve_all_slim_documents(
self,
@@ -441,10 +574,51 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
if slim_doc_batch:
yield slim_doc_batch
@override
def validate_connector_settings(self) -> None:
if self.client is None:
raise ZendeskCredentialsNotSetUpError()
try:
_get_article_page(self.client, start_time=0)
except HTTPError as e:
# Check for HTTP status codes
if e.response.status_code == 401:
raise CredentialExpiredError(
"Your Zendesk credentials appear to be invalid or expired (HTTP 401)."
) from e
elif e.response.status_code == 403:
raise InsufficientPermissionsError(
"Your Zendesk token does not have sufficient permissions (HTTP 403)."
) from e
elif e.response.status_code == 404:
raise ConnectorValidationError(
"Zendesk resource not found (HTTP 404)."
) from e
else:
raise ConnectorValidationError(
f"Unexpected Zendesk error (status={e.response.status_code}): {e}"
) from e
@override
def validate_checkpoint_json(
self, checkpoint_json: str
) -> ZendeskConnectorCheckpoint:
return ZendeskConnectorCheckpoint.model_validate_json(checkpoint_json)
@override
def build_dummy_checkpoint(self) -> ZendeskConnectorCheckpoint:
return ZendeskConnectorCheckpoint(
after_cursor_articles=None,
next_start_time_tickets=None,
cached_author_map=None,
cached_content_tags=None,
has_more=True,
)
if __name__ == "__main__":
import os
import time
connector = ZendeskConnector()
connector.load_credentials(
@@ -457,6 +631,8 @@ if __name__ == "__main__":
current = time.time()
one_day_ago = current - 24 * 60 * 60 # 1 day
document_batches = connector.poll_source(one_day_ago, current)
document_batches = connector.load_from_checkpoint(
one_day_ago, current, connector.build_dummy_checkpoint()
)
print(next(document_batches))

View File

@@ -339,6 +339,12 @@ class SearchPipeline:
self._retrieved_sections = self._get_sections()
return self._retrieved_sections
@property
def merged_retrieved_sections(self) -> list[InferenceSection]:
"""Should be used to display in the UI in order to prevent displaying
multiple sections for the same document as separate "documents"."""
return _merge_sections(sections=self.retrieved_sections)
@property
def reranked_sections(self) -> list[InferenceSection]:
"""Reranking is always done at the chunk level since section merging could create arbitrarily
@@ -415,6 +421,10 @@ class SearchPipeline:
raise ValueError(
"Basic search evaluation operation called while DISABLE_LLM_DOC_RELEVANCE is enabled."
)
# NOTE: final_context_sections must be accessed before accessing self._postprocessing_generator
# since the property sets the generator. DO NOT REMOVE.
_ = self.final_context_sections
self._section_relevance = next(
cast(
Iterator[list[SectionRelevancePiece]],

View File

@@ -1089,3 +1089,20 @@ def log_agent_sub_question_results(
db_session.commit()
return None
def update_chat_session_updated_at_timestamp(
chat_session_id: UUID, db_session: Session
) -> None:
"""
Explicitly update the timestamp on a chat session without modifying other fields.
This is useful when adding messages to a chat session to reflect recent activity.
"""
# Direct SQL update to avoid loading the entire object if it's not already loaded
db_session.execute(
update(ChatSession)
.where(ChatSession.id == chat_session_id)
.values(time_updated=func.now())
)
# No commit - the caller is responsible for committing the transaction

View File

@@ -555,6 +555,28 @@ def delete_documents_by_connector_credential_pair__no_commit(
db_session.execute(stmt)
def delete_all_documents_by_connector_credential_pair__no_commit(
db_session: Session,
connector_id: int,
credential_id: int,
) -> None:
"""Deletes all document by connector credential pair entries for a specific connector and credential.
This is primarily used during connector deletion to ensure all references are removed
before deleting the connector itself. This is crucial because connector_id is part of the
primary key in DocumentByConnectorCredentialPair, and attempting to delete the Connector
would otherwise try to set the foreign key to NULL, which fails for primary keys.
NOTE: Does not commit the transaction, this must be done by the caller.
"""
stmt = delete(DocumentByConnectorCredentialPair).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
db_session.execute(stmt)
def delete_documents__no_commit(db_session: Session, document_ids: list[str]) -> None:
db_session.execute(delete(DbDocument).where(DbDocument.id.in_(document_ids)))

View File

@@ -8,23 +8,31 @@ from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from onyx.connectors.models import ConnectorFailure
from onyx.db.engine import get_session_context_manager
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.server.documents.models import ConnectorCredentialPair
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
# Comment out unused imports that cause mypy errors
# from onyx.auth.models import UserRole
# from onyx.configs.constants import MAX_LAST_VALID_CHECKPOINT_AGE_SECONDS
# from onyx.db.connector_credential_pair import ConnectorCredentialPairIdentifier
# from onyx.db.engine import async_query_for_dms
logger = setup_logger()
@@ -201,6 +209,17 @@ def mark_attempt_in_progress(
attempt.status = IndexingStatus.IN_PROGRESS
attempt.time_started = index_attempt.time_started or func.now() # type: ignore
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt.id,
"status": IndexingStatus.IN_PROGRESS.value,
"cc_pair_id": index_attempt.connector_credential_pair_id,
"search_settings_id": index_attempt.search_settings_id,
},
)
except Exception:
db_session.rollback()
raise
@@ -219,6 +238,19 @@ def mark_attempt_succeeded(
attempt.status = IndexingStatus.SUCCESS
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.SUCCESS.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -237,6 +269,19 @@ def mark_attempt_partially_succeeded(
attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.COMPLETED_WITH_ERRORS.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -259,6 +304,20 @@ def mark_attempt_canceled(
attempt.status = IndexingStatus.CANCELED
attempt.error_msg = reason
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.CANCELED.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"reason": reason,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -283,6 +342,20 @@ def mark_attempt_failed(
attempt.error_msg = failure_reason
attempt.full_exception_trace = full_exception_trace
db_session.commit()
# Add telemetry for index attempt status change
optional_telemetry(
record_type=RecordType.INDEX_ATTEMPT_STATUS,
data={
"index_attempt_id": index_attempt_id,
"status": IndexingStatus.FAILED.value,
"cc_pair_id": attempt.connector_credential_pair_id,
"search_settings_id": attempt.search_settings_id,
"reason": failure_reason,
"total_docs_indexed": attempt.total_docs_indexed,
"new_docs_indexed": attempt.new_docs_indexed,
},
)
except Exception:
db_session.rollback()
raise
@@ -434,7 +507,7 @@ def get_latest_index_attempts_parallel(
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
with get_session_context_manager() as db_session:
with get_session_with_current_tenant() as db_session:
return get_latest_index_attempts(
secondary_index,
db_session,

View File

@@ -24,7 +24,9 @@ from onyx.db.models import User__UserGroup
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
def validate_user_role_update(requested_role: UserRole, current_role: UserRole) -> None:
def validate_user_role_update(
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
) -> None:
"""
Validate that a user role update is valid.
Assumed only admins can hit this endpoint.
@@ -57,6 +59,9 @@ def validate_user_role_update(requested_role: UserRole, current_role: UserRole)
detail="To change a Limited User's role, they must first login to Onyx via the web app.",
)
if explicit_override:
return
if requested_role == UserRole.CURATOR:
# This shouldn't happen, but just in case
raise HTTPException(

View File

@@ -5,13 +5,15 @@ import re
import zipfile
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from email.parser import Parser as EmailParser
from enum import auto
from enum import IntFlag
from io import BytesIO
from pathlib import Path
from typing import Any
from typing import IO
from typing import List
from typing import Tuple
from typing import NamedTuple
import chardet
import docx # type: ignore
@@ -35,7 +37,7 @@ logger = setup_logger()
TEXT_SECTION_SEPARATOR = "\n\n"
PLAIN_TEXT_FILE_EXTENSIONS = [
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS = [
".txt",
".md",
".mdx",
@@ -49,7 +51,7 @@ PLAIN_TEXT_FILE_EXTENSIONS = [
".yaml",
]
VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [
ACCEPTED_DOCUMENT_FILE_EXTENSIONS = [
".pdf",
".docx",
".pptx",
@@ -57,12 +59,21 @@ VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [
".eml",
".epub",
".html",
]
ACCEPTED_IMAGE_FILE_EXTENSIONS = [
".png",
".jpg",
".jpeg",
".webp",
]
ALL_ACCEPTED_FILE_EXTENSIONS = (
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
+ ACCEPTED_DOCUMENT_FILE_EXTENSIONS
+ ACCEPTED_IMAGE_FILE_EXTENSIONS
)
IMAGE_MEDIA_TYPES = [
"image/png",
"image/jpeg",
@@ -70,8 +81,15 @@ IMAGE_MEDIA_TYPES = [
]
class OnyxExtensionType(IntFlag):
Plain = auto()
Document = auto()
Multimedia = auto()
All = Plain | Document | Multimedia
def is_text_file_extension(file_name: str) -> bool:
return any(file_name.endswith(ext) for ext in PLAIN_TEXT_FILE_EXTENSIONS)
return any(file_name.endswith(ext) for ext in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS)
def get_file_ext(file_path_or_name: str | Path) -> str:
@@ -83,8 +101,20 @@ def is_valid_media_type(media_type: str) -> bool:
return media_type in IMAGE_MEDIA_TYPES
def is_valid_file_ext(ext: str) -> bool:
return ext in VALID_FILE_EXTENSIONS
def is_accepted_file_ext(ext: str, ext_type: OnyxExtensionType) -> bool:
if ext_type & OnyxExtensionType.Plain:
if ext in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS:
return True
if ext_type & OnyxExtensionType.Document:
if ext in ACCEPTED_DOCUMENT_FILE_EXTENSIONS:
return True
if ext_type & OnyxExtensionType.Multimedia:
if ext in ACCEPTED_IMAGE_FILE_EXTENSIONS:
return True
return False
def is_text_file(file: IO[bytes]) -> bool:
@@ -219,7 +249,7 @@ def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str:
def read_pdf_file(
file: IO[Any], pdf_pass: str | None = None, extract_images: bool = False
) -> tuple[str, dict, list[tuple[bytes, str]]]:
) -> tuple[str, dict[str, Any], Sequence[tuple[bytes, str]]]:
"""
Returns the text, basic PDF metadata, and optionally extracted images.
"""
@@ -282,13 +312,13 @@ def read_pdf_file(
def docx_to_text_and_images(
file: IO[Any],
) -> Tuple[str, List[Tuple[bytes, str]]]:
) -> tuple[str, Sequence[tuple[bytes, str]]]:
"""
Extract text from a docx. If embed_images=True, also extract inline images.
Return (text_content, list_of_images).
"""
paragraphs = []
embedded_images: List[Tuple[bytes, str]] = []
embedded_images: list[tuple[bytes, str]] = []
doc = docx.Document(file)
@@ -382,6 +412,9 @@ def extract_file_text(
"""
Legacy function that returns *only text*, ignoring embedded images.
For backward-compatibility in code that only wants text.
NOTE: Ignoring seems to be defined as returning an empty string for files it can't
handle (such as images).
"""
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
".pdf": pdf_to_text,
@@ -405,7 +438,9 @@ def extract_file_text(
if extension is None:
extension = get_file_ext(file_name)
if is_valid_file_ext(extension):
if is_accepted_file_ext(
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
):
func = extension_to_function.get(extension, file_io_to_text)
file.seek(0)
return func(file)
@@ -426,14 +461,22 @@ def extract_file_text(
return ""
class ExtractionResult(NamedTuple):
"""Structured result from text and image extraction from various file types."""
text_content: str
embedded_images: Sequence[tuple[bytes, str]]
metadata: dict[str, Any]
def extract_text_and_images(
file: IO[Any],
file_name: str,
pdf_pass: str | None = None,
) -> Tuple[str, List[Tuple[bytes, str]]]:
) -> ExtractionResult:
"""
Primary new function for the updated connector.
Returns (text_content, [(embedded_img_bytes, embedded_img_name), ...]).
Returns structured extraction result with text content, embedded images, and metadata.
"""
try:
@@ -442,7 +485,9 @@ def extract_text_and_images(
# If the user doesn't want embedded images, unstructured is fine
file.seek(0)
text_content = unstructured_to_text(file, file_name)
return (text_content, [])
return ExtractionResult(
text_content=text_content, embedded_images=[], metadata={}
)
extension = get_file_ext(file_name)
@@ -450,54 +495,76 @@ def extract_text_and_images(
if extension == ".docx":
file.seek(0)
text_content, images = docx_to_text_and_images(file)
return (text_content, images)
return ExtractionResult(
text_content=text_content, embedded_images=images, metadata={}
)
# PDF example: we do not show complicated PDF image extraction here
# so we simply extract text for now and skip images.
if extension == ".pdf":
file.seek(0)
text_content, _, images = read_pdf_file(file, pdf_pass, extract_images=True)
return (text_content, images)
text_content, pdf_metadata, images = read_pdf_file(
file, pdf_pass, extract_images=True
)
return ExtractionResult(
text_content=text_content, embedded_images=images, metadata=pdf_metadata
)
# For PPTX, XLSX, EML, etc., we do not show embedded image logic here.
# You can do something similar to docx if needed.
if extension == ".pptx":
file.seek(0)
return (pptx_to_text(file), [])
return ExtractionResult(
text_content=pptx_to_text(file), embedded_images=[], metadata={}
)
if extension == ".xlsx":
file.seek(0)
return (xlsx_to_text(file), [])
return ExtractionResult(
text_content=xlsx_to_text(file), embedded_images=[], metadata={}
)
if extension == ".eml":
file.seek(0)
return (eml_to_text(file), [])
return ExtractionResult(
text_content=eml_to_text(file), embedded_images=[], metadata={}
)
if extension == ".epub":
file.seek(0)
return (epub_to_text(file), [])
return ExtractionResult(
text_content=epub_to_text(file), embedded_images=[], metadata={}
)
if extension == ".html":
file.seek(0)
return (parse_html_page_basic(file), [])
return ExtractionResult(
text_content=parse_html_page_basic(file),
embedded_images=[],
metadata={},
)
# If we reach here and it's a recognized text extension
if is_text_file_extension(file_name):
file.seek(0)
encoding = detect_encoding(file)
text_content_raw, _ = read_text_file(
text_content_raw, file_metadata = read_text_file(
file, encoding=encoding, ignore_onyx_metadata=False
)
return (text_content_raw, [])
return ExtractionResult(
text_content=text_content_raw,
embedded_images=[],
metadata=file_metadata,
)
# If it's an image file or something else, we do not parse embedded images from them
# just return empty text
file.seek(0)
return ("", [])
return ExtractionResult(text_content="", embedded_images=[], metadata={})
except Exception as e:
logger.exception(f"Failed to extract text/images from {file_name}: {e}")
return ("", [])
return ExtractionResult(text_content="", embedded_images=[], metadata={})
def convert_docx_to_txt(

View File

@@ -15,6 +15,7 @@ EXCLUDED_IMAGE_TYPES = [
"image/tiff",
"image/gif",
"image/svg+xml",
"image/avif",
]

View File

@@ -361,7 +361,15 @@ def get_application() -> FastAPI:
)
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
# For Google OAuth, refresh tokens are requested by:
# 1. Adding the right scopes
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
oauth_client = GoogleOAuth2(
OAUTH_CLIENT_ID,
OAUTH_CLIENT_SECRET,
# Use standard scopes that include profile and email
scopes=["openid", "email", "profile"],
)
include_auth_router_with_prefix(
application,
create_onyx_oauth_router(
@@ -383,6 +391,13 @@ def get_application() -> FastAPI:
prefix="/auth",
)
# Add refresh token endpoint for OAuth as well
include_auth_router_with_prefix(
application,
fastapi_users.get_refresh_router(auth_backend),
prefix="/auth",
)
application.add_exception_handler(
RequestValidationError, validation_exception_handler
)

View File

@@ -15,7 +15,6 @@ from onyx.configs.constants import MessageType
from onyx.configs.constants import SearchFeedbackType
from onyx.configs.onyxbot_configs import DANSWER_FOLLOWUP_EMOJI
from onyx.connectors.slack.utils import expert_info_from_slack_id
from onyx.connectors.slack.utils import make_slack_api_rate_limited
from onyx.context.search.models import SavedSearchDoc
from onyx.db.chat import get_chat_message
from onyx.db.chat import translate_db_message_to_chat_message_detail
@@ -553,8 +552,7 @@ def handle_followup_resolved_button(
# Delete the message with the option to mark resolved
if not immediate:
slack_call = make_slack_api_rate_limited(client.web_client.chat_delete)
response = slack_call(
response = client.web_client.chat_delete(
channel=channel_id,
ts=message_ts,
)

View File

@@ -170,7 +170,8 @@ def handle_message(
respond_tag_only = channel_conf.get("respond_tag_only") or False
respond_member_group_list = channel_conf.get("respond_member_group_list", None)
if respond_tag_only and not bypass_filters:
# NOTE: always respond in the DMs, as long the default config is not disabled.
if respond_tag_only and not bypass_filters and not is_bot_dm:
logger.info(
"Skipping message since the channel is configured such that "
"OnyxBot only responds to tags"

View File

@@ -18,6 +18,9 @@ from prometheus_client import start_http_server
from redis.lock import Lock
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.http_retry import ConnectionErrorRetryHandler
from slack_sdk.http_retry import RateLimitErrorRetryHandler
from slack_sdk.http_retry import RetryHandler
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
@@ -944,10 +947,21 @@ def _get_socket_client(
) -> TenantSocketModeClient:
# For more info on how to set this up, checkout the docs:
# https://docs.onyx.app/slack_bot_setup
# use the retry handlers built into the slack sdk
connection_error_retry_handler = ConnectionErrorRetryHandler()
rate_limit_error_retry_handler = RateLimitErrorRetryHandler(max_retry_count=7)
slack_retry_handlers: list[RetryHandler] = [
connection_error_retry_handler,
rate_limit_error_retry_handler,
]
return TenantSocketModeClient(
# This app-level token will be used only for establishing a connection
app_token=slack_bot_tokens.app_token,
web_client=WebClient(token=slack_bot_tokens.bot_token),
web_client=WebClient(
token=slack_bot_tokens.bot_token, retry_handlers=slack_retry_handlers
),
tenant_id=tenant_id,
slack_bot_id=slack_bot_id,
)

View File

@@ -30,7 +30,6 @@ from onyx.configs.onyxbot_configs import (
from onyx.configs.onyxbot_configs import (
DANSWER_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS,
)
from onyx.connectors.slack.utils import make_slack_api_rate_limited
from onyx.connectors.slack.utils import SlackTextCleaner
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.users import get_user_by_email
@@ -125,13 +124,18 @@ def update_emote_react(
)
return
func = client.reactions_remove if remove else client.reactions_add
slack_call = make_slack_api_rate_limited(func) # type: ignore
slack_call(
name=emoji,
channel=channel,
timestamp=message_ts,
)
if remove:
client.reactions_remove(
name=emoji,
channel=channel,
timestamp=message_ts,
)
else:
client.reactions_add(
name=emoji,
channel=channel,
timestamp=message_ts,
)
except SlackApiError as e:
if remove:
logger.error(f"Failed to remove Reaction due to: {e}")
@@ -200,9 +204,8 @@ def respond_in_thread_or_channel(
message_ids: list[str] = []
if not receiver_ids:
slack_call = make_slack_api_rate_limited(client.chat_postMessage)
try:
response = slack_call(
response = client.chat_postMessage(
channel=channel,
text=text,
blocks=blocks,
@@ -224,7 +227,7 @@ def respond_in_thread_or_channel(
blocks_without_urls.append(_build_error_block(str(e)))
# Try again wtihout blocks containing url
response = slack_call(
response = client.chat_postMessage(
channel=channel,
text=text,
blocks=blocks_without_urls,
@@ -236,11 +239,9 @@ def respond_in_thread_or_channel(
message_ids.append(response["message_ts"])
else:
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
for receiver in receiver_ids:
try:
response = slack_call(
response = client.chat_postEphemeral(
channel=channel,
user=receiver,
text=text,
@@ -263,7 +264,7 @@ def respond_in_thread_or_channel(
blocks_without_urls.append(_build_error_block(str(e)))
# Try again wtihout blocks containing url
response = slack_call(
response = client.chat_postEphemeral(
channel=channel,
user=receiver,
text=text,
@@ -500,7 +501,7 @@ def fetch_user_semantic_id_from_id(
if not user_id:
return None
response = make_slack_api_rate_limited(client.users_info)(user=user_id)
response = client.users_info(user=user_id)
if not response["ok"]:
return None

View File

@@ -31,6 +31,7 @@ PUBLIC_ENDPOINT_SPECS = [
# just gets the version of Onyx (e.g. 0.3.11)
("/version", {"GET"}),
# stuff related to basic auth
("/auth/refresh", {"POST"}),
("/auth/register", {"POST"}),
("/auth/login", {"POST"}),
("/auth/logout", {"POST"}),

View File

@@ -132,6 +132,7 @@ class UserByEmail(BaseModel):
class UserRoleUpdateRequest(BaseModel):
user_email: str
new_role: UserRole
explicit_override: bool = False
class UserRoleResponse(BaseModel):

View File

@@ -261,9 +261,6 @@ def create_bot(
# Create a default Slack channel config
default_channel_config = ChannelConfig(
channel_name=None,
respond_member_group_list=[],
answer_filters=[],
follow_up_tags=[],
respond_tag_only=True,
)
insert_slack_channel_config(
@@ -371,7 +368,9 @@ def get_all_channels_from_slack_api(
_: User | None = Depends(current_admin_user),
) -> list[SlackChannel]:
"""
Fetches channels the bot is a member of from the Slack API.
Fetches all channels in the Slack workspace using the conversations_list API.
This includes both public and private channels that are visible to the app,
not just the ones the bot is a member of.
Handles pagination with a limit to avoid excessive API calls.
"""
tokens = fetch_slack_bot_tokens(db_session, bot_id)
@@ -386,20 +385,20 @@ def get_all_channels_from_slack_api(
current_page = 0
try:
# Use users_conversations with limited pagination
# Use conversations_list to get all channels in the workspace (including ones the bot is not a member of)
while current_page < MAX_SLACK_PAGES:
current_page += 1
# Make API call with cursor if we have one
if next_cursor:
response = client.users_conversations(
response = client.conversations_list(
types="public_channel,private_channel",
exclude_archived=True,
cursor=next_cursor,
limit=SLACK_API_CHANNELS_PER_PAGE,
)
else:
response = client.users_conversations(
response = client.conversations_list(
types="public_channel,private_channel",
exclude_archived=True,
limit=SLACK_API_CHANNELS_PER_PAGE,

View File

@@ -102,6 +102,7 @@ def set_user_role(
validate_user_role_update(
requested_role=requested_role,
current_role=current_role,
explicit_override=user_role_update_request.explicit_override,
)
if user_to_update.id == current_user.id:
@@ -122,6 +123,22 @@ def set_user_role(
db_session.commit()
class TestUpsertRequest(BaseModel):
email: str
@router.post("/manage/users/test-upsert-user")
async def test_upsert_user(
request: TestUpsertRequest,
_: User = Depends(current_admin_user),
) -> None | FullUserSnapshot:
"""Test endpoint for upsert_saml_user. Only used for integration testing."""
user = await fetch_ee_implementation_or_noop(
"onyx.server.saml", "upsert_saml_user", None
)(email=request.email)
return FullUserSnapshot.from_user_model(user) if user else None
@router.get("/manage/users/accepted")
def list_accepted_users(
q: str | None = Query(default=None),
@@ -296,7 +313,7 @@ def bulk_invite_users(
detail=f"Invalid email address: {email} - {str(e)}",
)
if MULTI_TENANT and not DEV_MODE:
if MULTI_TENANT:
try:
fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning", "add_users_to_tenant", None
@@ -318,7 +335,7 @@ def bulk_invite_users(
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
if not MULTI_TENANT:
if not MULTI_TENANT or DEV_MODE:
return number_of_invited_users
# for billing purposes, write to the control plane about the number of new users
@@ -359,7 +376,7 @@ def remove_invited_user(
number_of_invited_users = write_invited_users(remaining_users)
try:
if MULTI_TENANT:
if MULTI_TENANT and not DEV_MODE:
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(tenant_id, get_total_users_count(db_session))

View File

@@ -1,10 +1,19 @@
import io
from typing import cast
from PIL import Image
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.celery.tasks.beat_schedule import (
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT,
)
from onyx.configs.constants import CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import ONYX_EMAILABLE_LOGO_MAX_DIM
from onyx.db.engine import get_session_with_shared_schema
from onyx.file_store.file_store import PostgresBackedFileStore
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.utils.file import FileWithMimeType
from onyx.utils.file import OnyxStaticFileManager
from onyx.utils.variable_functionality import (
@@ -87,3 +96,72 @@ class OnyxRuntime:
)
return OnyxRuntime._get_with_static_fallback(db_filename, STATIC_FILENAME)
@staticmethod
def get_beat_multiplier() -> float:
"""the beat multiplier is used to scale up or down the frequency of certain beat
tasks in the cloud. It has a significant effect on load and is useful to adjust
in real time."""
beat_multiplier: float = CLOUD_BEAT_MULTIPLIER_DEFAULT
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
if beat_multiplier_raw is not None:
try:
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
beat_multiplier = float(beat_multiplier_bytes.decode())
except ValueError:
pass
if beat_multiplier <= 0.0:
return 1.0
return beat_multiplier
@staticmethod
def get_doc_permission_sync_multiplier() -> float:
"""Permission syncs are a significant source of load / queueing in the cloud."""
value: float = CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
value_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:doc_permission_sync_multiplier")
if value_raw is not None:
try:
value_bytes = cast(bytes, value_raw)
value = float(value_bytes.decode())
except ValueError:
pass
if value <= 0.0:
return 1.0
return value
@staticmethod
def get_build_fence_lookup_table_interval() -> int:
"""We maintain an active fence table to make lookups of existing fences efficient.
However, reconstructing the table is expensive, so adjusting it in realtime is useful.
"""
interval: int = CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
interval_raw = r.get(
f"{ONYX_CLOUD_REDIS_RUNTIME}:build_fence_lookup_table_interval"
)
if interval_raw is not None:
try:
interval_bytes = cast(bytes, interval_raw)
interval = int(interval_bytes.decode())
except ValueError:
pass
if interval <= 0.0:
return CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
return interval

View File

@@ -12,7 +12,6 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import ContextualPruningConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import SectionRelevancePiece
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
@@ -42,9 +41,6 @@ from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
)
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
from onyx.tools.tool_implementations.search_like_tool_utils import (
build_next_prompt_for_search_like_tool,
@@ -58,7 +54,6 @@ from onyx.utils.special_types import JSON_ro
logger = setup_logger()
SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
SEARCH_DOC_CONTENT_ID = "search_doc_content"
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
SEARCH_EVALUATION_ID = "llm_doc_eval"
QUERY_FIELD = "query"
@@ -357,13 +352,13 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
)
yield from yield_search_responses(
query,
lambda: search_pipeline.retrieved_sections,
lambda: search_pipeline.reranked_sections,
lambda: search_pipeline.final_context_sections,
search_query_info,
lambda: search_pipeline.section_relevance,
self,
query=query,
# give back the merged sections to prevent duplicate docs from appearing in the UI
get_retrieved_sections=lambda: search_pipeline.merged_retrieved_sections,
get_final_context_sections=lambda: search_pipeline.final_context_sections,
search_query_info=search_query_info,
get_section_relevance=lambda: search_pipeline.section_relevance,
search_tool=self,
)
def final_result(self, *args: ToolResponse) -> JSON_ro:
@@ -405,7 +400,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
def yield_search_responses(
query: str,
get_retrieved_sections: Callable[[], list[InferenceSection]],
get_reranked_sections: Callable[[], list[InferenceSection]],
get_final_context_sections: Callable[[], list[InferenceSection]],
search_query_info: SearchQueryInfo,
get_section_relevance: Callable[[], list[SectionRelevancePiece] | None],
@@ -423,16 +417,6 @@ def yield_search_responses(
),
)
yield ToolResponse(
id=SEARCH_DOC_CONTENT_ID,
response=OnyxContexts(
contexts=[
context_from_inference_section(section)
for section in get_reranked_sections()
]
),
)
section_relevance = get_section_relevance()
yield ToolResponse(
id=SECTION_RELEVANCE_LIST_ID,

View File

@@ -1,5 +1,4 @@
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.context.search.models import InferenceSection
from onyx.prompts.prompt_utils import clean_up_source
@@ -32,10 +31,23 @@ def section_to_dict(section: InferenceSection, section_num: int) -> dict:
return doc_dict
def context_from_inference_section(section: InferenceSection) -> OnyxContext:
return OnyxContext(
content=section.combined_content,
def section_to_llm_doc(section: InferenceSection) -> LlmDoc:
possible_link_chunks = [section.center_chunk] + section.chunks
link: str | None = None
for chunk in possible_link_chunks:
if chunk.source_links:
link = list(chunk.source_links.values())[0]
break
return LlmDoc(
document_id=section.center_chunk.document_id,
content=section.combined_content,
source_type=section.center_chunk.source_type,
semantic_identifier=section.center_chunk.semantic_identifier,
metadata=section.center_chunk.metadata,
updated_at=section.center_chunk.updated_at,
blurb=section.center_chunk.blurb,
link=link,
source_links=section.center_chunk.source_links,
match_highlights=section.center_chunk.match_highlights,
)

View File

@@ -36,6 +36,10 @@ class RecordType(str, Enum):
LATENCY = "latency"
FAILURE = "failure"
METRIC = "metric"
INDEXING_PROGRESS = "indexing_progress"
INDEXING_COMPLETE = "indexing_complete"
PERMISSION_SYNC_PROGRESS = "permission_sync_progress"
INDEX_ATTEMPT_STATUS = "index_attempt_status"
def _get_or_generate_customer_id_mt(tenant_id: str) -> str:

View File

@@ -6,14 +6,17 @@ import uuid
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import MutableMapping
from collections.abc import Sequence
from concurrent.futures import as_completed
from concurrent.futures import FIRST_COMPLETED
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from concurrent.futures import wait
from typing import Any
from typing import cast
from typing import Generic
from typing import overload
from typing import Protocol
from typing import TypeVar
from pydantic import GetCoreSchemaHandler
@@ -145,13 +148,20 @@ class ThreadSafeDict(MutableMapping[KT, VT]):
return collections.abc.ValuesView(self)
class CallableProtocol(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
...
def run_functions_tuples_in_parallel(
functions_with_args: list[tuple[Callable, tuple]],
functions_with_args: Sequence[tuple[CallableProtocol, tuple[Any, ...]]],
allow_failures: bool = False,
max_workers: int | None = None,
) -> list[Any]:
"""
Executes multiple functions in parallel and returns a list of the results for each function.
This function preserves contextvars across threads, which is important for maintaining
context like tenant IDs in database sessions.
Args:
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
@@ -159,7 +169,7 @@ def run_functions_tuples_in_parallel(
max_workers: Max number of worker threads
Returns:
dict: A dictionary mapping function names to their results or error messages.
list: A list of results from each function, in the same order as the input functions.
"""
workers = (
min(max_workers, len(functions_with_args))
@@ -186,7 +196,7 @@ def run_functions_tuples_in_parallel(
results.append((index, future.result()))
except Exception as e:
logger.exception(f"Function at index {index} failed due to {e}")
results.append((index, None))
results.append((index, None)) # type: ignore
if not allow_failures:
raise
@@ -288,7 +298,7 @@ def run_with_timeout(
if task.is_alive():
task.end()
return task.result
return task.result # type: ignore
# NOTE: this function should really only be used when run_functions_tuples_in_parallel is
@@ -304,9 +314,9 @@ def run_in_background(
"""
context = contextvars.copy_context()
# Timeout not used in the non-blocking case
task = TimeoutThread(-1, context.run, func, *args, **kwargs)
task = TimeoutThread(-1, context.run, func, *args, **kwargs) # type: ignore
task.start()
return task
return cast(TimeoutThread[R], task)
def wait_on_background(task: TimeoutThread[R]) -> R:

View File

@@ -56,7 +56,7 @@ puremagic==1.28
pyairtable==3.0.1
pycryptodome==3.19.1
pydantic==2.8.2
PyGithub==1.58.2
PyGithub==2.5.0
python-dateutil==2.8.2
python-gitlab==3.9.0
python-pptx==0.6.23

View File

@@ -78,19 +78,19 @@ def generate_dummy_chunk(
for i in range(number_of_document_sets):
document_set_names.append(f"Document Set {i}")
user_emails: set[str | None] = set()
user_groups: set[str] = set()
external_user_emails: set[str] = set()
external_user_group_ids: set[str] = set()
user_emails: list[str | None] = []
user_groups: list[str] = []
external_user_emails: list[str] = []
external_user_group_ids: list[str] = []
for i in range(number_of_acl_entries):
user_emails.add(f"user_{i}@example.com")
user_groups.add(f"group_{i}")
external_user_emails.add(f"external_user_{i}@example.com")
external_user_group_ids.add(f"external_group_{i}")
user_emails.append(f"user_{i}@example.com")
user_groups.append(f"group_{i}")
external_user_emails.append(f"external_user_{i}@example.com")
external_user_group_ids.append(f"external_group_{i}")
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=DocumentAccess(
access=DocumentAccess.build(
user_emails=user_emails,
user_groups=user_groups,
external_user_emails=external_user_emails,

View File

@@ -0,0 +1,77 @@
import os
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.configs.constants import BlobType
from onyx.connectors.blob.connector import BlobStorageConnector
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import ACCEPTED_DOCUMENT_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import ACCEPTED_IMAGE_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import get_file_ext
@pytest.fixture
def blob_connector(request: pytest.FixtureRequest) -> BlobStorageConnector:
connector = BlobStorageConnector(
bucket_type=BlobType.S3, bucket_name="onyx-connector-tests"
)
connector.load_credentials(
{
"aws_access_key_id": os.environ["AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS"],
"aws_secret_access_key": os.environ[
"AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS"
],
}
)
return connector
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_blob_s3_connector(
mock_get_api_key: MagicMock, blob_connector: BlobStorageConnector
) -> None:
"""
Plain and document file types should be fully indexed.
Multimedia and unknown file types will be indexed by title only with one empty section.
This is intentional in order to allow searching by just the title even if we can't
index the file content.
"""
all_docs: list[Document] = []
document_batches = blob_connector.load_from_state()
for doc_batch in document_batches:
for doc in doc_batch:
all_docs.append(doc)
#
assert len(all_docs) == 19
for doc in all_docs:
section = doc.sections[0]
assert isinstance(section, TextSection)
file_extension = get_file_ext(doc.semantic_identifier)
if file_extension in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS:
assert len(section.text) > 0
continue
if file_extension in ACCEPTED_DOCUMENT_FILE_EXTENSIONS:
assert len(section.text) > 0
continue
if file_extension in ACCEPTED_IMAGE_FILE_EXTENSIONS:
assert len(section.text) == 0
continue
# unknown extension
assert len(section.text) == 0

View File

@@ -0,0 +1,54 @@
import os
import time
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.github.connector import GithubConnector
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
@pytest.fixture
def github_connector() -> GithubConnector:
connector = GithubConnector(
repo_owner="onyx-dot-app",
repositories="documentation",
include_prs=True,
include_issues=True,
)
connector.load_credentials(
{
"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"],
}
)
return connector
def test_github_connector_basic(github_connector: GithubConnector) -> None:
docs = load_all_docs_from_checkpoint_connector(
connector=github_connector,
start=0,
end=time.time(),
)
assert len(docs) > 0 # We expect at least one PR to exist
# Test the first document's structure
doc = docs[0]
# Verify basic document properties
assert doc.source == DocumentSource.GITHUB
assert doc.secondary_owners is None
assert doc.from_ingestion_api is False
assert doc.additional_info is None
# Verify GitHub-specific properties
assert "github.com" in doc.id # Should be a GitHub URL
assert doc.metadata is not None
assert "state" in doc.metadata
assert "merged" in doc.metadata
# Verify sections
assert len(doc.sections) == 1
section = doc.sections[0]
assert section.link == doc.id # Section link should match document ID
assert isinstance(section.text, str) # Should have some text content

View File

@@ -1,5 +1,6 @@
import json
import os
import resource
from collections.abc import Callable
import pytest
@@ -136,3 +137,22 @@ def google_drive_service_acct_connector_factory() -> (
return connector
return _connector_factory
@pytest.fixture(scope="session", autouse=True)
def set_resource_limits() -> None:
# the google sdk is aggressive about using up file descriptors and
# macos is stingy ... these tests will fail randomly unless the descriptor limit is raised
RLIMIT_MINIMUM = 2048
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
desired_soft = min(RLIMIT_MINIMUM, hard) # Pick your target here
print(f"Open file limit: soft={soft} hard={hard} soft_required={RLIMIT_MINIMUM}")
if soft < desired_soft:
print(f"Raising open file limit: {soft} -> {desired_soft}")
resource.setrlimit(resource.RLIMIT_NOFILE, (desired_soft, hard))
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
print(f"New open file limit: soft={soft} hard={hard}")
return

View File

@@ -58,6 +58,16 @@ SECTIONS_FOLDER_URL = (
"https://drive.google.com/drive/u/5/folders/1loe6XJ-pJxu9YYPv7cF3Hmz296VNzA33"
)
EXTERNAL_SHARED_FOLDER_URL = (
"https://drive.google.com/drive/folders/1sWC7Oi0aQGgifLiMnhTjvkhRWVeDa-XS"
)
EXTERNAL_SHARED_DOCS_IN_FOLDER = [
"https://docs.google.com/document/d/1Sywmv1-H6ENk2GcgieKou3kQHR_0te1mhIUcq8XlcdY"
]
EXTERNAL_SHARED_DOC_SINGLETON = (
"https://docs.google.com/document/d/11kmisDfdvNcw5LYZbkdPVjTOdj-Uc5ma6Jep68xzeeA"
)
SHARED_DRIVE_3_URL = "https://drive.google.com/drive/folders/0AJYm2K_I_vtNUk9PVA"
ADMIN_EMAIL = "admin@onyx-test.com"
@@ -161,10 +171,14 @@ def _get_expected_file_content(file_id: int) -> str:
return file_text_template.format(file_id)
def assert_retrieved_docs_match_expected(
def assert_expected_docs_in_retrieved_docs(
retrieved_docs: list[Document],
expected_file_ids: Sequence[int],
) -> None:
"""NOTE: as far as i can tell this does NOT assert for an exact match.
it only checks to see if that the expected file id's are IN the retrieved doc list
"""
expected_file_names = {
file_name_template.format(file_id) for file_id in expected_file_ids
}
@@ -175,7 +189,7 @@ def assert_retrieved_docs_match_expected(
retrieved_docs.sort(key=lambda x: x.semantic_identifier)
for doc in retrieved_docs:
print(f"doc.semantic_identifier: {doc.semantic_identifier}")
print(f"retrieved doc: doc.semantic_identifier={doc.semantic_identifier}")
# Filter out invalid prefixes to prevent different tests from interfering with each other
valid_retrieved_docs = [

View File

@@ -7,7 +7,7 @@ from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import (
assert_retrieved_docs_match_expected,
assert_expected_docs_in_retrieved_docs,
)
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
@@ -62,7 +62,7 @@ def test_include_all(
+ FOLDER_2_2_FILE_IDS
+ SECTIONS_FILE_IDS
)
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@@ -100,7 +100,7 @@ def test_include_shared_drives_only(
+ FOLDER_2_2_FILE_IDS
+ SECTIONS_FILE_IDS
)
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@@ -128,7 +128,7 @@ def test_include_my_drives_only(
# Should only get primary_admins My Drive because we are impersonating them
expected_file_ids = ADMIN_FILE_IDS + ADMIN_FOLDER_3_FILE_IDS
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@@ -161,7 +161,7 @@ def test_drive_one_only(
+ FOLDER_1_1_FILE_IDS
+ FOLDER_1_2_FILE_IDS
)
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@@ -198,7 +198,7 @@ def test_folder_and_shared_drive(
+ FOLDER_2_1_FILE_IDS
+ FOLDER_2_2_FILE_IDS
)
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@@ -241,7 +241,7 @@ def test_folders_only(
+ FOLDER_2_2_FILE_IDS
+ ADMIN_FOLDER_3_FILE_IDS
)
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@@ -271,7 +271,7 @@ def test_personal_folders_only(
retrieved_docs = load_all_docs(connector)
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)

View File

@@ -1,13 +1,23 @@
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from urllib.parse import urlparse
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import (
assert_retrieved_docs_match_expected,
assert_expected_docs_in_retrieved_docs,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
EXTERNAL_SHARED_DOC_SINGLETON,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
EXTERNAL_SHARED_DOCS_IN_FOLDER,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
EXTERNAL_SHARED_FOLDER_URL,
)
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
@@ -70,12 +80,40 @@ def test_include_all(
+ FOLDER_2_2_FILE_IDS
+ SECTIONS_FILE_IDS
)
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_include_shared_drives_only_with_size_threshold(
mock_get_api_key: MagicMock,
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_include_shared_drives_only_with_size_threshold")
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=True,
include_my_drives=False,
include_files_shared_with_me=False,
shared_folder_urls=None,
shared_drive_urls=None,
my_drive_emails=None,
)
# this threshold will skip one file
connector.size_threshold = 16384
retrieved_docs = load_all_docs(connector)
# 2 extra files from shared drive owned by non-admin and not shared with admin
assert len(retrieved_docs) == 52
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
@@ -94,6 +132,7 @@ def test_include_shared_drives_only(
shared_drive_urls=None,
my_drive_emails=None,
)
retrieved_docs = load_all_docs(connector)
# Should only get shared drives
@@ -108,7 +147,11 @@ def test_include_shared_drives_only(
+ FOLDER_2_2_FILE_IDS
+ SECTIONS_FILE_IDS
)
assert_retrieved_docs_match_expected(
# 2 extra files from shared drive owned by non-admin and not shared with admin
assert len(retrieved_docs) == 53
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@@ -142,7 +185,7 @@ def test_include_my_drives_only(
+ TEST_USER_2_FILE_IDS
+ TEST_USER_3_FILE_IDS
)
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@@ -176,7 +219,7 @@ def test_drive_one_only(
+ FOLDER_1_1_FILE_IDS
+ FOLDER_1_2_FILE_IDS
)
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@@ -214,7 +257,7 @@ def test_folder_and_shared_drive(
+ FOLDER_2_1_FILE_IDS
+ FOLDER_2_2_FILE_IDS
)
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@@ -257,12 +300,70 @@ def test_folders_only(
+ FOLDER_2_2_FILE_IDS
+ ADMIN_FOLDER_3_FILE_IDS
)
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
def test_shared_folder_owned_by_external_user(
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_shared_folder_owned_by_external_user")
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=False,
include_my_drives=False,
include_files_shared_with_me=False,
shared_drive_urls=None,
shared_folder_urls=EXTERNAL_SHARED_FOLDER_URL,
my_drive_emails=None,
)
retrieved_docs = load_all_docs(connector)
expected_docs = EXTERNAL_SHARED_DOCS_IN_FOLDER
assert len(retrieved_docs) == len(expected_docs) # 1 for now
assert expected_docs[0] in retrieved_docs[0].id
def test_shared_with_me(
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_shared_with_me")
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=False,
include_my_drives=True,
include_files_shared_with_me=True,
shared_drive_urls=None,
shared_folder_urls=None,
my_drive_emails=None,
)
retrieved_docs = load_all_docs(connector)
print(retrieved_docs)
expected_file_ids = (
ADMIN_FILE_IDS
+ ADMIN_FOLDER_3_FILE_IDS
+ TEST_USER_1_FILE_IDS
+ TEST_USER_2_FILE_IDS
+ TEST_USER_3_FILE_IDS
)
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
retrieved_ids = {urlparse(doc.id).path.split("/")[-2] for doc in retrieved_docs}
for id in retrieved_ids:
print(id)
assert EXTERNAL_SHARED_DOC_SINGLETON.split("/")[-1] in retrieved_ids
assert EXTERNAL_SHARED_DOCS_IN_FOLDER[0].split("/")[-1] in retrieved_ids
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
@@ -288,7 +389,7 @@ def test_specific_emails(
retrieved_docs = load_all_docs(connector)
expected_file_ids = TEST_USER_1_FILE_IDS + TEST_USER_3_FILE_IDS
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@@ -318,7 +419,7 @@ def get_specific_folders_in_my_drive(
retrieved_docs = load_all_docs(connector)
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)

View File

@@ -5,7 +5,7 @@ from unittest.mock import patch
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import (
assert_retrieved_docs_match_expected,
assert_expected_docs_in_retrieved_docs,
)
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
@@ -50,7 +50,7 @@ def test_all(
+ ADMIN_FOLDER_3_FILE_IDS
+ list(range(0, 2))
)
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@@ -83,7 +83,7 @@ def test_shared_drives_only(
+ FOLDER_1_1_FILE_IDS
+ FOLDER_1_2_FILE_IDS
)
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@@ -114,7 +114,7 @@ def test_shared_with_me_only(
ADMIN_FOLDER_3_FILE_IDS
+ list(range(0, 2))
)
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@@ -142,7 +142,7 @@ def test_my_drive_only(
# These are the files from my drive
expected_file_ids = TEST_USER_1_FILE_IDS
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@@ -172,7 +172,7 @@ def test_shared_my_drive_folder(
# this is a folder from admin's drive that is shared with me
ADMIN_FOLDER_3_FILE_IDS
)
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
@@ -199,7 +199,7 @@ def test_shared_drive_folder(
retrieved_docs = load_all_docs(connector)
expected_file_ids = FOLDER_1_FILE_IDS + FOLDER_1_1_FILE_IDS + FOLDER_1_2_FILE_IDS
assert_retrieved_docs_match_expected(
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)

View File

@@ -2,12 +2,14 @@ import json
import os
import time
from pathlib import Path
from typing import cast
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.zendesk.connector import ZendeskConnector
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
def load_test_data(file_name: str = "test_zendesk_data.json") -> dict[str, dict]:
@@ -50,7 +52,7 @@ def get_credentials() -> dict[str, str]:
def test_zendesk_connector_basic(
request: pytest.FixtureRequest, connector_fixture: str
) -> None:
connector = request.getfixturevalue(connector_fixture)
connector = cast(ZendeskConnector, request.getfixturevalue(connector_fixture))
test_data = load_test_data()
all_docs: list[Document] = []
target_test_doc_id: str
@@ -61,12 +63,11 @@ def test_zendesk_connector_basic(
target_doc: Document | None = None
for doc_batch in connector.poll_source(0, time.time()):
for doc in doc_batch:
all_docs.append(doc)
if doc.id == target_test_doc_id:
target_doc = doc
print(f"target_doc {target_doc}")
for doc in load_all_docs_from_checkpoint_connector(connector, 0, time.time()):
all_docs.append(doc)
if doc.id == target_test_doc_id:
target_doc = doc
print(f"target_doc {target_doc}")
assert len(all_docs) > 0, "No documents were retrieved from the connector"
assert (
@@ -111,8 +112,10 @@ def test_zendesk_connector_basic(
def test_zendesk_connector_slim(zendesk_article_connector: ZendeskConnector) -> None:
# Get full doc IDs
all_full_doc_ids = set()
for doc_batch in zendesk_article_connector.load_from_state():
all_full_doc_ids.update([doc.id for doc in doc_batch])
for doc in load_all_docs_from_checkpoint_connector(
zendesk_article_connector, 0, time.time()
):
all_full_doc_ids.add(doc.id)
# Get slim doc IDs
all_slim_doc_ids = set()

View File

@@ -6,7 +6,7 @@ API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}"
MAX_DELAY = 45
MAX_DELAY = 60
GENERAL_HEADERS = {"Content-Type": "application/json"}

View File

@@ -5,6 +5,7 @@ import requests
from requests.models import Response
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SavedSearchDoc
from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
@@ -97,17 +98,24 @@ class ChatSessionManager:
for data in response_data:
if "rephrased_query" in data:
analyzed.rephrased_query = data["rephrased_query"]
elif "tool_name" in data:
if "tool_name" in data:
analyzed.tool_name = data["tool_name"]
analyzed.tool_result = (
data.get("tool_result")
if analyzed.tool_name == "run_search"
else None
)
elif "relevance_summaries" in data:
if "relevance_summaries" in data:
analyzed.relevance_summaries = data["relevance_summaries"]
elif "answer_piece" in data and data["answer_piece"]:
if "answer_piece" in data and data["answer_piece"]:
analyzed.full_message += data["answer_piece"]
if "top_documents" in data:
assert (
analyzed.top_documents is None
), "top_documents should only be set once"
analyzed.top_documents = [
SavedSearchDoc(**doc) for doc in data["top_documents"]
]
return analyzed

View File

@@ -9,7 +9,9 @@ from requests import HTTPError
from onyx.auth.schemas import UserRole
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.server.documents.models import PaginatedReturn
from onyx.server.manage.models import UserInfo
from onyx.server.models import FullUserSnapshot
from onyx.server.models import InvitedUserSnapshot
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
@@ -123,10 +125,15 @@ class UserManager:
user_to_set: DATestUser,
target_role: UserRole,
user_performing_action: DATestUser,
explicit_override: bool = False,
) -> DATestUser:
response = requests.patch(
url=f"{API_SERVER_URL}/manage/set-user-role",
json={"user_email": user_to_set.email, "new_role": target_role.value},
json={
"user_email": user_to_set.email,
"new_role": target_role.value,
"explicit_override": explicit_override,
},
headers=user_performing_action.headers,
)
response.raise_for_status()
@@ -240,3 +247,69 @@ class UserManager:
total_items=data["total_items"],
)
return paginated_result
@staticmethod
def invite_user(
user_to_invite_email: str, user_performing_action: DATestUser
) -> None:
"""Invite a user by email to join the organization.
Args:
user_to_invite_email: Email of the user to invite
user_performing_action: User with admin permissions performing the invitation
"""
response = requests.put(
url=f"{API_SERVER_URL}/manage/admin/users",
headers=user_performing_action.headers,
json={"emails": [user_to_invite_email]},
)
response.raise_for_status()
@staticmethod
def accept_invitation(tenant_id: str, user_performing_action: DATestUser) -> None:
"""Accept an invitation to join the organization.
Args:
tenant_id: ID of the tenant/organization to accept invitation for
user_performing_action: User accepting the invitation
"""
response = requests.post(
url=f"{API_SERVER_URL}/tenants/users/invite/accept",
headers=user_performing_action.headers,
json={"tenant_id": tenant_id},
)
response.raise_for_status()
@staticmethod
def get_invited_users(
user_performing_action: DATestUser,
) -> list[InvitedUserSnapshot]:
"""Get a list of all invited users.
Args:
user_performing_action: User with admin permissions performing the action
Returns:
List of invited user snapshots
"""
response = requests.get(
url=f"{API_SERVER_URL}/manage/users/invited",
headers=user_performing_action.headers,
)
response.raise_for_status()
return [InvitedUserSnapshot(**user) for user in response.json()]
@staticmethod
def get_user_info(user_performing_action: DATestUser) -> UserInfo:
"""Get user info for the current user.
Args:
user_performing_action: User performing the action
"""
response = requests.get(
url=f"{API_SERVER_URL}/me",
headers=user_performing_action.headers,
)
response.raise_for_status()
return UserInfo(**response.json())

View File

@@ -10,6 +10,7 @@ from pydantic import Field
from onyx.auth.schemas import UserRole
from onyx.configs.constants import QAFeedbackType
from onyx.context.search.enums import RecencyBiasSetting
from onyx.context.search.models import SavedSearchDoc
from onyx.db.enums import AccessType
from onyx.server.documents.models import DocumentSource
from onyx.server.documents.models import IndexAttemptSnapshot
@@ -157,7 +158,7 @@ class StreamedResponse(BaseModel):
full_message: str = ""
rephrased_query: str | None = None
tool_name: str | None = None
top_documents: list[dict[str, Any]] | None = None
top_documents: list[SavedSearchDoc] | None = None
relevance_summaries: list[dict[str, Any]] | None = None
tool_result: Any | None = None
user: str | None = None

View File

@@ -0,0 +1,70 @@
from onyx.db.models import UserRole
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
INVITED_BASIC_USER = "basic_user"
INVITED_BASIC_USER_EMAIL = "basic_user@test.com"
def test_user_invitation_flow(reset_multitenant: None) -> None:
# Create first user (admin)
admin_user: DATestUser = UserManager.create(name="admin")
assert UserManager.is_role(admin_user, UserRole.ADMIN)
# Create second user
invited_user: DATestUser = UserManager.create(name="admin_invited")
assert UserManager.is_role(invited_user, UserRole.ADMIN)
# Admin user invites the previously registered and non-registered user
UserManager.invite_user(invited_user.email, admin_user)
UserManager.invite_user(INVITED_BASIC_USER_EMAIL, admin_user)
invited_basic_user: DATestUser = UserManager.create(
name=INVITED_BASIC_USER, email=INVITED_BASIC_USER_EMAIL
)
assert UserManager.is_role(invited_basic_user, UserRole.BASIC)
# Verify the user is in the invited users list
invited_users = UserManager.get_invited_users(admin_user)
assert invited_user.email in [
user.email for user in invited_users
], f"User {invited_user.email} not found in invited users list"
# Get user info to check tenant information
user_info = UserManager.get_user_info(invited_user)
# Extract the tenant_id from the invitation
invited_tenant_id = (
user_info.tenant_info.invitation.tenant_id
if user_info.tenant_info and user_info.tenant_info.invitation
else None
)
assert invited_tenant_id is not None, "Expected to find an invitation tenant_id"
UserManager.accept_invitation(invited_tenant_id, invited_user)
# Get updated user info after accepting invitation
updated_user_info = UserManager.get_user_info(invited_user)
# Verify the user is no longer in the invited users list
updated_invited_users = UserManager.get_invited_users(admin_user)
assert invited_user.email not in [
user.email for user in updated_invited_users
], f"User {invited_user.email} should not be in invited users list after accepting"
# Verify the user has BASIC role in the organization
assert (
updated_user_info.role == UserRole.BASIC
), f"Expected user to have BASIC role, but got {updated_user_info.role}"
# Verify user is in the organization
user_page = UserManager.get_user_page(
user_performing_action=admin_user, role_filter=[UserRole.BASIC]
)
# Check if the invited user is in the list of users with BASIC role
invited_user_emails = [user.email for user in user_page.items]
assert invited_user.email in invited_user_emails, (
f"User {invited_user.email} not found in the list of basic users "
f"in the organization. Available users: {invited_user_emails}"
)

View File

@@ -0,0 +1,97 @@
import os
import pytest
import requests
from onyx.auth.schemas import UserRole
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="SAML tests are enterprise only",
)
def test_saml_user_conversion(reset: None) -> None:
"""
Test that SAML login correctly converts users with non-authenticated roles
(SLACK_USER or EXT_PERM_USER) to authenticated roles (BASIC).
This test:
1. Creates an admin and a regular user
2. Changes the regular user's role to EXT_PERM_USER
3. Simulates a SAML login by calling the test endpoint
4. Verifies the user's role is converted to BASIC
This tests the fix that ensures users with non-authenticated roles (SLACK_USER or EXT_PERM_USER)
are properly converted to authenticated roles during SAML login.
"""
# Create an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(email="admin@onyx-test.com")
# Create a regular user that we'll convert to EXT_PERM_USER
test_user_email = "ext_perm_user@example.com"
test_user = UserManager.create(email=test_user_email)
# Verify the user was created with BASIC role initially
assert UserManager.is_role(test_user, UserRole.BASIC)
# Change the user's role to EXT_PERM_USER using the UserManager
UserManager.set_role(
user_to_set=test_user,
target_role=UserRole.EXT_PERM_USER,
user_performing_action=admin_user,
explicit_override=True,
)
# Verify the user has EXT_PERM_USER role now
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
# Simulate SAML login by calling the test endpoint
response = requests.post(
f"{API_SERVER_URL}/manage/users/test-upsert-user",
json={"email": test_user_email},
headers=admin_user.headers, # Use admin headers for authorization
)
response.raise_for_status()
# Verify the response indicates the role changed to BASIC
user_data = response.json()
assert user_data["role"] == UserRole.BASIC.value
# Verify user role was changed in the database
assert UserManager.is_role(test_user, UserRole.BASIC)
# Do the same test with SLACK_USER
slack_user_email = "slack_user@example.com"
slack_user = UserManager.create(email=slack_user_email)
# Verify the user was created with BASIC role initially
assert UserManager.is_role(slack_user, UserRole.BASIC)
# Change the user's role to SLACK_USER
UserManager.set_role(
user_to_set=slack_user,
target_role=UserRole.SLACK_USER,
user_performing_action=admin_user,
explicit_override=True,
)
# Verify the user has SLACK_USER role
assert UserManager.is_role(slack_user, UserRole.SLACK_USER)
# Simulate SAML login again
response = requests.post(
f"{API_SERVER_URL}/manage/users/test-upsert-user",
json={"email": slack_user_email},
headers=admin_user.headers,
)
response.raise_for_status()
# Verify the response indicates the role changed to BASIC
user_data = response.json()
assert user_data["role"] == UserRole.BASIC.value
# Verify the user's role was changed in the database
assert UserManager.is_role(slack_user, UserRole.BASIC)

View File

@@ -5,6 +5,7 @@ This file contains tests for the following:
- updates the document sets and user groups to remove the connector
- Ensure that deleting a connector that is part of an overlapping document set and/or user group works as expected
"""
import os
from uuid import uuid4
from sqlalchemy.orm import Session
@@ -32,6 +33,13 @@ from tests.integration.common_utils.vespa import vespa_fixture
def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
user_group_1: DATestUserGroup
user_group_2: DATestUserGroup
is_ee = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# create api key
@@ -78,16 +86,17 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
print("Document sets created and synced")
# create user groups
user_group_1: DATestUserGroup = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
user_group_2: DATestUserGroup = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id, cc_pair_2.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(user_performing_action=admin_user)
if is_ee:
# create user groups
user_group_1 = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
user_group_2 = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id, cc_pair_2.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(user_performing_action=admin_user)
# inject a finished index attempt and index attempt error (exercises foreign key errors)
with Session(get_sqlalchemy_engine()) as db_session:
@@ -147,12 +156,13 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
)
# Update local records to match the database for later comparison
user_group_1.cc_pair_ids = []
user_group_2.cc_pair_ids = [cc_pair_2.id]
doc_set_1.cc_pair_ids = []
doc_set_2.cc_pair_ids = [cc_pair_2.id]
cc_pair_1.groups = []
cc_pair_2.groups = [user_group_2.id]
if is_ee:
cc_pair_2.groups = [user_group_2.id]
else:
cc_pair_2.groups = []
CCPairManager.wait_for_deletion_completion(
cc_pair_id=cc_pair_1.id, user_performing_action=admin_user
@@ -168,11 +178,15 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
verify_deleted=True,
)
cc_pair_2_group_name_expected = []
if is_ee:
cc_pair_2_group_name_expected = [user_group_2.name]
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
doc_set_names=[doc_set_2.name],
group_names=[user_group_2.name],
group_names=cc_pair_2_group_name_expected,
doc_creating_user=admin_user,
verify_deleted=False,
)
@@ -193,15 +207,19 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
user_performing_action=admin_user,
)
# validate user groups
UserGroupManager.verify(
user_group=user_group_1,
user_performing_action=admin_user,
)
UserGroupManager.verify(
user_group=user_group_2,
user_performing_action=admin_user,
)
if is_ee:
user_group_1.cc_pair_ids = []
user_group_2.cc_pair_ids = [cc_pair_2.id]
# validate user groups
UserGroupManager.verify(
user_group=user_group_1,
user_performing_action=admin_user,
)
UserGroupManager.verify(
user_group=user_group_2,
user_performing_action=admin_user,
)
def test_connector_deletion_for_overlapping_connectors(
@@ -210,6 +228,13 @@ def test_connector_deletion_for_overlapping_connectors(
"""Checks to make sure that connectors with overlapping documents work properly. Specifically, that the overlapping
document (1) still exists and (2) has the right document set / group post-deletion of one of the connectors.
"""
user_group_1: DATestUserGroup
user_group_2: DATestUserGroup
is_ee = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# create api key
@@ -281,47 +306,48 @@ def test_connector_deletion_for_overlapping_connectors(
doc_creating_user=admin_user,
)
# create a user group and attach it to connector 1
user_group_1: DATestUserGroup = UserGroupManager.create(
name="Test User Group 1",
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1],
user_performing_action=admin_user,
)
cc_pair_1.groups = [user_group_1.id]
if is_ee:
# create a user group and attach it to connector 1
user_group_1 = UserGroupManager.create(
name="Test User Group 1",
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1],
user_performing_action=admin_user,
)
cc_pair_1.groups = [user_group_1.id]
print("User group 1 created and synced")
print("User group 1 created and synced")
# create a user group and attach it to connector 2
user_group_2: DATestUserGroup = UserGroupManager.create(
name="Test User Group 2",
cc_pair_ids=[cc_pair_2.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_2],
user_performing_action=admin_user,
)
cc_pair_2.groups = [user_group_2.id]
# create a user group and attach it to connector 2
user_group_2 = UserGroupManager.create(
name="Test User Group 2",
cc_pair_ids=[cc_pair_2.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_2],
user_performing_action=admin_user,
)
cc_pair_2.groups = [user_group_2.id]
print("User group 2 created and synced")
print("User group 2 created and synced")
# verify vespa document is in the user group
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_1,
group_names=[user_group_1.name, user_group_2.name],
doc_creating_user=admin_user,
)
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
group_names=[user_group_1.name, user_group_2.name],
doc_creating_user=admin_user,
)
# verify vespa document is in the user group
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_1,
group_names=[user_group_1.name, user_group_2.name],
doc_creating_user=admin_user,
)
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
group_names=[user_group_1.name, user_group_2.name],
doc_creating_user=admin_user,
)
# delete connector 1
CCPairManager.pause_cc_pair(
@@ -354,11 +380,15 @@ def test_connector_deletion_for_overlapping_connectors(
# verify the document is not in any document sets
# verify the document is only in user group 2
group_names_expected = []
if is_ee:
group_names_expected = [user_group_2.name]
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
doc_set_names=[],
group_names=[user_group_2.name],
group_names=group_names_expected,
doc_creating_user=admin_user,
verify_deleted=False,
)

View File

@@ -1,3 +1,6 @@
import os
import pytest
import requests
from onyx.configs.constants import MessageType
@@ -12,6 +15,10 @@ from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history is enterprise only",
)
def test_all_stream_chat_message_objects_outputs(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -1,5 +1,7 @@
import json
import os
import pytest
import requests
from onyx.configs.constants import MessageType
@@ -16,10 +18,11 @@ from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
def test_send_message_simple_with_history(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history tests are enterprise only",
)
def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -> None:
# create connectors
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=admin_user,
@@ -53,18 +56,22 @@ def test_send_message_simple_with_history(reset: None) -> None:
response_json = response.json()
# Check that the top document is the correct document
assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id
# assert that the metadata is correct
for doc in cc_pair_1.documents:
found_doc = next(
(x for x in response_json["simple_search_docs"] if x["id"] == doc.id), None
(x for x in response_json["top_documents"] if x["document_id"] == doc.id),
None,
)
assert found_doc
assert found_doc["metadata"]["document_id"] == doc.id
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history tests are enterprise only",
)
def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
@@ -154,6 +161,10 @@ def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) ->
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[2].id
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history tests are enterprise only",
)
def test_send_message_simple_with_history_strict_json(
new_admin_user: DATestUser | None,
) -> None:

View File

@@ -2,6 +2,8 @@
This file takes the happy path to adding a curator to a user group and then tests
the permissions of the curator manipulating connector-credential pairs.
"""
import os
import pytest
from requests.exceptions import HTTPError
@@ -15,6 +17,10 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator and User Group tests are enterprise only",
)
def test_cc_pair_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -2,6 +2,8 @@
This file takes the happy path to adding a curator to a user group and then tests
the permissions of the curator manipulating connectors.
"""
import os
import pytest
from requests.exceptions import HTTPError
@@ -13,6 +15,10 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator and user group tests are enterprise only",
)
def test_connector_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

Some files were not shown because too many files have changed in this diff Show More