Compare commits

..

12 Commits

Author SHA1 Message Date
pablonyx
8e71216607 k 2025-03-27 18:35:48 -07:00
evan-danswer
a123661c92 fixed shared folder issue (#4371)
* fixed shared folder issue

* fix existing tests

* default allow files shared with me for service account
2025-03-27 23:39:52 +00:00
pablonyx
c554889baf Fix actions link (#4374) 2025-03-27 16:39:35 -07:00
rkuo-danswer
f08fa878a6 refactor file extension checking and add test for blob s3 (#4369)
* refactor file extension checking and add test for blob s3

* code review

* fix checking ext

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 18:57:44 +00:00
pablonyx
d307534781 add some debug logging (#4328) 2025-03-27 11:49:32 -07:00
rkuo-danswer
6f54791910 adjust some vars in real time (#4365)
* adjust some vars in real time

* some sanity checking

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 17:30:08 +00:00
pablonyx
0d5497bb6b Add multi-tenant user invitation flow test (#4360) 2025-03-27 09:53:15 -07:00
Chris Weaver
7648627503 Save all logs + add log persistence to most Onyx-owned containers (#4368)
* Save all logs + add log persistence to most Onyx-owned containers

* Separate volumes for each container

* Small fixes
2025-03-26 22:25:39 -07:00
pablonyx
927554d5ca slight robustification (#4367) 2025-03-27 03:23:36 +00:00
pablonyx
7dcec6caf5 Fix session touching (#4363)
* fix session touching

* Revert "fix session touching"

This reverts commit c473d5c9a2.

* Revert "Revert "fix session touching""

This reverts commit 26a71d40b6.

* update

* quick nit
2025-03-27 01:18:46 +00:00
rkuo-danswer
036648146d possible fix for confluence query filter (#4280)
* possible fix for confluence query filter

* nuke the attachment filter query ... it doesn't work!

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 00:35:14 +00:00
rkuo-danswer
2aa4697ac8 permission sync runs so often that it starves out other tasks if run at high priority (#4364)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-27 00:22:53 +00:00
43 changed files with 729 additions and 204 deletions

View File

@@ -9,6 +9,10 @@ on:
- cron: "0 16 * * *"
env:
# AWS
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}

View File

@@ -25,6 +25,10 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
#####
# Auto Permission Sync
#####
DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# In seconds, default is 5 minutes
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
@@ -39,6 +43,7 @@ CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
@@ -72,6 +77,13 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
)
GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("SLACK_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# The posthog client does not accept empty API keys or hosts however it fails silently
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app

View File

@@ -3,6 +3,8 @@ from collections.abc import Generator
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
@@ -66,13 +68,13 @@ GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all doc permissions every 5 minutes
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
DocumentSource.SLACK: 5 * 60,
DocumentSource.SLACK: SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
}
# If nothing is specified here, we run the doc_sync every time the celery beat runs
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all group permissions every 30 minutes
DocumentSource.GOOGLE_DRIVE: 5 * 60,
DocumentSource.GOOGLE_DRIVE: GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY,
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
}

View File

@@ -70,6 +70,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
"""
Add users to a tenant with proper transaction handling.
Checks if users already have a tenant mapping to avoid duplicates.
If a user already has an active mapping to any tenant, the new mapping will be added as inactive.
"""
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
@@ -88,9 +89,25 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
.first()
)
# If user already has an active mapping, add this one as inactive
if not existing_mapping:
# Only add if mapping doesn't exist
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
# Check if the user already has an active mapping to any tenant
has_active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.first()
)
db_session.add(
UserTenantMapping(
email=email,
tenant_id=tenant_id,
active=False if has_active_mapping else True,
)
)
# Commit the transaction
db_session.commit()

View File

@@ -1,6 +1,5 @@
from datetime import timedelta
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
@@ -10,12 +9,10 @@ from celery.utils.log import get_task_logger
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import SqlEngine
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
@@ -141,8 +138,6 @@ class DynamicTenantScheduler(PersistentScheduler):
"""Only updates the actual beat schedule on the celery app when it changes"""
do_update = False
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
task_logger.debug("_try_updating_schedule starting")
tenant_ids = get_all_tenant_ids()
@@ -152,16 +147,7 @@ class DynamicTenantScheduler(PersistentScheduler):
current_schedule = self.schedule.items()
# get potential new state
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
if beat_multiplier_raw is not None:
try:
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
beat_multiplier = float(beat_multiplier_bytes.decode())
except ValueError:
task_logger.error(
f"Invalid beat_multiplier value: {beat_multiplier_raw}"
)
beat_multiplier = OnyxRuntime.get_beat_multiplier()
new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)

View File

@@ -14,7 +14,7 @@ logger = setup_logger()
# Only set up memory monitoring in container environment
if is_running_in_container():
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_DIR = "/var/log/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files

View File

@@ -21,6 +21,7 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# we have a better implementation (backpressure, etc)
# Note that DynamicTenantScheduler can adjust the runtime value for this via Redis
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT = 1.0
# tasks that run in either self-hosted on cloud
beat_task_templates: list[dict] = []

View File

@@ -389,6 +389,8 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
credential_id_to_delete: int | None = None
connector_id_to_delete: int | None = None
if not cc_pair:
task_logger.warning(
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
@@ -443,26 +445,35 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
)
# Store IDs before potentially expiring cc_pair
connector_id_to_delete = cc_pair.connector_id
credential_id_to_delete = cc_pair.credential_id
# Explicitly delete document by connector credential pair records before deleting the connector
# This is needed because connector_id is a primary key in that table and cascading deletes won't work
delete_all_documents_by_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
connector_id=connector_id_to_delete,
credential_id=credential_id_to_delete,
)
# Flush to ensure document deletion happens before connector deletion
db_session.flush()
# Expire the cc_pair to ensure SQLAlchemy doesn't try to manage its state
# related to the deleted DocumentByConnectorCredentialPair during commit
db_session.expire(cc_pair)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
connector_id=connector_id_to_delete,
credential_id=credential_id_to_delete,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=cc_pair.connector_id,
connector_id=connector_id_to_delete,
)
if not connector or not len(connector.credentials):
task_logger.info(
@@ -495,15 +506,15 @@ def monitor_connector_deletion_taskset(
task_logger.exception(
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
)
raise e
task_logger.info(
f"Connector deletion succeeded: "
f"cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"connector={connector_id_to_delete} "
f"credential={credential_id_to_delete} "
f"docs_deleted={fence_data.num_tasks}"
)
@@ -553,7 +564,7 @@ def validate_connector_deletion_fences(
def validate_connector_deletion_fence(
tenant_id: str,
key_bytes: bytes,
queued_tasks: set[str],
queued_upsert_tasks: set[str],
r: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
@@ -640,7 +651,7 @@ def validate_connector_deletion_fence(
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_tasks:
if member_str in queued_upsert_tasks:
continue
tasks_not_in_celery += 1

View File

@@ -17,6 +17,7 @@ from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.document import upsert_document_external_perms
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
@@ -63,6 +64,7 @@ from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyn
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.server.utils import make_short_id
from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import format_error_for_logging
@@ -106,9 +108,10 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
if not source_sync_period:
return True
source_sync_period = DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
source_sync_period *= int(OnyxRuntime.get_doc_permission_sync_multiplier())
# If the last sync is greater than the full fetch period, we run the sync
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)
@@ -286,7 +289,7 @@ def try_creating_permissions_sync_task(
),
queue=OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.HIGH,
priority=OnyxCeleryPriority.MEDIUM,
)
# fill in the celery task id

View File

@@ -271,7 +271,7 @@ def try_creating_external_group_sync_task(
),
queue=OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC,
task_id=custom_task_id,
priority=OnyxCeleryPriority.HIGH,
priority=OnyxCeleryPriority.MEDIUM,
)
payload.celery_task_id = result.id

View File

@@ -72,6 +72,7 @@ from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_utils import is_fence
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
@@ -401,7 +402,11 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
logger.warning(f"Adding {key_bytes} to the lookup table.")
redis_client.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
redis_client.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
redis_client.set(
OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE,
1,
ex=OnyxRuntime.get_build_fence_lookup_table_interval(),
)
# 1/3: KICKOFF

View File

@@ -73,6 +73,7 @@ from onyx.db.chat import get_or_create_root_message
from onyx.db.chat import reserve_message_id
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
from onyx.db.chat import update_chat_session_updated_at_timestamp
from onyx.db.engine import get_session_context_manager
from onyx.db.milestone import check_multi_assistant_milestone
from onyx.db.milestone import create_milestone_if_not_exists
@@ -1069,6 +1070,8 @@ def stream_chat_message_objects(
prev_message = next_answer_message
logger.debug("Committing messages")
# Explicitly update the timestamp on the chat session
update_chat_session_updated_at_timestamp(chat_session_id, db_session)
db_session.commit() # actually save user / assistant message
yield AgenticMessageResponseIDInfo(agentic_message_ids=agentic_message_ids)

View File

@@ -382,6 +382,7 @@ ONYX_CLOUD_TENANT_ID = "cloud"
# the redis namespace for runtime variables
ONYX_CLOUD_REDIS_RUNTIME = "runtime"
CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT = 600
class OnyxCeleryTask:

View File

@@ -87,7 +87,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
credentials.get(key)
for key in ["aws_access_key_id", "aws_secret_access_key"]
):
raise ConnectorMissingCredentialError("Google Cloud Storage")
raise ConnectorMissingCredentialError("Amazon S3")
session = boto3.Session(
aws_access_key_id=credentials["aws_access_key_id"],

View File

@@ -65,20 +65,6 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
_SLIM_DOC_BATCH_SIZE = 5000
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
"gif",
"mp4",
"mov",
"mp3",
"wav",
]
_FULL_EXTENSION_FILTER_STRING = "".join(
[
f" and title!~'*.{extension}'"
for extension in _ATTACHMENT_EXTENSIONS_TO_FILTER_OUT
]
)
ONE_HOUR = 3600
@@ -209,7 +195,6 @@ class ConfluenceConnector(
def _construct_attachment_query(self, confluence_page_id: str) -> str:
attachment_query = f"type=attachment and container='{confluence_page_id}'"
attachment_query += self.cql_label_filter
attachment_query += _FULL_EXTENSION_FILTER_STRING
return attachment_query
def _get_comment_string_for_page_id(self, page_id: str) -> str:
@@ -374,11 +359,13 @@ class ConfluenceConnector(
if not validate_attachment_filetype(
attachment,
):
logger.info(f"Skipping attachment: {attachment['title']}")
continue
logger.info(f"Processing attachment: {attachment['title']}")
# Attempt to get textual content or image summarization:
try:
logger.info(f"Processing attachment: {attachment['title']}")
response = convert_attachment_to_content(
confluence_client=self.confluence_client,
attachment=attachment,

View File

@@ -28,8 +28,9 @@ from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import is_text_file_extension
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.extract_file_text import read_text_file
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
@@ -69,7 +70,9 @@ def _process_egnyte_file(
file_name = file_metadata["name"]
extension = get_file_ext(file_name)
if not is_valid_file_ext(extension):
if not is_accepted_file_ext(
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
):
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
return None

View File

@@ -22,8 +22,9 @@ from onyx.db.engine import get_session_with_current_tenant
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
@@ -51,7 +52,7 @@ def _read_files_and_metadata(
file_content, ignore_dirs=True
):
yield os.path.join(directory_path, file_info.filename), subfile, metadata
elif is_valid_file_ext(extension):
elif is_accepted_file_ext(extension, OnyxExtensionType.All):
yield file_name, file_content, metadata
else:
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
@@ -122,7 +123,7 @@ def _process_file(
logger.warning(f"No file record found for '{file_name}' in PG; skipping.")
return []
if not is_valid_file_ext(extension):
if not is_accepted_file_ext(extension, OnyxExtensionType.All):
logger.warning(
f"Skipping file '{file_name}' with unrecognized extension '{extension}'"
)

View File

@@ -28,7 +28,9 @@ from onyx.connectors.google_drive.doc_conversion import (
)
from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from onyx.connectors.google_drive.file_retrieval import (
get_all_files_in_my_drive_and_shared,
)
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
from onyx.connectors.google_drive.models import DriveRetrievalStage
@@ -86,13 +88,18 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
def _convert_single_file(
creds: Any,
primary_admin_email: str,
allow_images: bool,
size_threshold: int,
retriever_email: str,
file: dict[str, Any],
) -> Document | ConnectorFailure | None:
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
# We used to always get the user email from the file owners when available,
# but this was causing issues with shared folders where the owner was not included in the service account
# now we use the email of the account that successfully listed the file. Leaving this in case we end up
# wanting to retry with file owners and/or admin email at some point.
# user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_email = retriever_email
# Only construct these services when needed
user_drive_service = lazy_eval(
lambda: get_drive_service(creds, user_email=user_email)
@@ -450,10 +457,11 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
logger.info(f"Getting all files in my drive as '{user_email}'")
yield from add_retrieval_info(
get_all_files_in_my_drive(
get_all_files_in_my_drive_and_shared(
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
is_slim=is_slim,
include_shared_with_me=self.include_files_shared_with_me,
start=curr_stage.completed_until if resuming else start,
end=end,
),
@@ -916,20 +924,28 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
self.allow_images,
self.size_threshold,
)
# Fetch files in batches
batches_complete = 0
files_batch: list[GoogleDriveFileType] = []
files_batch: list[RetrievedDriveFile] = []
def _yield_batch(
files_batch: list[GoogleDriveFileType],
files_batch: list[RetrievedDriveFile],
) -> Iterator[Document | ConnectorFailure]:
nonlocal batches_complete
# Process the batch using run_functions_tuples_in_parallel
func_with_args = [(convert_func, (file,)) for file in files_batch]
func_with_args = [
(
convert_func,
(
file.user_email,
file.drive_file,
),
)
for file in files_batch
]
results = cast(
list[Document | ConnectorFailure | None],
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
@@ -967,7 +983,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
)
continue
files_batch.append(retrieved_file.drive_file)
files_batch.append(retrieved_file)
if len(files_batch) < self.batch_size:
continue

View File

@@ -87,35 +87,17 @@ def _download_and_extract_sections_basic(
mime_type = file["mimeType"]
link = file.get("webViewLink", "")
try:
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
@@ -124,88 +106,100 @@ def _download_and_extract_sections_basic(
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to download {file_name}")
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
# Process based on mime type
if mime_type == "text/plain":
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response))
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to download {file_name}")
return []
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
# Process based on mime type
if mime_type == "text/plain":
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response,
file_name=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
]
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response,
file_name=file_id,
display_name=file_name,
media_type=mime_type,
image_data=img_data,
file_name=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
]
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_name=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
else:
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
except Exception as e:
logger.error(f"Error processing file {file_name}: {e}")
return []
else:
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
def convert_drive_item_to_document(

View File

@@ -214,10 +214,11 @@ def get_files_in_shared_drive(
yield file
def get_all_files_in_my_drive(
def get_all_files_in_my_drive_and_shared(
service: GoogleDriveService,
update_traversed_ids_func: Callable,
is_slim: bool,
include_shared_with_me: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
@@ -229,7 +230,8 @@ def get_all_files_in_my_drive(
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
folder_query += " and 'me' in owners"
if not include_shared_with_me:
folder_query += " and 'me' in owners"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
@@ -246,7 +248,8 @@ def get_all_files_in_my_drive(
# Then get the files
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += " and 'me' in owners"
if not include_shared_with_me:
file_query += " and 'me' in owners"
file_query += _generate_time_range_filter(start, end)
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,

View File

@@ -20,8 +20,8 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import VALID_FILE_EXTENSIONS
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -298,7 +298,7 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
elif (
is_valid_format
and file_extension in VALID_FILE_EXTENSIONS
and file_extension in ALL_ACCEPTED_FILE_EXTENSIONS
and can_download
):
# For documents, try to get the text content

View File

@@ -1089,3 +1089,20 @@ def log_agent_sub_question_results(
db_session.commit()
return None
def update_chat_session_updated_at_timestamp(
chat_session_id: UUID, db_session: Session
) -> None:
"""
Explicitly update the timestamp on a chat session without modifying other fields.
This is useful when adding messages to a chat session to reflect recent activity.
"""
# Direct SQL update to avoid loading the entire object if it's not already loaded
db_session.execute(
update(ChatSession)
.where(ChatSession.id == chat_session_id)
.values(time_updated=func.now())
)
# No commit - the caller is responsible for committing the transaction

View File

@@ -7,6 +7,8 @@ from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from email.parser import Parser as EmailParser
from enum import auto
from enum import IntFlag
from io import BytesIO
from pathlib import Path
from typing import Any
@@ -35,7 +37,7 @@ logger = setup_logger()
TEXT_SECTION_SEPARATOR = "\n\n"
PLAIN_TEXT_FILE_EXTENSIONS = [
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS = [
".txt",
".md",
".mdx",
@@ -49,7 +51,7 @@ PLAIN_TEXT_FILE_EXTENSIONS = [
".yaml",
]
VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [
ACCEPTED_DOCUMENT_FILE_EXTENSIONS = [
".pdf",
".docx",
".pptx",
@@ -57,12 +59,21 @@ VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [
".eml",
".epub",
".html",
]
ACCEPTED_IMAGE_FILE_EXTENSIONS = [
".png",
".jpg",
".jpeg",
".webp",
]
ALL_ACCEPTED_FILE_EXTENSIONS = (
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
+ ACCEPTED_DOCUMENT_FILE_EXTENSIONS
+ ACCEPTED_IMAGE_FILE_EXTENSIONS
)
IMAGE_MEDIA_TYPES = [
"image/png",
"image/jpeg",
@@ -70,8 +81,15 @@ IMAGE_MEDIA_TYPES = [
]
class OnyxExtensionType(IntFlag):
Plain = auto()
Document = auto()
Multimedia = auto()
All = Plain | Document | Multimedia
def is_text_file_extension(file_name: str) -> bool:
return any(file_name.endswith(ext) for ext in PLAIN_TEXT_FILE_EXTENSIONS)
return any(file_name.endswith(ext) for ext in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS)
def get_file_ext(file_path_or_name: str | Path) -> str:
@@ -83,8 +101,20 @@ def is_valid_media_type(media_type: str) -> bool:
return media_type in IMAGE_MEDIA_TYPES
def is_valid_file_ext(ext: str) -> bool:
return ext in VALID_FILE_EXTENSIONS
def is_accepted_file_ext(ext: str, ext_type: OnyxExtensionType) -> bool:
if ext_type & OnyxExtensionType.Plain:
if ext in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS:
return True
if ext_type & OnyxExtensionType.Document:
if ext in ACCEPTED_DOCUMENT_FILE_EXTENSIONS:
return True
if ext_type & OnyxExtensionType.Multimedia:
if ext in ACCEPTED_IMAGE_FILE_EXTENSIONS:
return True
return False
def is_text_file(file: IO[bytes]) -> bool:
@@ -382,6 +412,9 @@ def extract_file_text(
"""
Legacy function that returns *only text*, ignoring embedded images.
For backward-compatibility in code that only wants text.
NOTE: Ignoring seems to be defined as returning an empty string for files it can't
handle (such as images).
"""
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
".pdf": pdf_to_text,
@@ -405,7 +438,9 @@ def extract_file_text(
if extension is None:
extension = get_file_ext(file_name)
if is_valid_file_ext(extension):
if is_accepted_file_ext(
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
):
func = extension_to_function.get(extension, file_io_to_text)
file.seek(0)
return func(file)

View File

@@ -15,6 +15,7 @@ EXCLUDED_IMAGE_TYPES = [
"image/tiff",
"image/gif",
"image/svg+xml",
"image/avif",
]

View File

@@ -313,7 +313,7 @@ def bulk_invite_users(
detail=f"Invalid email address: {email} - {str(e)}",
)
if MULTI_TENANT and not DEV_MODE:
if MULTI_TENANT:
try:
fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning", "add_users_to_tenant", None
@@ -335,7 +335,7 @@ def bulk_invite_users(
except Exception as e:
logger.error(f"Error sending email invite to invited users: {e}")
if not MULTI_TENANT:
if not MULTI_TENANT or DEV_MODE:
return number_of_invited_users
# for billing purposes, write to the control plane about the number of new users
@@ -376,7 +376,7 @@ def remove_invited_user(
number_of_invited_users = write_invited_users(remaining_users)
try:
if MULTI_TENANT:
if MULTI_TENANT and not DEV_MODE:
fetch_ee_implementation_or_noop(
"onyx.server.tenants.billing", "register_tenant_users", None
)(tenant_id, get_total_users_count(db_session))

View File

@@ -1,10 +1,19 @@
import io
from typing import cast
from PIL import Image
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.celery.tasks.beat_schedule import (
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT,
)
from onyx.configs.constants import CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import ONYX_EMAILABLE_LOGO_MAX_DIM
from onyx.db.engine import get_session_with_shared_schema
from onyx.file_store.file_store import PostgresBackedFileStore
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.utils.file import FileWithMimeType
from onyx.utils.file import OnyxStaticFileManager
from onyx.utils.variable_functionality import (
@@ -87,3 +96,72 @@ class OnyxRuntime:
)
return OnyxRuntime._get_with_static_fallback(db_filename, STATIC_FILENAME)
@staticmethod
def get_beat_multiplier() -> float:
"""the beat multiplier is used to scale up or down the frequency of certain beat
tasks in the cloud. It has a significant effect on load and is useful to adjust
in real time."""
beat_multiplier: float = CLOUD_BEAT_MULTIPLIER_DEFAULT
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
if beat_multiplier_raw is not None:
try:
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
beat_multiplier = float(beat_multiplier_bytes.decode())
except ValueError:
pass
if beat_multiplier <= 0.0:
return 1.0
return beat_multiplier
@staticmethod
def get_doc_permission_sync_multiplier() -> float:
"""Permission syncs are a significant source of load / queueing in the cloud."""
value: float = CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
value_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:doc_permission_sync_multiplier")
if value_raw is not None:
try:
value_bytes = cast(bytes, value_raw)
value = float(value_bytes.decode())
except ValueError:
pass
if value <= 0.0:
return 1.0
return value
@staticmethod
def get_build_fence_lookup_table_interval() -> int:
"""We maintain an active fence table to make lookups of existing fences efficient.
However, reconstructing the table is expensive, so adjusting it in realtime is useful.
"""
interval: int = CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
interval_raw = r.get(
f"{ONYX_CLOUD_REDIS_RUNTIME}:build_fence_lookup_table_interval"
)
if interval_raw is not None:
try:
interval_bytes = cast(bytes, interval_raw)
interval = int(interval_bytes.decode())
except ValueError:
pass
if interval <= 0.0:
return CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
return interval

View File

@@ -0,0 +1,77 @@
import os
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.configs.constants import BlobType
from onyx.connectors.blob.connector import BlobStorageConnector
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import ACCEPTED_DOCUMENT_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import ACCEPTED_IMAGE_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import get_file_ext
@pytest.fixture
def blob_connector(request: pytest.FixtureRequest) -> BlobStorageConnector:
connector = BlobStorageConnector(
bucket_type=BlobType.S3, bucket_name="onyx-connector-tests"
)
connector.load_credentials(
{
"aws_access_key_id": os.environ["AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS"],
"aws_secret_access_key": os.environ[
"AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS"
],
}
)
return connector
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_blob_s3_connector(
mock_get_api_key: MagicMock, blob_connector: BlobStorageConnector
) -> None:
"""
Plain and document file types should be fully indexed.
Multimedia and unknown file types will be indexed by title only with one empty section.
This is intentional in order to allow searching by just the title even if we can't
index the file content.
"""
all_docs: list[Document] = []
document_batches = blob_connector.load_from_state()
for doc_batch in document_batches:
for doc in doc_batch:
all_docs.append(doc)
#
assert len(all_docs) == 19
for doc in all_docs:
section = doc.sections[0]
assert isinstance(section, TextSection)
file_extension = get_file_ext(doc.semantic_identifier)
if file_extension in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS:
assert len(section.text) > 0
continue
if file_extension in ACCEPTED_DOCUMENT_FILE_EXTENSIONS:
assert len(section.text) > 0
continue
if file_extension in ACCEPTED_IMAGE_FILE_EXTENSIONS:
assert len(section.text) == 0
continue
# unknown extension
assert len(section.text) == 0

View File

@@ -58,6 +58,16 @@ SECTIONS_FOLDER_URL = (
"https://drive.google.com/drive/u/5/folders/1loe6XJ-pJxu9YYPv7cF3Hmz296VNzA33"
)
EXTERNAL_SHARED_FOLDER_URL = (
"https://drive.google.com/drive/folders/1sWC7Oi0aQGgifLiMnhTjvkhRWVeDa-XS"
)
EXTERNAL_SHARED_DOCS_IN_FOLDER = [
"https://docs.google.com/document/d/1Sywmv1-H6ENk2GcgieKou3kQHR_0te1mhIUcq8XlcdY"
]
EXTERNAL_SHARED_DOC_SINGLETON = (
"https://docs.google.com/document/d/11kmisDfdvNcw5LYZbkdPVjTOdj-Uc5ma6Jep68xzeeA"
)
SHARED_DRIVE_3_URL = "https://drive.google.com/drive/folders/0AJYm2K_I_vtNUk9PVA"
ADMIN_EMAIL = "admin@onyx-test.com"

View File

@@ -1,6 +1,7 @@
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from urllib.parse import urlparse
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
@@ -9,6 +10,15 @@ from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_
from tests.daily.connectors.google_drive.consts_and_utils import (
assert_expected_docs_in_retrieved_docs,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
EXTERNAL_SHARED_DOC_SINGLETON,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
EXTERNAL_SHARED_DOCS_IN_FOLDER,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
EXTERNAL_SHARED_FOLDER_URL,
)
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
@@ -100,7 +110,8 @@ def test_include_shared_drives_only_with_size_threshold(
retrieved_docs = load_all_docs(connector)
assert len(retrieved_docs) == 50
# 2 extra files from shared drive owned by non-admin and not shared with admin
assert len(retrieved_docs) == 52
@patch(
@@ -137,7 +148,8 @@ def test_include_shared_drives_only(
+ SECTIONS_FILE_IDS
)
assert len(retrieved_docs) == 51
# 2 extra files from shared drive owned by non-admin and not shared with admin
assert len(retrieved_docs) == 53
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
@@ -294,6 +306,64 @@ def test_folders_only(
)
def test_shared_folder_owned_by_external_user(
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_shared_folder_owned_by_external_user")
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=False,
include_my_drives=False,
include_files_shared_with_me=False,
shared_drive_urls=None,
shared_folder_urls=EXTERNAL_SHARED_FOLDER_URL,
my_drive_emails=None,
)
retrieved_docs = load_all_docs(connector)
expected_docs = EXTERNAL_SHARED_DOCS_IN_FOLDER
assert len(retrieved_docs) == len(expected_docs) # 1 for now
assert expected_docs[0] in retrieved_docs[0].id
def test_shared_with_me(
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_shared_with_me")
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=False,
include_my_drives=True,
include_files_shared_with_me=True,
shared_drive_urls=None,
shared_folder_urls=None,
my_drive_emails=None,
)
retrieved_docs = load_all_docs(connector)
print(retrieved_docs)
expected_file_ids = (
ADMIN_FILE_IDS
+ ADMIN_FOLDER_3_FILE_IDS
+ TEST_USER_1_FILE_IDS
+ TEST_USER_2_FILE_IDS
+ TEST_USER_3_FILE_IDS
)
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
retrieved_ids = {urlparse(doc.id).path.split("/")[-2] for doc in retrieved_docs}
for id in retrieved_ids:
print(id)
assert EXTERNAL_SHARED_DOC_SINGLETON.split("/")[-1] in retrieved_ids
assert EXTERNAL_SHARED_DOCS_IN_FOLDER[0].split("/")[-1] in retrieved_ids
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,

View File

@@ -9,7 +9,9 @@ from requests import HTTPError
from onyx.auth.schemas import UserRole
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.server.documents.models import PaginatedReturn
from onyx.server.manage.models import UserInfo
from onyx.server.models import FullUserSnapshot
from onyx.server.models import InvitedUserSnapshot
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestUser
@@ -245,3 +247,69 @@ class UserManager:
total_items=data["total_items"],
)
return paginated_result
@staticmethod
def invite_user(
user_to_invite_email: str, user_performing_action: DATestUser
) -> None:
"""Invite a user by email to join the organization.
Args:
user_to_invite_email: Email of the user to invite
user_performing_action: User with admin permissions performing the invitation
"""
response = requests.put(
url=f"{API_SERVER_URL}/manage/admin/users",
headers=user_performing_action.headers,
json={"emails": [user_to_invite_email]},
)
response.raise_for_status()
@staticmethod
def accept_invitation(tenant_id: str, user_performing_action: DATestUser) -> None:
"""Accept an invitation to join the organization.
Args:
tenant_id: ID of the tenant/organization to accept invitation for
user_performing_action: User accepting the invitation
"""
response = requests.post(
url=f"{API_SERVER_URL}/tenants/users/invite/accept",
headers=user_performing_action.headers,
json={"tenant_id": tenant_id},
)
response.raise_for_status()
@staticmethod
def get_invited_users(
user_performing_action: DATestUser,
) -> list[InvitedUserSnapshot]:
"""Get a list of all invited users.
Args:
user_performing_action: User with admin permissions performing the action
Returns:
List of invited user snapshots
"""
response = requests.get(
url=f"{API_SERVER_URL}/manage/users/invited",
headers=user_performing_action.headers,
)
response.raise_for_status()
return [InvitedUserSnapshot(**user) for user in response.json()]
@staticmethod
def get_user_info(user_performing_action: DATestUser) -> UserInfo:
"""Get user info for the current user.
Args:
user_performing_action: User performing the action
"""
response = requests.get(
url=f"{API_SERVER_URL}/me",
headers=user_performing_action.headers,
)
response.raise_for_status()
return UserInfo(**response.json())

View File

@@ -0,0 +1,70 @@
from onyx.db.models import UserRole
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
INVITED_BASIC_USER = "basic_user"
INVITED_BASIC_USER_EMAIL = "basic_user@test.com"
def test_user_invitation_flow(reset_multitenant: None) -> None:
# Create first user (admin)
admin_user: DATestUser = UserManager.create(name="admin")
assert UserManager.is_role(admin_user, UserRole.ADMIN)
# Create second user
invited_user: DATestUser = UserManager.create(name="admin_invited")
assert UserManager.is_role(invited_user, UserRole.ADMIN)
# Admin user invites the previously registered and non-registered user
UserManager.invite_user(invited_user.email, admin_user)
UserManager.invite_user(INVITED_BASIC_USER_EMAIL, admin_user)
invited_basic_user: DATestUser = UserManager.create(
name=INVITED_BASIC_USER, email=INVITED_BASIC_USER_EMAIL
)
assert UserManager.is_role(invited_basic_user, UserRole.BASIC)
# Verify the user is in the invited users list
invited_users = UserManager.get_invited_users(admin_user)
assert invited_user.email in [
user.email for user in invited_users
], f"User {invited_user.email} not found in invited users list"
# Get user info to check tenant information
user_info = UserManager.get_user_info(invited_user)
# Extract the tenant_id from the invitation
invited_tenant_id = (
user_info.tenant_info.invitation.tenant_id
if user_info.tenant_info and user_info.tenant_info.invitation
else None
)
assert invited_tenant_id is not None, "Expected to find an invitation tenant_id"
UserManager.accept_invitation(invited_tenant_id, invited_user)
# Get updated user info after accepting invitation
updated_user_info = UserManager.get_user_info(invited_user)
# Verify the user is no longer in the invited users list
updated_invited_users = UserManager.get_invited_users(admin_user)
assert invited_user.email not in [
user.email for user in updated_invited_users
], f"User {invited_user.email} should not be in invited users list after accepting"
# Verify the user has BASIC role in the organization
assert (
updated_user_info.role == UserRole.BASIC
), f"Expected user to have BASIC role, but got {updated_user_info.role}"
# Verify user is in the organization
user_page = UserManager.get_user_page(
user_performing_action=admin_user, role_filter=[UserRole.BASIC]
)
# Check if the invited user is in the list of users with BASIC role
invited_user_emails = [user.email for user in user_page.items]
assert invited_user.email in invited_user_emails, (
f"User {invited_user.email} not found in the list of basic users "
f"in the organization. Available users: {invited_user_emails}"
)

View File

@@ -129,6 +129,9 @@ services:
options:
max-size: "50m"
max-file: "6"
# optional, only for debugging purposes
volumes:
- api_server_logs:/var/log
background:
image: onyxdotapp/onyx-backend:${IMAGE_TAG:-latest}
@@ -256,7 +259,7 @@ services:
- "host.docker.internal:host-gateway"
# optional, only for debugging purposes
volumes:
- log_store:/var/log/persisted-logs
- background_logs:/var/log
logging:
driver: json-file
options:
@@ -325,6 +328,8 @@ services:
volumes:
# Not necessary, this is just to reduce download time during startup
- model_cache_huggingface:/root/.cache/huggingface/
# optional, only for debugging purposes
- inference_model_server_logs:/var/log
logging:
driver: json-file
options:
@@ -357,6 +362,8 @@ services:
volumes:
# Not necessary, this is just to reduce download time during startup
- indexing_huggingface_model_cache:/root/.cache/huggingface/
# optional, only for debugging purposes
- indexing_model_server_logs:/var/log
logging:
driver: json-file
options:
@@ -434,4 +441,8 @@ volumes:
model_cache_huggingface:
indexing_huggingface_model_cache:
log_store: # for logs that we don't want to lose on container restarts
# for logs that we don't want to lose on container restarts
api_server_logs:
background_logs:
inference_model_server_logs:
indexing_model_server_logs:

View File

@@ -106,6 +106,9 @@ services:
options:
max-size: "50m"
max-file: "6"
volumes:
# optional, only for debugging purposes
- api_server_logs:/var/log
background:
image: onyxdotapp/onyx-backend:${IMAGE_TAG:-latest}
@@ -211,7 +214,7 @@ services:
- "host.docker.internal:host-gateway"
# optional, only for debugging purposes
volumes:
- log_store:/var/log/persisted-logs
- background_logs:/var/log
logging:
driver: json-file
options:
@@ -273,6 +276,8 @@ services:
volumes:
# Not necessary, this is just to reduce download time during startup
- model_cache_huggingface:/root/.cache/huggingface/
# optional, only for debugging purposes
- inference_model_server_logs:/var/log
logging:
driver: json-file
options:
@@ -310,6 +315,8 @@ services:
volumes:
# Not necessary, this is just to reduce download time during startup
- indexing_huggingface_model_cache:/root/.cache/huggingface/
# optional, only for debugging purposes
- indexing_model_server_logs:/var/log
logging:
driver: json-file
options:
@@ -387,4 +394,8 @@ volumes:
# Created by the container itself
model_cache_huggingface:
indexing_huggingface_model_cache:
log_store: # for logs that we don't want to lose on container restarts
# for logs that we don't want to lose on container restarts
api_server_logs:
background_logs:
inference_model_server_logs:
indexing_model_server_logs:

View File

@@ -244,8 +244,6 @@ services:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
@@ -423,4 +421,3 @@ volumes:
model_cache_huggingface:
indexing_huggingface_model_cache:
log_store: # for logs that we don't want to lose on container restarts

View File

@@ -54,9 +54,6 @@ services:
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
extra_hosts:
- "host.docker.internal:host-gateway"
# optional, only for debugging purposes
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
@@ -236,4 +233,3 @@ volumes:
# Created by the container itself
model_cache_huggingface:
indexing_huggingface_model_cache:
log_store: # for logs that we don't want to lose on container restarts

View File

@@ -36,6 +36,10 @@ services:
options:
max-size: "50m"
max-file: "6"
volumes:
# optional, only for debugging purposes
- api_server_logs:/var/log
background:
image: onyxdotapp/onyx-backend:${IMAGE_TAG:-latest}
@@ -69,7 +73,7 @@ services:
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- log_store:/var/log/persisted-logs
- background_logs:/var/log
logging:
driver: json-file
options:
@@ -122,6 +126,8 @@ services:
volumes:
# Not necessary, this is just to reduce download time during startup
- model_cache_huggingface:/root/.cache/huggingface/
# optional, only for debugging purposes
- inference_model_server_logs:/var/log
logging:
driver: json-file
options:
@@ -150,6 +156,8 @@ services:
volumes:
# Not necessary, this is just to reduce download time during startup
- indexing_huggingface_model_cache:/root/.cache/huggingface/
# optional, only for debugging purposes
- indexing_model_server_logs:/var/log
logging:
driver: json-file
options:
@@ -231,4 +239,8 @@ volumes:
# Created by the container itself
model_cache_huggingface:
indexing_huggingface_model_cache:
log_store: # for logs that we don't want to lose on container restarts
# for logs that we don't want to lose on container restarts
api_server_logs:
background_logs:
inference_model_server_logs:
indexing_model_server_logs:

View File

@@ -32,13 +32,14 @@ services:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
volumes:
- api_server_logs:/var/log
background:
image: onyxdotapp/onyx-backend:${IMAGE_TAG:-latest}
build:
@@ -76,7 +77,7 @@ services:
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- log_store:/var/log/persisted-logs
- background_logs:/var/log
logging:
driver: json-file
options:
@@ -152,6 +153,8 @@ services:
volumes:
# Not necessary, this is just to reduce download time during startup
- model_cache_huggingface:/root/.cache/huggingface/
# optional, only for debugging purposes
- inference_model_server_logs:/var/log
logging:
driver: json-file
options:
@@ -180,6 +183,8 @@ services:
volumes:
# Not necessary, this is just to reduce download time during startup
- indexing_huggingface_model_cache:/root/.cache/huggingface/
# optional, only for debugging purposes
- indexing_model_server_logs:/var/log
logging:
driver: json-file
options:
@@ -264,4 +269,8 @@ volumes:
# Created by the container itself
model_cache_huggingface:
indexing_huggingface_model_cache:
log_store: # for logs that we don't want to lose on container restarts
# for logs that we don't want to lose on container restarts
api_server_logs:
background_logs:
inference_model_server_logs:
indexing_model_server_logs:

View File

@@ -63,7 +63,7 @@ services:
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- log_store:/var/log/persisted-logs
- log_store:/var/log
logging:
driver: json-file
options:

View File

@@ -45,7 +45,7 @@ export function ActionsTable({ tools }: { tools: ToolSnapshot[] }) {
className="mr-1 my-auto cursor-pointer"
onClick={() =>
router.push(
`/admin/tools/edit/${tool.id}?u=${Date.now()}`
`/admin/actions/edit/${tool.id}?u=${Date.now()}`
)
}
/>

View File

@@ -281,7 +281,7 @@ export default function AddConnector({
return (
<Formik
initialValues={{
...createConnectorInitialValues(connector),
...createConnectorInitialValues(connector, currentCredential),
...Object.fromEntries(
connectorConfigs[connector].advanced_values.map((field) => [
field.name,

View File

@@ -1384,6 +1384,7 @@ export function ChatPage({
if (!packet) {
continue;
}
console.log("Packet:", JSON.stringify(packet));
if (!initialFetchDetails) {
if (!Object.hasOwn(packet, "user_message_id")) {
@@ -1729,6 +1730,7 @@ export function ChatPage({
}
}
} catch (e: any) {
console.log("Error:", e);
const errorMsg = e.message;
upsertToCompleteMessageMap({
messages: [
@@ -1756,11 +1758,13 @@ export function ChatPage({
completeMessageMapOverride: currentMessageMap(completeMessageDetail),
});
}
console.log("Finished streaming");
setAgenticGenerating(false);
resetRegenerationState(currentSessionId());
updateChatState("input");
if (isNewSession) {
console.log("Setting up new session");
if (finalMessage) {
setSelectedMessageForDocDisplay(finalMessage.message_id);
}

View File

@@ -1292,7 +1292,8 @@ For example, specifying .*-support.* as a "channel" will cause the connector to
},
};
export function createConnectorInitialValues(
connector: ConfigurableSources
connector: ConfigurableSources,
currentCredential: Credential<any> | null = null
): Record<string, any> & AccessTypeGroupSelectorFormType {
const configuration = connectorConfigs[connector];
@@ -1307,7 +1308,16 @@ export function createConnectorInitialValues(
} else if (field.type === "list") {
acc[field.name] = field.default || [];
} else if (field.type === "checkbox") {
acc[field.name] = field.default || false;
// Special case for include_files_shared_with_me when using service account
if (
field.name === "include_files_shared_with_me" &&
currentCredential &&
!currentCredential.credential_json?.google_tokens
) {
acc[field.name] = true;
} else {
acc[field.name] = field.default || false;
}
} else if (field.default !== undefined) {
acc[field.name] = field.default;
}