mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
49 Commits
l
...
foreign_ke
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d842fed37e | ||
|
|
14981162fd | ||
|
|
288daa4e90 | ||
|
|
5e21dc6cb3 | ||
|
|
39b3a503b4 | ||
|
|
a70d472b5c | ||
|
|
0ed2886ad0 | ||
|
|
aabf8a99bc | ||
|
|
95701db1bd | ||
|
|
24105254ac | ||
|
|
4fe99d05fd | ||
|
|
d35f93b233 | ||
|
|
766b0f35df | ||
|
|
a0470a96eb | ||
|
|
b82123563b | ||
|
|
787e25cd78 | ||
|
|
c6375f8abf | ||
|
|
58e5deba01 | ||
|
|
028e877342 | ||
|
|
47bff2b6a9 | ||
|
|
1502bcea12 | ||
|
|
2701f83634 | ||
|
|
601037abb5 | ||
|
|
7e9b12403a | ||
|
|
d903e5912a | ||
|
|
d2aea63573 | ||
|
|
57b4639709 | ||
|
|
1308b6cbe8 | ||
|
|
98abd7d3fa | ||
|
|
e4180cefba | ||
|
|
f67b5356fa | ||
|
|
9bdb581220 | ||
|
|
42d6d935ae | ||
|
|
8d62b992ef | ||
|
|
2ad86aa9a6 | ||
|
|
74a472ece7 | ||
|
|
b2ce848b53 | ||
|
|
519ec20d05 | ||
|
|
3b1e26d0d4 | ||
|
|
118d2b52e6 | ||
|
|
e625884702 | ||
|
|
fa78f50fe3 | ||
|
|
05ab94945b | ||
|
|
7a64a25ff4 | ||
|
|
7f10494bbe | ||
|
|
f2d4024783 | ||
|
|
70795a4047 | ||
|
|
d8a17a7238 | ||
|
|
cbf98c0128 |
@@ -39,6 +39,12 @@ env:
|
||||
AIRTABLE_TEST_TABLE_ID: ${{ secrets.AIRTABLE_TEST_TABLE_ID }}
|
||||
AIRTABLE_TEST_TABLE_NAME: ${{ secrets.AIRTABLE_TEST_TABLE_NAME }}
|
||||
AIRTABLE_ACCESS_TOKEN: ${{ secrets.AIRTABLE_ACCESS_TOKEN }}
|
||||
# Sharepoint
|
||||
SHAREPOINT_CLIENT_ID: ${{ secrets.SHAREPOINT_CLIENT_ID }}
|
||||
SHAREPOINT_CLIENT_SECRET: ${{ secrets.SHAREPOINT_CLIENT_SECRET }}
|
||||
SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ secrets.SHAREPOINT_CLIENT_DIRECTORY_ID }}
|
||||
SHAREPOINT_SITE: ${{ secrets.SHAREPOINT_SITE }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
"""foreign key input prompts
|
||||
|
||||
Revision ID: 33ea50e88f24
|
||||
Revises: a6df6b88ef81
|
||||
Create Date: 2025-01-29 10:54:22.141765
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "33ea50e88f24"
|
||||
down_revision = "a6df6b88ef81"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# First drop the existing FK constraints
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate with ON DELETE CASCADE
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
"inputprompt",
|
||||
["input_prompt_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
'"user"',
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the new FKs with ondelete
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate them without cascading
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
"inputprompt",
|
||||
["input_prompt_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
'"user"',
|
||||
["user_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -0,0 +1,37 @@
|
||||
"""lowercase_user_emails
|
||||
|
||||
Revision ID: 4d58345da04a
|
||||
Revises: f1ca58b2f2ec
|
||||
Create Date: 2025-01-29 07:48:46.784041
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4d58345da04a"
|
||||
down_revision = "f1ca58b2f2ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get database connection
|
||||
connection = op.get_bind()
|
||||
|
||||
# Update all user emails to lowercase
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET email = LOWER(email)
|
||||
WHERE email != LOWER(email)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Cannot restore original case of emails
|
||||
pass
|
||||
@@ -0,0 +1,29 @@
|
||||
"""remove recent assistants
|
||||
|
||||
Revision ID: a6df6b88ef81
|
||||
Revises: 4d58345da04a
|
||||
Create Date: 2025-01-29 10:25:52.790407
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a6df6b88ef81"
|
||||
down_revision = "4d58345da04a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_column("user", "recent_assistants")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False
|
||||
),
|
||||
)
|
||||
@@ -32,6 +32,7 @@ def perform_ttl_management_task(
|
||||
|
||||
@celery_app.task(
|
||||
name="check_ttl_management_task",
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
@@ -56,6 +57,7 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
|
||||
@celery_app.task(
|
||||
name="autogenerate_usage_report_task",
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
|
||||
|
||||
@@ -14,6 +14,8 @@ def _build_group_member_email_map(
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
logger.debug(f"Processing groups for user: {user_result}")
|
||||
|
||||
user = user_result.get("user", {})
|
||||
if not user:
|
||||
logger.warning(f"user result missing user field: {user_result}")
|
||||
@@ -33,10 +35,17 @@ def _build_group_member_email_map(
|
||||
logger.warning(f"user result missing email field: {user_result}")
|
||||
continue
|
||||
|
||||
all_users_groups: set[str] = set()
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user):
|
||||
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
|
||||
group_id = group["name"]
|
||||
group_member_emails.setdefault(group_id, set()).add(email)
|
||||
all_users_groups.add(group_id)
|
||||
|
||||
if not group_member_emails:
|
||||
logger.warning(f"No groups found for user with email: {email}")
|
||||
else:
|
||||
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
|
||||
|
||||
return group_member_emails
|
||||
|
||||
|
||||
@@ -42,24 +42,22 @@ def _fetch_permissions_for_permission_ids(
|
||||
if not permission_info or not doc_id:
|
||||
return []
|
||||
|
||||
# Check cache first for all permission IDs
|
||||
permissions = [
|
||||
_PERMISSION_ID_PERMISSION_MAP[pid]
|
||||
for pid in permission_ids
|
||||
if pid in _PERMISSION_ID_PERMISSION_MAP
|
||||
]
|
||||
|
||||
# If we found all permissions in cache, return them
|
||||
if len(permissions) == len(permission_ids):
|
||||
return permissions
|
||||
|
||||
owner_email = permission_info.get("owner_email")
|
||||
|
||||
drive_service = get_drive_service(
|
||||
creds=google_drive_connector.creds,
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
)
|
||||
|
||||
# Otherwise, fetch all permissions and update cache
|
||||
fetched_permissions = execute_paginated_retrieval(
|
||||
retrieval_function=drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
@@ -69,7 +67,6 @@ def _fetch_permissions_for_permission_ids(
|
||||
)
|
||||
|
||||
permissions_for_doc_id = []
|
||||
# Update cache and return all permissions
|
||||
for permission in fetched_permissions:
|
||||
permissions_for_doc_id.append(permission)
|
||||
_PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission
|
||||
|
||||
@@ -111,6 +111,7 @@ async def login_as_anonymous_user(
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.delete_cookie("fastapiusersauth")
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
|
||||
@@ -58,6 +58,7 @@ class UserGroup(BaseModel):
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair_relationship.cc_pair.credential
|
||||
),
|
||||
access_type=cc_pair_relationship.cc_pair.access_type,
|
||||
)
|
||||
for cc_pair_relationship in user_group_model.cc_pair_relationships
|
||||
if cc_pair_relationship.is_current
|
||||
|
||||
@@ -42,6 +42,10 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
tenant_id: str | None = None
|
||||
|
||||
|
||||
class UserUpdateWithRole(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
"""
|
||||
Role updates are not allowed through the user update endpoint for security reasons
|
||||
|
||||
@@ -57,7 +57,7 @@ from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserUpdate
|
||||
from onyx.auth.schemas import UserUpdateWithRole
|
||||
from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
@@ -216,7 +216,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
verification_token_lifetime_seconds = AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
|
||||
user_db: SQLAlchemyUserDatabase[User, uuid.UUID]
|
||||
|
||||
async def create(
|
||||
@@ -246,10 +245,8 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
referral_source=referral_source,
|
||||
request=request,
|
||||
)
|
||||
|
||||
async with get_async_session_with_tenant(tenant_id) as db_session:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
verify_email_is_invited(user_create.email)
|
||||
verify_email_domain(user_create.email)
|
||||
if MULTI_TENANT:
|
||||
@@ -268,16 +265,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user_create.role = UserRole.ADMIN
|
||||
else:
|
||||
user_create.role = UserRole.BASIC
|
||||
|
||||
try:
|
||||
user = await super().create(user_create, safe=safe, request=request) # type: ignore
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.role.is_web_login() and user_create.role.is_web_login():
|
||||
user_update = UserUpdate(
|
||||
user_update = UserUpdateWithRole(
|
||||
password=user_create.password,
|
||||
is_verified=user_create.is_verified,
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
else:
|
||||
@@ -285,7 +282,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return user
|
||||
|
||||
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
|
||||
|
||||
@@ -24,6 +24,7 @@ from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
@@ -316,6 +317,8 @@ def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
HttpxPool.close_all()
|
||||
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
@@ -8,7 +7,6 @@ from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.signals import beat_init
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import SqlEngine
|
||||
@@ -132,21 +130,25 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
# get current schedule and extract current tenants
|
||||
current_schedule = self.schedule.items()
|
||||
|
||||
current_tenants = set()
|
||||
for task_name, _ in current_schedule:
|
||||
task_name = cast(str, task_name)
|
||||
if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
|
||||
continue
|
||||
# there are no more per tenant beat tasks, so comment this out
|
||||
# NOTE: we may not actualy need this scheduler any more and should
|
||||
# test reverting to a regular beat schedule implementation
|
||||
|
||||
if "_" in task_name:
|
||||
# example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
|
||||
# -> "12345678-abcd-efgh-ijkl-12345678"
|
||||
current_tenants.add(task_name.split("_")[-1])
|
||||
logger.info(f"Found {len(current_tenants)} existing items in schedule")
|
||||
# current_tenants = set()
|
||||
# for task_name, _ in current_schedule:
|
||||
# task_name = cast(str, task_name)
|
||||
# if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
|
||||
# continue
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id not in current_tenants:
|
||||
logger.info(f"Processing new tenant: {tenant_id}")
|
||||
# if "_" in task_name:
|
||||
# # example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
|
||||
# # -> "12345678-abcd-efgh-ijkl-12345678"
|
||||
# current_tenants.add(task_name.split("_")[-1])
|
||||
# logger.info(f"Found {len(current_tenants)} existing items in schedule")
|
||||
|
||||
# for tenant_id in tenant_ids:
|
||||
# if tenant_id not in current_tenants:
|
||||
# logger.info(f"Processing new tenant: {tenant_id}")
|
||||
|
||||
new_schedule = self._generate_schedule(tenant_ids)
|
||||
|
||||
|
||||
@@ -10,6 +10,10 @@ from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -54,12 +58,23 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits
|
||||
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
logger.info(f"Concurrency: {sender.concurrency}") # type: ignore
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=8) # type: ignore
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
|
||||
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
sender.concurrency + EXTRA_CONCURRENCY, # type: ignore
|
||||
ssl_cert=VESPA_CLOUD_CERT_PATH,
|
||||
ssl_key=VESPA_CLOUD_KEY_PATH,
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
@@ -17,6 +20,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.models import TaskQueueState
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.server.documents.models import DeletionAttemptSnapshot
|
||||
@@ -154,3 +158,25 @@ def celery_is_worker_primary(worker: Any) -> bool:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def httpx_init_vespa_pool(
|
||||
max_keepalive_connections: int,
|
||||
timeout: int = VESPA_REQUEST_TIMEOUT,
|
||||
ssl_cert: str | None = None,
|
||||
ssl_key: str | None = None,
|
||||
) -> None:
|
||||
httpx_cert = None
|
||||
httpx_verify = False
|
||||
if ssl_cert and ssl_key:
|
||||
httpx_cert = cast(tuple[str, str], (ssl_cert, ssl_key))
|
||||
httpx_verify = True
|
||||
|
||||
HttpxPool.init_client(
|
||||
name="vespa",
|
||||
cert=httpx_cert,
|
||||
verify=httpx_verify,
|
||||
timeout=timeout,
|
||||
http2=False,
|
||||
limits=httpx.Limits(max_keepalive_connections=max_keepalive_connections),
|
||||
)
|
||||
|
||||
@@ -16,6 +16,10 @@ from shared_configs.configs import MULTI_TENANT
|
||||
# it's only important that they run relatively regularly
|
||||
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# hack to slow down task dispatch in the cloud until
|
||||
# we have a better implementation (backpressure, etc)
|
||||
CLOUD_BEAT_SCHEDULE_MULTIPLIER = 8
|
||||
|
||||
# tasks that only run in the cloud
|
||||
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be filtered
|
||||
# by the DynamicTenantScheduler
|
||||
@@ -24,7 +28,7 @@ cloud_tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
|
||||
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
|
||||
"schedule": timedelta(hours=1),
|
||||
"schedule": timedelta(hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
@@ -35,7 +39,7 @@ cloud_tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-indexing",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
@@ -47,7 +51,7 @@ cloud_tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
@@ -59,7 +63,7 @@ cloud_tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
@@ -71,7 +75,7 @@ cloud_tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-prune",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
@@ -83,7 +87,7 @@ cloud_tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
@@ -95,7 +99,7 @@ cloud_tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"schedule": timedelta(seconds=30 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
@@ -107,7 +111,7 @@ cloud_tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
@@ -119,7 +123,7 @@ cloud_tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-background-processes",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(minutes=5),
|
||||
"schedule": timedelta(minutes=5 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
@@ -137,7 +141,9 @@ if LLM_MODEL_UPDATE_API_URL:
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"schedule": timedelta(
|
||||
hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER
|
||||
), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
@@ -221,7 +227,7 @@ if not MULTI_TENANT:
|
||||
{
|
||||
"name": "monitor-background-processes",
|
||||
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"schedule": timedelta(minutes=5),
|
||||
"schedule": timedelta(minutes=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
|
||||
@@ -33,6 +33,7 @@ class TaskDependencyError(RuntimeError):
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
@@ -139,13 +140,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
submitted=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
)
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
|
||||
try:
|
||||
@@ -184,6 +178,13 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
)
|
||||
if tasks_generated is None:
|
||||
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
|
||||
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
)
|
||||
|
||||
except TaskDependencyError:
|
||||
redis_connector.delete.set_fence(None)
|
||||
raise
|
||||
|
||||
@@ -11,6 +11,7 @@ from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.document import upsert_document_external_perms
|
||||
@@ -31,12 +32,17 @@ from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_doc_perm_sync import (
|
||||
@@ -57,6 +63,9 @@ LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
"""Jobs / utils for kicking off doc permissions sync tasks."""
|
||||
|
||||
|
||||
def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if external doc permissions sync is due."""
|
||||
|
||||
@@ -91,6 +100,7 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -173,6 +183,19 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_PERMISSIONS,
|
||||
)
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorPermissionSyncPayload(started=None, celery_task_id=None)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
result = app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
|
||||
kwargs=dict(
|
||||
@@ -184,11 +207,8 @@ def try_creating_permissions_sync_task(
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorPermissionSyncPayload(
|
||||
started=None, celery_task_id=result.id
|
||||
)
|
||||
|
||||
# fill in the celery task id
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
|
||||
@@ -398,3 +418,53 @@ def update_external_document_permissions_task(
|
||||
f"Error Syncing Document Permissions: connector_id={connector_id} doc_id={doc_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
"""Monitoring CCPair permissions utils, called in monitor_vespa_sync"""
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.permissions.fenced:
|
||||
return
|
||||
|
||||
initial = redis_connector.permissions.generator_complete
|
||||
if initial is None:
|
||||
return
|
||||
|
||||
remaining = redis_connector.permissions.get_remaining()
|
||||
task_logger.info(
|
||||
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
)
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
payload: RedisConnectorPermissionSyncPayload | None = (
|
||||
redis_connector.permissions.payload
|
||||
)
|
||||
start_time: datetime | None = payload.started if payload else None
|
||||
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
|
||||
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_PERMISSIONS,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial,
|
||||
)
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
|
||||
@@ -33,7 +33,11 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_ext_group_sync import (
|
||||
RedisConnectorExternalGroupSyncPayload,
|
||||
@@ -91,6 +95,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -199,6 +204,15 @@ def try_creating_external_group_sync_task(
|
||||
celery_task_id=result.id,
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_GROUP,
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
except Exception:
|
||||
@@ -288,11 +302,26 @@ def connector_external_group_sync_generator_task(
|
||||
)
|
||||
|
||||
mark_cc_pair_as_external_group_synced(db_session, cc_pair.id)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_GROUP,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run external group sync: cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_GROUP,
|
||||
sync_status=SyncStatus.FAILED,
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
redis_connector.external_group_sync.taskset_clear()
|
||||
raise e
|
||||
|
||||
@@ -15,6 +15,7 @@ from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.indexing.utils import _should_index
|
||||
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
@@ -22,6 +23,9 @@ from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_ta
|
||||
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
|
||||
from onyx.background.indexing.job_client import SimpleJobClient
|
||||
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
@@ -37,14 +41,14 @@ from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.swap_index import check_index_swap
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
@@ -69,6 +73,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
tasks_created = 0
|
||||
locked = False
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
@@ -119,9 +124,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
search_settings_list: list[SearchSettings] = get_active_search_settings(
|
||||
db_session
|
||||
)
|
||||
search_settings_list = get_active_search_settings_list(db_session)
|
||||
for search_settings_instance in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
@@ -227,7 +230,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_indexing_fences(
|
||||
tenant_id, self.app, redis_client, redis_client_celery, lock_beat
|
||||
tenant_id, redis_client_replica, redis_client_celery, lock_beat
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating indexing fences")
|
||||
@@ -301,6 +304,14 @@ def connector_indexing_task(
|
||||
attempt_found = False
|
||||
n_final_progress: int | None = None
|
||||
|
||||
# 20 is the documented default for httpx max_keepalive_connections
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH
|
||||
)
|
||||
else:
|
||||
httpx_init_vespa_pool(20)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
|
||||
@@ -291,17 +291,20 @@ def validate_indexing_fence(
|
||||
|
||||
def validate_indexing_fences(
|
||||
tenant_id: str | None,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
"""Validates all indexing fences for this tenant ... aka makes sure
|
||||
indexing tasks sent to celery are still in flight.
|
||||
"""
|
||||
reserved_indexing_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(
|
||||
# Use replica for this because the worst thing that happens
|
||||
# is that we don't run the validation on this pass
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -54,6 +54,7 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
@@ -26,18 +27,20 @@ from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import SyncRecord
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
|
||||
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
_MONITORING_TIME_LIMIT = _MONITORING_SOFT_TIME_LIMIT + 60 # 6 minutes
|
||||
|
||||
@@ -49,6 +52,17 @@ _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT = (
|
||||
"monitoring_connector_index_attempt_run_success:{cc_pair_id}:{index_attempt_id}"
|
||||
)
|
||||
|
||||
_FINAL_METRIC_KEY_FMT = "sync_final_metrics:{sync_type}:{entity_id}:{sync_record_id}"
|
||||
|
||||
_SYNC_START_LATENCY_KEY_FMT = (
|
||||
"sync_start_latency:{sync_type}:{entity_id}:{sync_record_id}"
|
||||
)
|
||||
|
||||
_CONNECTOR_START_TIME_KEY_FMT = "connector_start_time:{cc_pair_id}:{index_attempt_id}"
|
||||
_CONNECTOR_END_TIME_KEY_FMT = "connector_end_time:{cc_pair_id}:{index_attempt_id}"
|
||||
_SYNC_START_TIME_KEY_FMT = "sync_start_time:{sync_type}:{entity_id}:{sync_record_id}"
|
||||
_SYNC_END_TIME_KEY_FMT = "sync_end_time:{sync_type}:{entity_id}:{sync_record_id}"
|
||||
|
||||
|
||||
def _mark_metric_as_emitted(redis_std: Redis, key: str) -> None:
|
||||
"""Mark a metric as having been emitted by setting a Redis key with expiration"""
|
||||
@@ -111,6 +125,7 @@ class Metric(BaseModel):
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
task_logger.info(f"Emitting metric: {data}")
|
||||
optional_telemetry(
|
||||
record_type=RecordType.METRIC,
|
||||
data=data,
|
||||
@@ -189,48 +204,107 @@ def _build_connector_start_latency_metric(
|
||||
f"Start latency for index attempt {recent_attempt.id}: {start_latency:.2f}s "
|
||||
f"(desired: {desired_start_time}, actual: {recent_attempt.time_started})"
|
||||
)
|
||||
|
||||
job_id = build_job_id("connector", str(cc_pair.id), str(recent_attempt.id))
|
||||
|
||||
return Metric(
|
||||
key=metric_key,
|
||||
name="connector_start_latency",
|
||||
value=start_latency,
|
||||
tags={},
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _build_run_success_metrics(
|
||||
def _build_connector_final_metrics(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
recent_attempts: list[IndexAttempt],
|
||||
redis_std: Redis,
|
||||
) -> list[Metric]:
|
||||
"""
|
||||
Final metrics for connector index attempts:
|
||||
- Boolean success/fail metric
|
||||
- If success, emit:
|
||||
* duration (seconds)
|
||||
* doc_count
|
||||
"""
|
||||
metrics = []
|
||||
for attempt in recent_attempts:
|
||||
metric_key = _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT.format(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_id=attempt.id,
|
||||
)
|
||||
|
||||
if _has_metric_been_emitted(redis_std, metric_key):
|
||||
task_logger.info(
|
||||
f"Skipping metric for connector {cc_pair.connector.id} "
|
||||
f"index attempt {attempt.id} because it has already been "
|
||||
"emitted"
|
||||
f"Skipping final metrics for connector {cc_pair.connector.id} "
|
||||
f"index attempt {attempt.id}, already emitted."
|
||||
)
|
||||
continue
|
||||
|
||||
if attempt.status in [
|
||||
# We only emit final metrics if the attempt is in a terminal state
|
||||
if attempt.status not in [
|
||||
IndexingStatus.SUCCESS,
|
||||
IndexingStatus.FAILED,
|
||||
IndexingStatus.CANCELED,
|
||||
]:
|
||||
task_logger.info(
|
||||
f"Adding run success metric for index attempt {attempt.id} with status {attempt.status}"
|
||||
# Not finished; skip
|
||||
continue
|
||||
|
||||
job_id = build_job_id("connector", str(cc_pair.id), str(attempt.id))
|
||||
success = attempt.status == IndexingStatus.SUCCESS
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=metric_key, # We'll mark the same key for any final metrics
|
||||
name="connector_run_succeeded",
|
||||
value=success,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
"status": attempt.status.value,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if success:
|
||||
# Make sure we have valid time_started
|
||||
if attempt.time_started and attempt.time_updated:
|
||||
duration_seconds = (
|
||||
attempt.time_updated - attempt.time_started
|
||||
).total_seconds()
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=None, # No need for a new key, or you can reuse the same if you prefer
|
||||
name="connector_index_duration_seconds",
|
||||
value=duration_seconds,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Index attempt {attempt.id} succeeded but has missing time "
|
||||
f"(time_started={attempt.time_started}, time_updated={attempt.time_updated})."
|
||||
)
|
||||
|
||||
# For doc counts, choose whichever field is more relevant
|
||||
doc_count = attempt.total_docs_indexed or 0
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=metric_key,
|
||||
name="connector_run_succeeded",
|
||||
value=attempt.status == IndexingStatus.SUCCESS,
|
||||
tags={"source": str(cc_pair.connector.source)},
|
||||
key=None,
|
||||
name="connector_index_doc_count",
|
||||
value=doc_count,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -239,189 +313,337 @@ def _build_run_success_metrics(
|
||||
|
||||
def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
|
||||
"""Collect metrics about connector runs from the past hour"""
|
||||
# NOTE: use get_db_current_time since the IndexAttempt times are set based on DB time
|
||||
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
|
||||
|
||||
# Get all connector credential pairs
|
||||
cc_pairs = db_session.scalars(select(ConnectorCredentialPair)).all()
|
||||
# Might be more than one search setting, or just one
|
||||
active_search_settings_list = get_active_search_settings_list(db_session)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
metrics = []
|
||||
|
||||
for cc_pair, search_settings in zip(cc_pairs, active_search_settings):
|
||||
recent_attempts = (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair.id,
|
||||
IndexAttempt.search_settings_id == search_settings.id,
|
||||
# If you want to process each cc_pair against each search setting:
|
||||
for cc_pair in cc_pairs:
|
||||
for search_settings in active_search_settings_list:
|
||||
recent_attempts = (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair.id,
|
||||
IndexAttempt.search_settings_id == search_settings.id,
|
||||
)
|
||||
.order_by(IndexAttempt.time_created.desc())
|
||||
.limit(2)
|
||||
.all()
|
||||
)
|
||||
.order_by(IndexAttempt.time_created.desc())
|
||||
.limit(2)
|
||||
.all()
|
||||
)
|
||||
if not recent_attempts:
|
||||
continue
|
||||
|
||||
most_recent_attempt = recent_attempts[0]
|
||||
second_most_recent_attempt = (
|
||||
recent_attempts[1] if len(recent_attempts) > 1 else None
|
||||
)
|
||||
if not recent_attempts:
|
||||
continue
|
||||
|
||||
if one_hour_ago > most_recent_attempt.time_created:
|
||||
continue
|
||||
most_recent_attempt = recent_attempts[0]
|
||||
second_most_recent_attempt = (
|
||||
recent_attempts[1] if len(recent_attempts) > 1 else None
|
||||
)
|
||||
|
||||
# Connector start latency
|
||||
start_latency_metric = _build_connector_start_latency_metric(
|
||||
cc_pair, most_recent_attempt, second_most_recent_attempt, redis_std
|
||||
)
|
||||
if start_latency_metric:
|
||||
metrics.append(start_latency_metric)
|
||||
if one_hour_ago > most_recent_attempt.time_created:
|
||||
continue
|
||||
|
||||
# Connector run success/failure
|
||||
run_success_metrics = _build_run_success_metrics(
|
||||
cc_pair, recent_attempts, redis_std
|
||||
)
|
||||
metrics.extend(run_success_metrics)
|
||||
# Build a job_id for correlation
|
||||
job_id = build_job_id(
|
||||
"connector", str(cc_pair.id), str(most_recent_attempt.id)
|
||||
)
|
||||
|
||||
# Add raw start time metric if available
|
||||
if most_recent_attempt.time_started:
|
||||
start_time_key = _CONNECTOR_START_TIME_KEY_FMT.format(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_id=most_recent_attempt.id,
|
||||
)
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=start_time_key,
|
||||
name="connector_start_time",
|
||||
value=most_recent_attempt.time_started.timestamp(),
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Add raw end time metric if available and in terminal state
|
||||
if (
|
||||
most_recent_attempt.status.is_terminal()
|
||||
and most_recent_attempt.time_updated
|
||||
):
|
||||
end_time_key = _CONNECTOR_END_TIME_KEY_FMT.format(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_id=most_recent_attempt.id,
|
||||
)
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=end_time_key,
|
||||
name="connector_end_time",
|
||||
value=most_recent_attempt.time_updated.timestamp(),
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Connector start latency
|
||||
start_latency_metric = _build_connector_start_latency_metric(
|
||||
cc_pair, most_recent_attempt, second_most_recent_attempt, redis_std
|
||||
)
|
||||
|
||||
if start_latency_metric:
|
||||
metrics.append(start_latency_metric)
|
||||
|
||||
# Connector run success/failure
|
||||
final_metrics = _build_connector_final_metrics(
|
||||
cc_pair, recent_attempts, redis_std
|
||||
)
|
||||
metrics.extend(final_metrics)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
|
||||
"""Collect metrics about document set and group syncing speed"""
|
||||
# NOTE: use get_db_current_time since the SyncRecord times are set based on DB time
|
||||
"""
|
||||
Collect metrics for document set and group syncing:
|
||||
- Success/failure status
|
||||
- Start latency (for doc sets / user groups)
|
||||
- Duration & doc count (only if success)
|
||||
- Throughput (docs/min) (only if success)
|
||||
- Raw start/end times for each sync
|
||||
"""
|
||||
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
|
||||
|
||||
# Get all sync records from the last hour
|
||||
# Get all sync records that ended in the last hour
|
||||
recent_sync_records = db_session.scalars(
|
||||
select(SyncRecord)
|
||||
.where(SyncRecord.sync_start_time >= one_hour_ago)
|
||||
.order_by(SyncRecord.sync_start_time.desc())
|
||||
.where(SyncRecord.sync_end_time.isnot(None))
|
||||
.where(SyncRecord.sync_end_time >= one_hour_ago)
|
||||
.order_by(SyncRecord.sync_end_time.desc())
|
||||
).all()
|
||||
|
||||
task_logger.info(
|
||||
f"Collecting sync metrics for {len(recent_sync_records)} sync records"
|
||||
)
|
||||
|
||||
metrics = []
|
||||
|
||||
for sync_record in recent_sync_records:
|
||||
# Skip if no end time (sync still in progress)
|
||||
if not sync_record.sync_end_time:
|
||||
continue
|
||||
# Build a job_id for correlation
|
||||
job_id = build_job_id("sync_record", str(sync_record.id))
|
||||
|
||||
# Check if we already emitted a metric for this sync record
|
||||
metric_key = (
|
||||
f"sync_speed:{sync_record.sync_type}:"
|
||||
f"{sync_record.entity_id}:{sync_record.id}"
|
||||
)
|
||||
if _has_metric_been_emitted(redis_std, metric_key):
|
||||
task_logger.info(
|
||||
f"Skipping metric for sync record {sync_record.id} "
|
||||
"because it has already been emitted"
|
||||
)
|
||||
continue
|
||||
|
||||
# Calculate sync duration in minutes
|
||||
sync_duration_mins = (
|
||||
sync_record.sync_end_time - sync_record.sync_start_time
|
||||
).total_seconds() / 60.0
|
||||
|
||||
# Calculate sync speed (docs/min) - avoid division by zero
|
||||
sync_speed = (
|
||||
sync_record.num_docs_synced / sync_duration_mins
|
||||
if sync_duration_mins > 0
|
||||
else None
|
||||
)
|
||||
|
||||
if sync_speed is None:
|
||||
task_logger.error(
|
||||
f"Something went wrong with sync speed calculation. "
|
||||
f"Sync record: {sync_record.id}, duration: {sync_duration_mins}, "
|
||||
f"docs synced: {sync_record.num_docs_synced}"
|
||||
)
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"Calculated sync speed for record {sync_record.id}: {sync_speed} docs/min"
|
||||
# Add raw start time metric
|
||||
start_time_key = _SYNC_START_TIME_KEY_FMT.format(
|
||||
sync_type=sync_record.sync_type,
|
||||
entity_id=sync_record.entity_id,
|
||||
sync_record_id=sync_record.id,
|
||||
)
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=metric_key,
|
||||
name="sync_speed_docs_per_min",
|
||||
value=sync_speed,
|
||||
tags={
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
"status": str(sync_record.sync_status),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Add sync start latency metric
|
||||
start_latency_key = (
|
||||
f"sync_start_latency:{sync_record.sync_type}"
|
||||
f":{sync_record.entity_id}:{sync_record.id}"
|
||||
)
|
||||
if _has_metric_been_emitted(redis_std, start_latency_key):
|
||||
task_logger.info(
|
||||
f"Skipping start latency metric for sync record {sync_record.id} "
|
||||
"because it has already been emitted"
|
||||
)
|
||||
continue
|
||||
|
||||
# Get the entity's last update time based on sync type
|
||||
entity: DocumentSet | UserGroup | None = None
|
||||
if sync_record.sync_type == SyncType.DOCUMENT_SET:
|
||||
entity = db_session.scalar(
|
||||
select(DocumentSet).where(DocumentSet.id == sync_record.entity_id)
|
||||
)
|
||||
elif sync_record.sync_type == SyncType.USER_GROUP:
|
||||
entity = db_session.scalar(
|
||||
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
|
||||
)
|
||||
else:
|
||||
# Skip other sync types
|
||||
task_logger.info(
|
||||
f"Skipping sync record {sync_record.id} "
|
||||
f"with type {sync_record.sync_type} "
|
||||
f"and id {sync_record.entity_id} "
|
||||
"because it is not a document set or user group"
|
||||
)
|
||||
continue
|
||||
|
||||
if entity is None:
|
||||
task_logger.error(
|
||||
f"Could not find entity for sync record {sync_record.id} "
|
||||
f"with type {sync_record.sync_type} and id {sync_record.entity_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Calculate start latency in seconds
|
||||
start_latency = (
|
||||
sync_record.sync_start_time - entity.time_last_modified_by_user
|
||||
).total_seconds()
|
||||
task_logger.info(
|
||||
f"Calculated start latency for sync record {sync_record.id}: {start_latency} seconds"
|
||||
)
|
||||
if start_latency < 0:
|
||||
task_logger.error(
|
||||
f"Start latency is negative for sync record {sync_record.id} "
|
||||
f"with type {sync_record.sync_type} and id {sync_record.entity_id}. "
|
||||
f"Sync start time: {sync_record.sync_start_time}, "
|
||||
f"Entity last modified: {entity.time_last_modified_by_user}"
|
||||
)
|
||||
continue
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=start_latency_key,
|
||||
name="sync_start_latency_seconds",
|
||||
value=start_latency,
|
||||
key=start_time_key,
|
||||
name="sync_start_time",
|
||||
value=sync_record.sync_start_time.timestamp(),
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Add raw end time metric if available
|
||||
if sync_record.sync_end_time:
|
||||
end_time_key = _SYNC_END_TIME_KEY_FMT.format(
|
||||
sync_type=sync_record.sync_type,
|
||||
entity_id=sync_record.entity_id,
|
||||
sync_record_id=sync_record.id,
|
||||
)
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=end_time_key,
|
||||
name="sync_end_time",
|
||||
value=sync_record.sync_end_time.timestamp(),
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Emit a SUCCESS/FAIL boolean metric
|
||||
# Use a single Redis key to avoid re-emitting final metrics
|
||||
final_metric_key = _FINAL_METRIC_KEY_FMT.format(
|
||||
sync_type=sync_record.sync_type,
|
||||
entity_id=sync_record.entity_id,
|
||||
sync_record_id=sync_record.id,
|
||||
)
|
||||
if not _has_metric_been_emitted(redis_std, final_metric_key):
|
||||
# Evaluate success
|
||||
sync_succeeded = sync_record.sync_status == SyncStatus.SUCCESS
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=final_metric_key,
|
||||
name="sync_run_succeeded",
|
||||
value=sync_succeeded,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
"status": str(sync_record.sync_status),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# If successful, emit additional metrics
|
||||
if sync_succeeded:
|
||||
if sync_record.sync_end_time and sync_record.sync_start_time:
|
||||
duration_seconds = (
|
||||
sync_record.sync_end_time - sync_record.sync_start_time
|
||||
).total_seconds()
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Invalid times for sync record {sync_record.id}: "
|
||||
f"start={sync_record.sync_start_time}, end={sync_record.sync_end_time}"
|
||||
)
|
||||
duration_seconds = None
|
||||
|
||||
doc_count = sync_record.num_docs_synced or 0
|
||||
|
||||
sync_speed = None
|
||||
if duration_seconds and duration_seconds > 0:
|
||||
duration_mins = duration_seconds / 60.0
|
||||
sync_speed = (
|
||||
doc_count / duration_mins if duration_mins > 0 else None
|
||||
)
|
||||
|
||||
# Emit duration, doc count, speed
|
||||
if duration_seconds is not None:
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=final_metric_key,
|
||||
name="sync_duration_seconds",
|
||||
value=duration_seconds,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Invalid sync record {sync_record.id} with no duration"
|
||||
)
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=final_metric_key,
|
||||
name="sync_doc_count",
|
||||
value=doc_count,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if sync_speed is not None:
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=final_metric_key,
|
||||
name="sync_speed_docs_per_min",
|
||||
value=sync_speed,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Invalid sync record {sync_record.id} with no duration"
|
||||
)
|
||||
|
||||
# Emit start latency
|
||||
start_latency_key = _SYNC_START_LATENCY_KEY_FMT.format(
|
||||
sync_type=sync_record.sync_type,
|
||||
entity_id=sync_record.entity_id,
|
||||
sync_record_id=sync_record.id,
|
||||
)
|
||||
if not _has_metric_been_emitted(redis_std, start_latency_key):
|
||||
# Get the entity's last update time based on sync type
|
||||
entity: DocumentSet | UserGroup | None = None
|
||||
if sync_record.sync_type == SyncType.DOCUMENT_SET:
|
||||
entity = db_session.scalar(
|
||||
select(DocumentSet).where(DocumentSet.id == sync_record.entity_id)
|
||||
)
|
||||
elif sync_record.sync_type == SyncType.USER_GROUP:
|
||||
entity = db_session.scalar(
|
||||
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
|
||||
)
|
||||
|
||||
if entity is None:
|
||||
task_logger.error(
|
||||
f"Sync record of type {sync_record.sync_type} doesn't have an entity "
|
||||
f"associated with it (id={sync_record.entity_id}). Skipping start latency metric."
|
||||
)
|
||||
|
||||
# Calculate start latency in seconds:
|
||||
# (actual sync start) - (last modified time)
|
||||
if (
|
||||
entity is not None
|
||||
and entity.time_last_modified_by_user
|
||||
and sync_record.sync_start_time
|
||||
):
|
||||
start_latency = (
|
||||
sync_record.sync_start_time - entity.time_last_modified_by_user
|
||||
).total_seconds()
|
||||
|
||||
if start_latency < 0:
|
||||
task_logger.error(
|
||||
f"Negative start latency for sync record {sync_record.id} "
|
||||
f"(start={sync_record.sync_start_time}, entity_modified={entity.time_last_modified_by_user})"
|
||||
)
|
||||
continue
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=start_latency_key,
|
||||
name="sync_start_latency_seconds",
|
||||
value=start_latency,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def build_job_id(
|
||||
job_type: Literal["connector", "sync_record"],
|
||||
primary_id: str,
|
||||
secondary_id: str | None = None,
|
||||
) -> str:
|
||||
if job_type == "connector":
|
||||
if secondary_id is None:
|
||||
raise ValueError(
|
||||
"secondary_id (attempt_id) is required for connector job_type"
|
||||
)
|
||||
return f"connector:{primary_id}:attempt:{secondary_id}"
|
||||
elif job_type == "sync_record":
|
||||
return f"sync_record:{primary_id}"
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
ignore_result=True,
|
||||
soft_time_limit=_MONITORING_SOFT_TIME_LIMIT,
|
||||
time_limit=_MONITORING_TIME_LIMIT,
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
@@ -459,14 +681,18 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
lambda: _collect_connector_metrics(db_session, redis_std),
|
||||
lambda: _collect_sync_metrics(db_session, redis_std),
|
||||
]
|
||||
|
||||
# Collect and log each metric
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
for metric_fn in metric_functions:
|
||||
metrics = metric_fn()
|
||||
for metric in metrics:
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
if metric.key:
|
||||
# double check to make sure we aren't double-emitting metrics
|
||||
if metric.key is not None and not _has_metric_been_emitted(
|
||||
redis_std, metric.key
|
||||
):
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
|
||||
task_logger.info("Successfully collected background metrics")
|
||||
|
||||
@@ -25,13 +25,18 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.connector import mark_ccpair_as_pruned
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import get_documents_for_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import pruning_ctx
|
||||
@@ -40,6 +45,9 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
"""Jobs / utils for kicking off pruning tasks."""
|
||||
|
||||
|
||||
def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if pruning is due.
|
||||
|
||||
@@ -78,6 +86,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -203,6 +212,14 @@ def try_creating_prune_generator_task(
|
||||
priority=OnyxCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair.id,
|
||||
sync_type=SyncType.PRUNING,
|
||||
)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
redis_connector.prune.set_fence(True)
|
||||
except Exception:
|
||||
@@ -347,3 +364,52 @@ def connector_pruning_generator_task(
|
||||
lock.release()
|
||||
|
||||
task_logger.info(f"Pruning generator finished: cc_pair={cc_pair_id}")
|
||||
|
||||
|
||||
"""Monitoring pruning utils, called in monitor_vespa_sync"""
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.prune.fenced:
|
||||
return
|
||||
|
||||
initial = redis_connector.prune.generator_complete
|
||||
if initial is None:
|
||||
return
|
||||
|
||||
remaining = redis_connector.prune.get_remaining()
|
||||
task_logger.info(
|
||||
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
)
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
|
||||
task_logger.info(
|
||||
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.PRUNING,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=initial,
|
||||
)
|
||||
|
||||
redis_connector.prune.taskset_clear()
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.set_fence(False)
|
||||
|
||||
@@ -27,12 +27,14 @@ from onyx.db.document import mark_document_as_synced
|
||||
from onyx.db.document_set import fetch_document_sets_for_document
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.document_index.document_index_utils import get_both_index_names
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
|
||||
|
||||
@@ -78,9 +80,11 @@ def document_by_cc_pair_cleanup_task(
|
||||
action = "skip"
|
||||
chunks_affected = 0
|
||||
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
active_search_settings.primary,
|
||||
active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
@@ -213,6 +217,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
ignore_result=True,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
@@ -247,6 +252,10 @@ def cloud_beat_task_generator(
|
||||
lock_beat.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# needed in the cloud
|
||||
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
|
||||
continue
|
||||
|
||||
self.app.send_task(
|
||||
task_name,
|
||||
kwargs=dict(
|
||||
|
||||
@@ -24,6 +24,10 @@ from onyx.access.access import get_access_for_document
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.doc_permission_syncing.tasks import (
|
||||
monitor_ccpair_permissions_taskset,
|
||||
)
|
||||
from onyx.background.celery.tasks.pruning.tasks import monitor_ccpair_pruning_taskset
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
@@ -34,8 +38,6 @@ from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector import mark_ccpair_as_pruned
|
||||
from onyx.db.connector_credential_pair import add_deletion_failure_message
|
||||
from onyx.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
@@ -61,23 +63,22 @@ from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.sync_record import cleanup_sync_records
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.document_index.document_index_utils import get_both_index_names
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncPayload,
|
||||
)
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
@@ -97,6 +98,7 @@ logger = setup_logger()
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
@@ -650,83 +652,6 @@ def monitor_connector_deletion_taskset(
|
||||
redis_connector.delete.reset()
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.prune.fenced:
|
||||
return
|
||||
|
||||
initial = redis_connector.prune.generator_complete
|
||||
if initial is None:
|
||||
return
|
||||
|
||||
remaining = redis_connector.prune.get_remaining()
|
||||
task_logger.info(
|
||||
f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
)
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
|
||||
task_logger.info(
|
||||
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
|
||||
)
|
||||
|
||||
redis_connector.prune.taskset_clear()
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.set_fence(False)
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if not redis_connector.permissions.fenced:
|
||||
return
|
||||
|
||||
initial = redis_connector.permissions.generator_complete
|
||||
if initial is None:
|
||||
return
|
||||
|
||||
remaining = redis_connector.permissions.get_remaining()
|
||||
task_logger.info(
|
||||
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
)
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
payload: RedisConnectorPermissionSyncPayload | None = (
|
||||
redis_connector.permissions.payload
|
||||
)
|
||||
start_time: datetime | None = payload.started if payload else None
|
||||
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
|
||||
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
@@ -871,7 +796,12 @@ def monitor_ccpair_indexing_taskset(
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(name=OnyxCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
ignore_result=True,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
"""This is a celery beat task that monitors and finalizes various long running tasks.
|
||||
|
||||
@@ -895,6 +825,17 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Replica usage notes
|
||||
#
|
||||
# False negatives are OK. (aka fail to to see a key that exists on the master).
|
||||
# We simply skip the monitoring work and it will be caught on the next pass.
|
||||
#
|
||||
# False positives are not OK, and are possible if we clear a fence on the master and
|
||||
# then read from the replica. In this case, monitoring work could be done on a fence
|
||||
# that no longer exists. To avoid this, we scan from the replica, but double check
|
||||
# the result on the master.
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
@@ -954,17 +895,19 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
# scan and monitor activity to completion
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
if r_replica.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
timings["connector"] = time.monotonic() - phase_start
|
||||
timings["connector_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
RedisConnectorDelete.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
if r.exists(key_bytes):
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
lock_beat.reacquire()
|
||||
|
||||
timings["connector_deletion"] = time.monotonic() - phase_start
|
||||
@@ -974,66 +917,74 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
RedisDocumentSet.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
if r.exists(key_bytes):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
timings["documentset"] = time.monotonic() - phase_start
|
||||
timings["documentset_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
RedisUserGroup.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
|
||||
"onyx.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
if r.exists(key_bytes):
|
||||
monitor_usergroup_taskset = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"onyx.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
timings["usergroup"] = time.monotonic() - phase_start
|
||||
timings["usergroup_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
if r.exists(key_bytes):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
timings["pruning"] = time.monotonic() - phase_start
|
||||
timings["pruning_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
if r.exists(key_bytes):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
timings["indexing"] = time.monotonic() - phase_start
|
||||
timings["indexing_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r.scan_iter(
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
|
||||
if r.exists(key_bytes):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(
|
||||
tenant_id, key_bytes, r, db_session
|
||||
)
|
||||
lock_beat.reacquire()
|
||||
|
||||
timings["permissions"] = time.monotonic() - phase_start
|
||||
timings["permissions_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -1068,9 +1019,11 @@ def vespa_metadata_sync_task(
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
retry_index = RetryDocumentIndex(doc_index)
|
||||
|
||||
@@ -35,6 +35,7 @@ from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
@@ -219,9 +220,10 @@ def _run_indexing(
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
# Indexing is only done into one index at a time
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=ctx.index_name, secondary_index_name=None
|
||||
index_attempt_start.search_settings,
|
||||
None,
|
||||
httpx_client=HttpxPool.get("vespa"),
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
|
||||
@@ -254,6 +254,7 @@ def _get_force_search_settings(
|
||||
and new_msg_req.retrieval_options.run_search
|
||||
== OptionalSearchSetting.ALWAYS,
|
||||
new_msg_req.search_doc_ids,
|
||||
new_msg_req.query_override is not None,
|
||||
DISABLE_LLM_CHOOSE_SEARCH,
|
||||
]
|
||||
)
|
||||
@@ -425,9 +426,7 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name, secondary_index_name=None
|
||||
)
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
root_message = get_or_create_root_message(
|
||||
@@ -499,14 +498,6 @@ def stream_chat_message_objects(
|
||||
f"existing assistant message id: {existing_assistant_message_id}"
|
||||
)
|
||||
|
||||
# Disable Query Rephrasing for the first message
|
||||
# This leads to a better first response since the LLM rephrasing the question
|
||||
# leads to worst search quality
|
||||
if not history_msgs:
|
||||
new_msg_req.query_override = (
|
||||
new_msg_req.query_override or new_msg_req.message
|
||||
)
|
||||
|
||||
# load all files needed for this chat chain in memory
|
||||
files = load_all_chat_files(
|
||||
history_msgs, new_msg_req.file_descriptors, db_session
|
||||
|
||||
@@ -200,6 +200,8 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
|
||||
# this assumes that other redis settings remain the same as the primary
|
||||
REDIS_REPLICA_HOST = os.environ.get("REDIS_REPLICA_HOST") or REDIS_HOST
|
||||
|
||||
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
|
||||
|
||||
@@ -476,6 +478,12 @@ INDEXING_SIZE_WARNING_THRESHOLD = int(
|
||||
# 0 disables this behavior and is the default.
|
||||
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0)
|
||||
|
||||
# Enable multi-threaded embedding model calls for parallel processing
|
||||
# Note: only applies for API-based embedding models
|
||||
INDEXING_EMBEDDING_MODEL_NUM_THREADS = int(
|
||||
os.environ.get("INDEXING_EMBEDDING_MODEL_NUM_THREADS") or 1
|
||||
)
|
||||
|
||||
# During an indexing attempt, specifies the number of batches which are allowed to
|
||||
# exception without aborting the attempt.
|
||||
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
|
||||
@@ -20,9 +22,9 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
# NOTE: all are made lowercase to avoid case sensitivity issues
|
||||
# these are the field types that are considered metadata rather
|
||||
# than sections
|
||||
_METADATA_FIELD_TYPES = {
|
||||
# These field types are considered metadata by default when
|
||||
# treat_all_non_attachment_fields_as_metadata is False
|
||||
DEFAULT_METADATA_FIELD_TYPES = {
|
||||
"singlecollaborator",
|
||||
"collaborator",
|
||||
"createdby",
|
||||
@@ -60,12 +62,16 @@ class AirtableConnector(LoadConnector):
|
||||
self,
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
treat_all_non_attachment_fields_as_metadata: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.base_id = base_id
|
||||
self.table_name_or_id = table_name_or_id
|
||||
self.batch_size = batch_size
|
||||
self.airtable_client: AirtableApi | None = None
|
||||
self.treat_all_non_attachment_fields_as_metadata = (
|
||||
treat_all_non_attachment_fields_as_metadata
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
|
||||
@@ -166,8 +172,14 @@ class AirtableConnector(LoadConnector):
|
||||
return [(str(field_info), default_link)]
|
||||
|
||||
def _should_be_metadata(self, field_type: str) -> bool:
|
||||
"""Determine if a field type should be treated as metadata."""
|
||||
return field_type.lower() in _METADATA_FIELD_TYPES
|
||||
"""Determine if a field type should be treated as metadata.
|
||||
|
||||
When treat_all_non_attachment_fields_as_metadata is True, all fields except
|
||||
attachments are treated as metadata. Otherwise, only fields with types listed
|
||||
in DEFAULT_METADATA_FIELD_TYPES are treated as metadata."""
|
||||
if self.treat_all_non_attachment_fields_as_metadata:
|
||||
return field_type.lower() != "multipleattachments"
|
||||
return field_type.lower() in DEFAULT_METADATA_FIELD_TYPES
|
||||
|
||||
def _process_field(
|
||||
self,
|
||||
@@ -233,7 +245,7 @@ class AirtableConnector(LoadConnector):
|
||||
record: RecordDict,
|
||||
table_schema: TableSchema,
|
||||
primary_field_name: str | None,
|
||||
) -> Document:
|
||||
) -> Document | None:
|
||||
"""Process a single Airtable record into a Document.
|
||||
|
||||
Args:
|
||||
@@ -264,6 +276,11 @@ class AirtableConnector(LoadConnector):
|
||||
field_val = fields.get(field_name)
|
||||
field_type = field_schema.type
|
||||
|
||||
logger.debug(
|
||||
f"Processing field '{field_name}' of type '{field_type}' "
|
||||
f"for record '{record_id}'."
|
||||
)
|
||||
|
||||
field_sections, field_metadata = self._process_field(
|
||||
field_id=field_schema.id,
|
||||
field_name=field_name,
|
||||
@@ -277,6 +294,10 @@ class AirtableConnector(LoadConnector):
|
||||
sections.extend(field_sections)
|
||||
metadata.update(field_metadata)
|
||||
|
||||
if not sections:
|
||||
logger.warning(f"No sections found for record {record_id}")
|
||||
return None
|
||||
|
||||
semantic_id = (
|
||||
f"{table_name}: {primary_field_value}"
|
||||
if primary_field_value
|
||||
@@ -313,18 +334,45 @@ class AirtableConnector(LoadConnector):
|
||||
primary_field_name = field.name
|
||||
break
|
||||
|
||||
record_documents: list[Document] = []
|
||||
for record in records:
|
||||
document = self._process_record(
|
||||
record=record,
|
||||
table_schema=table_schema,
|
||||
primary_field_name=primary_field_name,
|
||||
)
|
||||
record_documents.append(document)
|
||||
logger.info(f"Starting to process Airtable records for {table.name}.")
|
||||
|
||||
# Process records in parallel batches using ThreadPoolExecutor
|
||||
PARALLEL_BATCH_SIZE = 16
|
||||
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
|
||||
|
||||
# Process records in batches
|
||||
for i in range(0, len(records), PARALLEL_BATCH_SIZE):
|
||||
batch_records = records[i : i + PARALLEL_BATCH_SIZE]
|
||||
record_documents: list[Document] = []
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit batch tasks
|
||||
future_to_record = {
|
||||
executor.submit(
|
||||
self._process_record,
|
||||
record=record,
|
||||
table_schema=table_schema,
|
||||
primary_field_name=primary_field_name,
|
||||
): record
|
||||
for record in batch_records
|
||||
}
|
||||
|
||||
# Wait for all tasks in this batch to complete
|
||||
for future in as_completed(future_to_record):
|
||||
record = future_to_record[future]
|
||||
try:
|
||||
document = future.result()
|
||||
if document:
|
||||
record_documents.append(document)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to process record {record['id']}")
|
||||
raise e
|
||||
|
||||
# After batch is complete, yield if we've hit the batch size
|
||||
if len(record_documents) >= self.batch_size:
|
||||
yield record_documents
|
||||
record_documents = []
|
||||
|
||||
# Yield any remaining records
|
||||
if record_documents:
|
||||
yield record_documents
|
||||
|
||||
@@ -232,20 +232,29 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
}
|
||||
|
||||
# Get labels
|
||||
label_dicts = confluence_object["metadata"]["labels"]["results"]
|
||||
page_labels = [label["name"] for label in label_dicts]
|
||||
label_dicts = (
|
||||
confluence_object.get("metadata", {}).get("labels", {}).get("results", [])
|
||||
)
|
||||
page_labels = [label.get("name") for label in label_dicts if label.get("name")]
|
||||
if page_labels:
|
||||
doc_metadata["labels"] = page_labels
|
||||
|
||||
# Get last modified and author email
|
||||
last_modified = datetime_from_string(confluence_object["version"]["when"])
|
||||
author_email = confluence_object["version"].get("by", {}).get("email")
|
||||
version_dict = confluence_object.get("version", {})
|
||||
last_modified = (
|
||||
datetime_from_string(version_dict.get("when"))
|
||||
if version_dict.get("when")
|
||||
else None
|
||||
)
|
||||
author_email = version_dict.get("by", {}).get("email")
|
||||
|
||||
title = confluence_object.get("title", "Untitled Document")
|
||||
|
||||
return Document(
|
||||
id=object_url,
|
||||
sections=[Section(link=object_url, text=object_text)],
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=confluence_object["title"],
|
||||
semantic_identifier=title,
|
||||
doc_updated_at=last_modified,
|
||||
primary_owners=(
|
||||
[BasicExpertInfo(email=author_email)] if author_email else None
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
@@ -45,7 +46,17 @@ class ConnectorRunner:
|
||||
def run(self) -> GenerateDocumentsOutput:
|
||||
"""Adds additional exception logging to the connector."""
|
||||
try:
|
||||
yield from self.doc_batch_generator
|
||||
start = time.monotonic()
|
||||
for batch in self.doc_batch_generator:
|
||||
# to know how long connector is taking
|
||||
logger.debug(
|
||||
f"Connector took {time.monotonic() - start} seconds to build a batch."
|
||||
)
|
||||
|
||||
yield batch
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
except Exception:
|
||||
exc_type, _, exc_traceback = sys.exc_info()
|
||||
|
||||
|
||||
@@ -50,6 +50,9 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
current_link = ""
|
||||
current_text = ""
|
||||
|
||||
if transcript["sentences"] is None:
|
||||
return None
|
||||
|
||||
for sentence in transcript["sentences"]:
|
||||
if sentence["speaker_name"] != current_speaker_name:
|
||||
if current_speaker_name is not None:
|
||||
|
||||
@@ -150,6 +150,16 @@ class Document(DocumentBase):
|
||||
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
|
||||
source: DocumentSource
|
||||
|
||||
def get_total_char_length(self) -> int:
|
||||
"""Calculate the total character length of the document including sections, metadata, and identifiers."""
|
||||
section_length = sum(len(section.text) for section in self.sections)
|
||||
identifier_length = len(self.semantic_identifier) + len(self.title or "")
|
||||
metadata_length = sum(
|
||||
len(k) + len(v) if isinstance(v, str) else len(k) + sum(len(x) for x in v)
|
||||
for k, v in self.metadata.items()
|
||||
)
|
||||
return section_length + identifier_length + metadata_length
|
||||
|
||||
def to_short_descriptor(self) -> str:
|
||||
"""Used when logging the identity of a document"""
|
||||
return f"ID: '{self.id}'; Semantic ID: '{self.semantic_identifier}'"
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
import io
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from urllib.parse import unquote
|
||||
|
||||
import msal # type: ignore
|
||||
from office365.graph_client import GraphClient # type: ignore
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore
|
||||
from office365.onedrive.sites.site import Site # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -29,16 +27,25 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@dataclass
|
||||
class SiteData:
|
||||
url: str | None
|
||||
folder: Optional[str]
|
||||
sites: list = field(default_factory=list)
|
||||
driveitems: list = field(default_factory=list)
|
||||
class SiteDescriptor(BaseModel):
|
||||
"""Data class for storing SharePoint site information.
|
||||
|
||||
Args:
|
||||
url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests)
|
||||
drive_name: The name of the drive to access (e.g. "Shared Documents", "Other Library")
|
||||
If None, all drives will be accessed.
|
||||
folder_path: The folder path within the drive to access (e.g. "test/nested with spaces")
|
||||
If None, all folders will be accessed.
|
||||
"""
|
||||
|
||||
url: str
|
||||
drive_name: str | None
|
||||
folder_path: str | None
|
||||
|
||||
|
||||
def _convert_driveitem_to_document(
|
||||
driveitem: DriveItem,
|
||||
drive_name: str,
|
||||
) -> Document:
|
||||
file_text = extract_file_text(
|
||||
file=io.BytesIO(driveitem.get_content().execute_query().value),
|
||||
@@ -58,7 +65,7 @@ def _convert_driveitem_to_document(
|
||||
email=driveitem.last_modified_by.user.email,
|
||||
)
|
||||
],
|
||||
metadata={},
|
||||
metadata={"drive": drive_name},
|
||||
)
|
||||
return doc
|
||||
|
||||
@@ -70,93 +77,179 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
sites: list[str] = [],
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.graph_client: GraphClient | None = None
|
||||
self.site_data: list[SiteData] = self._extract_site_and_folder(sites)
|
||||
self._graph_client: GraphClient | None = None
|
||||
self.site_descriptors: list[SiteDescriptor] = self._extract_site_and_drive_info(
|
||||
sites
|
||||
)
|
||||
self.msal_app: msal.ConfidentialClientApplication | None = None
|
||||
|
||||
@property
|
||||
def graph_client(self) -> GraphClient:
|
||||
if self._graph_client is None:
|
||||
raise ConnectorMissingCredentialError("Sharepoint")
|
||||
|
||||
return self._graph_client
|
||||
|
||||
@staticmethod
|
||||
def _extract_site_and_folder(site_urls: list[str]) -> list[SiteData]:
|
||||
def _extract_site_and_drive_info(site_urls: list[str]) -> list[SiteDescriptor]:
|
||||
site_data_list = []
|
||||
for url in site_urls:
|
||||
parts = url.strip().split("/")
|
||||
if "sites" in parts:
|
||||
sites_index = parts.index("sites")
|
||||
site_url = "/".join(parts[: sites_index + 2])
|
||||
folder = (
|
||||
parts[sites_index + 2] if len(parts) > sites_index + 2 else None
|
||||
)
|
||||
remaining_parts = parts[sites_index + 2 :]
|
||||
|
||||
# Extract drive name and folder path
|
||||
if remaining_parts:
|
||||
drive_name = unquote(remaining_parts[0])
|
||||
folder_path = (
|
||||
"/".join(unquote(part) for part in remaining_parts[1:])
|
||||
if len(remaining_parts) > 1
|
||||
else None
|
||||
)
|
||||
else:
|
||||
drive_name = None
|
||||
folder_path = None
|
||||
|
||||
site_data_list.append(
|
||||
SiteData(url=site_url, folder=folder, sites=[], driveitems=[])
|
||||
SiteDescriptor(
|
||||
url=site_url,
|
||||
drive_name=drive_name,
|
||||
folder_path=folder_path,
|
||||
)
|
||||
)
|
||||
return site_data_list
|
||||
|
||||
def _populate_sitedata_driveitems(
|
||||
def _fetch_driveitems(
|
||||
self,
|
||||
site_descriptor: SiteDescriptor,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> None:
|
||||
filter_str = ""
|
||||
if start is not None and end is not None:
|
||||
filter_str = f"last_modified_datetime ge {start.isoformat()} and last_modified_datetime le {end.isoformat()}"
|
||||
) -> list[tuple[DriveItem, str]]:
|
||||
final_driveitems: list[tuple[DriveItem, str]] = []
|
||||
try:
|
||||
site = self.graph_client.sites.get_by_url(site_descriptor.url)
|
||||
|
||||
for element in self.site_data:
|
||||
sites: list[Site] = []
|
||||
for site in element.sites:
|
||||
site_sublist = site.lists.get().execute_query()
|
||||
sites.extend(site_sublist)
|
||||
# Get all drives in the site
|
||||
drives = site.drives.get().execute_query()
|
||||
logger.debug(f"Found drives: {[drive.name for drive in drives]}")
|
||||
|
||||
for site in sites:
|
||||
# Filter drives based on the requested drive name
|
||||
if site_descriptor.drive_name:
|
||||
drives = [
|
||||
drive
|
||||
for drive in drives
|
||||
if drive.name == site_descriptor.drive_name
|
||||
or (
|
||||
drive.name == "Documents"
|
||||
and site_descriptor.drive_name == "Shared Documents"
|
||||
)
|
||||
]
|
||||
if not drives:
|
||||
logger.warning(f"Drive '{site_descriptor.drive_name}' not found")
|
||||
return []
|
||||
|
||||
# Process each matching drive
|
||||
for drive in drives:
|
||||
try:
|
||||
query = site.drive.root.get_files(True, 1000)
|
||||
if filter_str:
|
||||
query = query.filter(filter_str)
|
||||
root_folder = drive.root
|
||||
if site_descriptor.folder_path:
|
||||
# If a specific folder is requested, navigate to it
|
||||
for folder_part in site_descriptor.folder_path.split("/"):
|
||||
root_folder = root_folder.get_by_path(folder_part)
|
||||
|
||||
# Get all items recursively
|
||||
query = root_folder.get_files(
|
||||
recursive=True,
|
||||
page_size=1000,
|
||||
)
|
||||
driveitems = query.execute_query()
|
||||
if element.folder:
|
||||
filtered_driveitems = [
|
||||
logger.debug(
|
||||
f"Found {len(driveitems)} items in drive '{drive.name}'"
|
||||
)
|
||||
|
||||
# Use "Shared Documents" as the library name for the default "Documents" drive
|
||||
drive_name = (
|
||||
"Shared Documents" if drive.name == "Documents" else drive.name
|
||||
)
|
||||
|
||||
# Filter items based on folder path if specified
|
||||
if site_descriptor.folder_path:
|
||||
# Filter items to ensure they're in the specified folder or its subfolders
|
||||
# The path will be in format: /drives/{drive_id}/root:/folder/path
|
||||
driveitems = [
|
||||
item
|
||||
for item in driveitems
|
||||
if element.folder in item.parent_reference.path
|
||||
if any(
|
||||
path_part == site_descriptor.folder_path
|
||||
or path_part.startswith(
|
||||
site_descriptor.folder_path + "/"
|
||||
)
|
||||
for path_part in item.parent_reference.path.split(
|
||||
"root:/"
|
||||
)[1].split("/")
|
||||
)
|
||||
]
|
||||
element.driveitems.extend(filtered_driveitems)
|
||||
else:
|
||||
element.driveitems.extend(driveitems)
|
||||
if len(driveitems) == 0:
|
||||
all_paths = [
|
||||
item.parent_reference.path for item in driveitems
|
||||
]
|
||||
logger.warning(
|
||||
f"Nothing found for folder '{site_descriptor.folder_path}' "
|
||||
f"in; any of valid paths: {all_paths}"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Sites include things that do not contain .drive.root so this fails
|
||||
# but this is fine, as there are no actually documents in those
|
||||
pass
|
||||
# Filter items based on time window if specified
|
||||
if start is not None and end is not None:
|
||||
driveitems = [
|
||||
item
|
||||
for item in driveitems
|
||||
if start
|
||||
<= item.last_modified_datetime.replace(tzinfo=timezone.utc)
|
||||
<= end
|
||||
]
|
||||
logger.debug(
|
||||
f"Found {len(driveitems)} items within time window in drive '{drive.name}'"
|
||||
)
|
||||
|
||||
def _populate_sitedata_sites(self) -> None:
|
||||
if self.graph_client is None:
|
||||
raise ConnectorMissingCredentialError("Sharepoint")
|
||||
for item in driveitems:
|
||||
final_driveitems.append((item, drive_name))
|
||||
|
||||
if self.site_data:
|
||||
for element in self.site_data:
|
||||
element.sites = [
|
||||
self.graph_client.sites.get_by_url(element.url)
|
||||
.get()
|
||||
.execute_query()
|
||||
]
|
||||
else:
|
||||
sites = self.graph_client.sites.get_all().execute_query()
|
||||
self.site_data = [
|
||||
SiteData(url=None, folder=None, sites=sites, driveitems=[])
|
||||
]
|
||||
except Exception as e:
|
||||
# Some drives might not be accessible
|
||||
logger.warning(f"Failed to process drive: {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
# Sites include things that do not contain drives so this fails
|
||||
# but this is fine, as there are no actual documents in those
|
||||
logger.warning(f"Failed to process site: {str(e)}")
|
||||
|
||||
return final_driveitems
|
||||
|
||||
def _fetch_sites(self) -> list[SiteDescriptor]:
|
||||
sites = self.graph_client.sites.get_all().execute_query()
|
||||
site_descriptors = [
|
||||
SiteDescriptor(
|
||||
url=sites.resource_url,
|
||||
drive_name=None,
|
||||
folder_path=None,
|
||||
)
|
||||
]
|
||||
return site_descriptors
|
||||
|
||||
def _fetch_from_sharepoint(
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.graph_client is None:
|
||||
raise ConnectorMissingCredentialError("Sharepoint")
|
||||
|
||||
self._populate_sitedata_sites()
|
||||
self._populate_sitedata_driveitems(start=start, end=end)
|
||||
site_descriptors = self.site_descriptors or self._fetch_sites()
|
||||
|
||||
# goes over all urls, converts them into Document objects and then yields them in batches
|
||||
doc_batch: list[Document] = []
|
||||
for element in self.site_data:
|
||||
for driveitem in element.driveitems:
|
||||
for site_descriptor in site_descriptors:
|
||||
driveitems = self._fetch_driveitems(site_descriptor, start=start, end=end)
|
||||
for driveitem, drive_name in driveitems:
|
||||
logger.debug(f"Processing: {driveitem.web_url}")
|
||||
doc_batch.append(_convert_driveitem_to_document(driveitem))
|
||||
doc_batch.append(_convert_driveitem_to_document(driveitem, drive_name))
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
@@ -168,22 +261,26 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
sp_client_secret = credentials["sp_client_secret"]
|
||||
sp_directory_id = credentials["sp_directory_id"]
|
||||
|
||||
authority_url = f"https://login.microsoftonline.com/{sp_directory_id}"
|
||||
self.msal_app = msal.ConfidentialClientApplication(
|
||||
authority=authority_url,
|
||||
client_id=sp_client_id,
|
||||
client_credential=sp_client_secret,
|
||||
)
|
||||
|
||||
def _acquire_token_func() -> dict[str, Any]:
|
||||
"""
|
||||
Acquire token via MSAL
|
||||
"""
|
||||
authority_url = f"https://login.microsoftonline.com/{sp_directory_id}"
|
||||
app = msal.ConfidentialClientApplication(
|
||||
authority=authority_url,
|
||||
client_id=sp_client_id,
|
||||
client_credential=sp_client_secret,
|
||||
)
|
||||
token = app.acquire_token_for_client(
|
||||
if self.msal_app is None:
|
||||
raise RuntimeError("MSAL app is not initialized")
|
||||
|
||||
token = self.msal_app.acquire_token_for_client(
|
||||
scopes=["https://graph.microsoft.com/.default"]
|
||||
)
|
||||
return token
|
||||
|
||||
self.graph_client = GraphClient(_acquire_token_func)
|
||||
self._graph_client = GraphClient(_acquire_token_func)
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
@@ -192,19 +289,19 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.utcfromtimestamp(start)
|
||||
end_datetime = datetime.utcfromtimestamp(end)
|
||||
start_datetime = datetime.fromtimestamp(start, timezone.utc)
|
||||
end_datetime = datetime.fromtimestamp(end, timezone.utc)
|
||||
return self._fetch_from_sharepoint(start=start_datetime, end=end_datetime)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = SharepointConnector(sites=os.environ["SITES"].split(","))
|
||||
connector = SharepointConnector(sites=os.environ["SHAREPOINT_SITES"].split(","))
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
"sp_client_id": os.environ["SP_CLIENT_ID"],
|
||||
"sp_client_secret": os.environ["SP_CLIENT_SECRET"],
|
||||
"sp_directory_id": os.environ["SP_CLIENT_DIRECTORY_ID"],
|
||||
"sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"],
|
||||
"sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"],
|
||||
"sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
|
||||
@@ -104,8 +104,11 @@ def make_slack_api_rate_limited(
|
||||
f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}"
|
||||
)
|
||||
time.sleep(retry_after)
|
||||
elif error in ["already_reacted", "no_reaction"]:
|
||||
# The response isn't used for reactions, this is basically just a pass
|
||||
elif error in ["already_reacted", "no_reaction", "internal_error"]:
|
||||
# Log internal_error and return the response instead of failing
|
||||
logger.warning(
|
||||
f"Slack call encountered '{error}', skipping and continuing..."
|
||||
)
|
||||
return e.response
|
||||
else:
|
||||
# Raise the error for non-transient errors
|
||||
|
||||
@@ -180,23 +180,28 @@ class TeamsConnector(LoadConnector, PollConnector):
|
||||
self.batch_size = batch_size
|
||||
self.graph_client: GraphClient | None = None
|
||||
self.requested_team_list: list[str] = teams
|
||||
self.msal_app: msal.ConfidentialClientApplication | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
teams_client_id = credentials["teams_client_id"]
|
||||
teams_client_secret = credentials["teams_client_secret"]
|
||||
teams_directory_id = credentials["teams_directory_id"]
|
||||
|
||||
authority_url = f"https://login.microsoftonline.com/{teams_directory_id}"
|
||||
self.msal_app = msal.ConfidentialClientApplication(
|
||||
authority=authority_url,
|
||||
client_id=teams_client_id,
|
||||
client_credential=teams_client_secret,
|
||||
)
|
||||
|
||||
def _acquire_token_func() -> dict[str, Any]:
|
||||
"""
|
||||
Acquire token via MSAL
|
||||
"""
|
||||
authority_url = f"https://login.microsoftonline.com/{teams_directory_id}"
|
||||
app = msal.ConfidentialClientApplication(
|
||||
authority=authority_url,
|
||||
client_id=teams_client_id,
|
||||
client_credential=teams_client_secret,
|
||||
)
|
||||
token = app.acquire_token_for_client(
|
||||
if self.msal_app is None:
|
||||
raise RuntimeError("MSAL app is not initialized")
|
||||
|
||||
token = self.msal_app.acquire_token_for_client(
|
||||
scopes=["https://graph.microsoft.com/.default"]
|
||||
)
|
||||
return token
|
||||
|
||||
@@ -67,10 +67,7 @@ class SearchPipeline:
|
||||
self.rerank_metrics_callback = rerank_metrics_callback
|
||||
|
||||
self.search_settings = get_current_search_settings(db_session)
|
||||
self.document_index = get_default_document_index(
|
||||
primary_index_name=self.search_settings.index_name,
|
||||
secondary_index_name=None,
|
||||
)
|
||||
self.document_index = get_default_document_index(self.search_settings, None)
|
||||
self.prompt_config: PromptConfig | None = prompt_config
|
||||
|
||||
# Preprocessing steps generate this
|
||||
|
||||
@@ -28,6 +28,9 @@ class SyncType(str, PyEnum):
|
||||
DOCUMENT_SET = "document_set"
|
||||
USER_GROUP = "user_group"
|
||||
CONNECTOR_DELETION = "connector_deletion"
|
||||
PRUNING = "pruning" # not really a sync, but close enough
|
||||
EXTERNAL_PERMISSIONS = "external_permissions"
|
||||
EXTERNAL_GROUP = "external_group"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.value
|
||||
|
||||
@@ -3,6 +3,8 @@ from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import LLMProvider as LLMProviderModel
|
||||
@@ -124,10 +126,29 @@ def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolM
|
||||
|
||||
def fetch_existing_llm_providers(
|
||||
db_session: Session,
|
||||
) -> list[LLMProviderModel]:
|
||||
stmt = select(LLMProviderModel)
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
def fetch_existing_llm_providers_for_user(
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
) -> list[LLMProviderModel]:
|
||||
if not user:
|
||||
return list(db_session.scalars(select(LLMProviderModel)).all())
|
||||
if AUTH_TYPE != AuthType.DISABLED:
|
||||
# User is anonymous
|
||||
return list(
|
||||
db_session.scalars(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_public == True # noqa: E712
|
||||
)
|
||||
).all()
|
||||
)
|
||||
else:
|
||||
# If auth is disabled, user has access to all providers
|
||||
return fetch_existing_llm_providers(db_session)
|
||||
|
||||
stmt = select(LLMProviderModel).distinct()
|
||||
user_groups_select = select(User__UserGroup.user_group_id).where(
|
||||
User__UserGroup.user_id == user.id
|
||||
|
||||
@@ -161,9 +161,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
hidden_assistants: Mapped[list[int]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False, default=[]
|
||||
)
|
||||
recent_assistants: Mapped[list[dict]] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False, default=list, server_default="[]"
|
||||
)
|
||||
|
||||
pinned_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
@@ -747,6 +745,34 @@ class SearchSettings(Base):
|
||||
def api_key(self) -> str | None:
|
||||
return self.cloud_provider.api_key if self.cloud_provider is not None else None
|
||||
|
||||
@property
|
||||
def large_chunks_enabled(self) -> bool:
|
||||
"""
|
||||
Given multipass usage and an embedder, decides whether large chunks are allowed
|
||||
based on model/provider constraints.
|
||||
"""
|
||||
# Only local models that support a larger context are from Nomic
|
||||
# Cohere does not support larger contexts (they recommend not going above ~512 tokens)
|
||||
return SearchSettings.can_use_large_chunks(
|
||||
self.multipass_indexing, self.model_name, self.provider_type
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def can_use_large_chunks(
|
||||
multipass: bool, model_name: str, provider_type: EmbeddingProvider | None
|
||||
) -> bool:
|
||||
"""
|
||||
Given multipass usage and an embedder, decides whether large chunks are allowed
|
||||
based on model/provider constraints.
|
||||
"""
|
||||
# Only local models that support a larger context are from Nomic
|
||||
# Cohere does not support larger contexts (they recommend not going above ~512 tokens)
|
||||
return (
|
||||
multipass
|
||||
and model_name.startswith("nomic-ai")
|
||||
and provider_type != EmbeddingProvider.COHERE
|
||||
)
|
||||
|
||||
|
||||
class IndexAttempt(Base):
|
||||
"""
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import aliased
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
@@ -291,8 +291,9 @@ def get_personas_for_user(
|
||||
include_deleted: bool = False,
|
||||
joinedload_all: bool = False,
|
||||
) -> Sequence[Persona]:
|
||||
stmt = select(Persona).distinct()
|
||||
stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable)
|
||||
stmt = select(Persona)
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
|
||||
if not include_default:
|
||||
stmt = stmt.where(Persona.builtin_persona.is_(False))
|
||||
if not include_slack_bot_personas:
|
||||
@@ -302,14 +303,16 @@ def get_personas_for_user(
|
||||
|
||||
if joinedload_all:
|
||||
stmt = stmt.options(
|
||||
joinedload(Persona.prompts),
|
||||
joinedload(Persona.tools),
|
||||
joinedload(Persona.document_sets),
|
||||
joinedload(Persona.groups),
|
||||
joinedload(Persona.users),
|
||||
selectinload(Persona.prompts),
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.groups),
|
||||
selectinload(Persona.users),
|
||||
selectinload(Persona.labels),
|
||||
)
|
||||
|
||||
return db_session.execute(stmt).unique().scalars().all()
|
||||
results = db_session.execute(stmt).scalars().all()
|
||||
return results
|
||||
|
||||
|
||||
def get_personas(db_session: Session) -> Sequence[Persona]:
|
||||
|
||||
@@ -29,9 +29,21 @@ from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import PRESERVED_SEARCH_FIELDS
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ActiveSearchSettings:
|
||||
primary: SearchSettings
|
||||
secondary: SearchSettings | None
|
||||
|
||||
def __init__(
|
||||
self, primary: SearchSettings, secondary: SearchSettings | None
|
||||
) -> None:
|
||||
self.primary = primary
|
||||
self.secondary = secondary
|
||||
|
||||
|
||||
def create_search_settings(
|
||||
search_settings: SavedSearchSettings,
|
||||
db_session: Session,
|
||||
@@ -143,21 +155,27 @@ def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
|
||||
return latest_settings
|
||||
|
||||
|
||||
def get_active_search_settings(db_session: Session) -> list[SearchSettings]:
|
||||
"""Returns active search settings. The first entry will always be the current search
|
||||
settings. If there are new search settings that are being migrated to, those will be
|
||||
the second entry."""
|
||||
def get_active_search_settings(db_session: Session) -> ActiveSearchSettings:
|
||||
"""Returns active search settings. Secondary search settings may be None."""
|
||||
|
||||
# Get the primary and secondary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
return ActiveSearchSettings(
|
||||
primary=primary_search_settings, secondary=secondary_search_settings
|
||||
)
|
||||
|
||||
|
||||
def get_active_search_settings_list(db_session: Session) -> list[SearchSettings]:
|
||||
"""Returns active search settings as a list. Primary settings are the first element,
|
||||
and if secondary search settings exist, they will be the second element."""
|
||||
|
||||
search_settings_list: list[SearchSettings] = []
|
||||
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings_list.append(primary_search_settings)
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings_list.append(secondary_search_settings)
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
search_settings_list.append(active_search_settings.primary)
|
||||
if active_search_settings.secondary:
|
||||
search_settings_list.append(active_search_settings.secondary)
|
||||
|
||||
return search_settings_list
|
||||
|
||||
|
||||
@@ -8,20 +8,64 @@ from sqlalchemy.orm import Session
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import SyncRecord
|
||||
from onyx.setup import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def insert_sync_record(
|
||||
db_session: Session,
|
||||
entity_id: int | None,
|
||||
entity_id: int,
|
||||
sync_type: SyncType,
|
||||
) -> SyncRecord:
|
||||
"""Insert a new sync record into the database.
|
||||
"""Insert a new sync record into the database, cancelling any existing in-progress records.
|
||||
|
||||
Args:
|
||||
db_session: The database session to use
|
||||
entity_id: The ID of the entity being synced (document set ID, user group ID, etc.)
|
||||
sync_type: The type of sync operation
|
||||
"""
|
||||
# If an existing in-progress sync record exists, mark as cancelled
|
||||
existing_in_progress_sync_record = fetch_latest_sync_record(
|
||||
db_session, entity_id, sync_type, sync_status=SyncStatus.IN_PROGRESS
|
||||
)
|
||||
|
||||
if existing_in_progress_sync_record is not None:
|
||||
logger.info(
|
||||
f"Cancelling existing in-progress sync record {existing_in_progress_sync_record.id} "
|
||||
f"for entity_id={entity_id} sync_type={sync_type}"
|
||||
)
|
||||
mark_sync_records_as_cancelled(db_session, entity_id, sync_type)
|
||||
|
||||
return _create_sync_record(db_session, entity_id, sync_type)
|
||||
|
||||
|
||||
def mark_sync_records_as_cancelled(
|
||||
db_session: Session,
|
||||
entity_id: int | None,
|
||||
sync_type: SyncType,
|
||||
) -> None:
|
||||
stmt = (
|
||||
update(SyncRecord)
|
||||
.where(
|
||||
and_(
|
||||
SyncRecord.entity_id == entity_id,
|
||||
SyncRecord.sync_type == sync_type,
|
||||
SyncRecord.sync_status == SyncStatus.IN_PROGRESS,
|
||||
)
|
||||
)
|
||||
.values(sync_status=SyncStatus.CANCELED)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _create_sync_record(
|
||||
db_session: Session,
|
||||
entity_id: int | None,
|
||||
sync_type: SyncType,
|
||||
) -> SyncRecord:
|
||||
"""Create and insert a new sync record into the database."""
|
||||
sync_record = SyncRecord(
|
||||
entity_id=entity_id,
|
||||
sync_type=sync_type,
|
||||
@@ -39,6 +83,7 @@ def fetch_latest_sync_record(
|
||||
db_session: Session,
|
||||
entity_id: int,
|
||||
sync_type: SyncType,
|
||||
sync_status: SyncStatus | None = None,
|
||||
) -> SyncRecord | None:
|
||||
"""Fetch the most recent sync record for a given entity ID and status.
|
||||
|
||||
@@ -59,6 +104,9 @@ def fetch_latest_sync_record(
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if sync_status is not None:
|
||||
stmt = stmt.where(SyncRecord.sync_status == sync_status)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
@@ -4,24 +4,63 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import ENABLE_MULTIPASS_INDEXING
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import MultipassConfig
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
DEFAULT_BATCH_SIZE = 30
|
||||
DEFAULT_INDEX_NAME = "danswer_chunk"
|
||||
|
||||
|
||||
def get_both_index_names(db_session: Session) -> tuple[str, str | None]:
|
||||
def should_use_multipass(search_settings: SearchSettings | None) -> bool:
|
||||
"""
|
||||
Determines whether multipass should be used based on the search settings
|
||||
or the default config if settings are unavailable.
|
||||
"""
|
||||
if search_settings is not None:
|
||||
return search_settings.multipass_indexing
|
||||
return ENABLE_MULTIPASS_INDEXING
|
||||
|
||||
|
||||
def get_multipass_config(search_settings: SearchSettings) -> MultipassConfig:
|
||||
"""
|
||||
Determines whether to enable multipass and large chunks by examining
|
||||
the current search settings and the embedder configuration.
|
||||
"""
|
||||
if not search_settings:
|
||||
return MultipassConfig(multipass_indexing=False, enable_large_chunks=False)
|
||||
|
||||
multipass = should_use_multipass(search_settings)
|
||||
enable_large_chunks = SearchSettings.can_use_large_chunks(
|
||||
multipass, search_settings.model_name, search_settings.provider_type
|
||||
)
|
||||
return MultipassConfig(
|
||||
multipass_indexing=multipass, enable_large_chunks=enable_large_chunks
|
||||
)
|
||||
|
||||
|
||||
def get_both_index_properties(
|
||||
db_session: Session,
|
||||
) -> tuple[str, str | None, bool, bool | None]:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
config_1 = get_multipass_config(search_settings)
|
||||
|
||||
search_settings_new = get_secondary_search_settings(db_session)
|
||||
if not search_settings_new:
|
||||
return search_settings.index_name, None
|
||||
return search_settings.index_name, None, config_1.enable_large_chunks, None
|
||||
|
||||
return search_settings.index_name, search_settings_new.index_name
|
||||
config_2 = get_multipass_config(search_settings)
|
||||
return (
|
||||
search_settings.index_name,
|
||||
search_settings_new.index_name,
|
||||
config_1.enable_large_chunks,
|
||||
config_2.enable_large_chunks,
|
||||
)
|
||||
|
||||
|
||||
def translate_boost_count_to_multiplier(boost: int) -> float:
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
@@ -7,17 +9,28 @@ from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
def get_default_document_index(
|
||||
primary_index_name: str,
|
||||
secondary_index_name: str | None,
|
||||
search_settings: SearchSettings,
|
||||
secondary_search_settings: SearchSettings | None,
|
||||
httpx_client: httpx.Client | None = None,
|
||||
) -> DocumentIndex:
|
||||
"""Primary index is the index that is used for querying/updating etc.
|
||||
Secondary index is for when both the currently used index and the upcoming
|
||||
index both need to be updated, updates are applied to both indices"""
|
||||
|
||||
secondary_index_name: str | None = None
|
||||
secondary_large_chunks_enabled: bool | None = None
|
||||
if secondary_search_settings:
|
||||
secondary_index_name = secondary_search_settings.index_name
|
||||
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
|
||||
|
||||
# Currently only supporting Vespa
|
||||
return VespaIndex(
|
||||
index_name=primary_index_name,
|
||||
index_name=search_settings.index_name,
|
||||
secondary_index_name=secondary_index_name,
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
)
|
||||
|
||||
|
||||
@@ -27,6 +40,6 @@ def get_current_primary_default_document_index(db_session: Session) -> DocumentI
|
||||
"""
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
return get_default_document_index(
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=None,
|
||||
search_settings,
|
||||
None,
|
||||
)
|
||||
|
||||
@@ -231,21 +231,22 @@ def _get_chunks_via_visit_api(
|
||||
return document_chunks
|
||||
|
||||
|
||||
@retry(tries=10, delay=1, backoff=2)
|
||||
def get_all_vespa_ids_for_document_id(
|
||||
document_id: str,
|
||||
index_name: str,
|
||||
filters: IndexFilters | None = None,
|
||||
get_large_chunks: bool = False,
|
||||
) -> list[str]:
|
||||
document_chunks = _get_chunks_via_visit_api(
|
||||
chunk_request=VespaChunkRequest(document_id=document_id),
|
||||
index_name=index_name,
|
||||
filters=filters or IndexFilters(access_control_list=None),
|
||||
field_names=[DOCUMENT_ID],
|
||||
get_large_chunks=get_large_chunks,
|
||||
)
|
||||
return [chunk["id"].split("::", 1)[-1] for chunk in document_chunks]
|
||||
# TODO(rkuo): candidate for removal if not being used
|
||||
# @retry(tries=10, delay=1, backoff=2)
|
||||
# def get_all_vespa_ids_for_document_id(
|
||||
# document_id: str,
|
||||
# index_name: str,
|
||||
# filters: IndexFilters | None = None,
|
||||
# get_large_chunks: bool = False,
|
||||
# ) -> list[str]:
|
||||
# document_chunks = _get_chunks_via_visit_api(
|
||||
# chunk_request=VespaChunkRequest(document_id=document_id),
|
||||
# index_name=index_name,
|
||||
# filters=filters or IndexFilters(access_control_list=None),
|
||||
# field_names=[DOCUMENT_ID],
|
||||
# get_large_chunks=get_large_chunks,
|
||||
# )
|
||||
# return [chunk["id"].split("::", 1)[-1] for chunk in document_chunks]
|
||||
|
||||
|
||||
def parallel_visit_api_retrieval(
|
||||
|
||||
@@ -25,7 +25,6 @@ from onyx.configs.chat_configs import VESPA_SEARCHER_THREADS
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.document_index.document_index_utils import get_document_chunk_ids
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentInsertionRecord
|
||||
@@ -41,12 +40,12 @@ from onyx.document_index.vespa.chunk_retrieval import (
|
||||
)
|
||||
from onyx.document_index.vespa.chunk_retrieval import query_vespa
|
||||
from onyx.document_index.vespa.deletion import delete_vespa_chunks
|
||||
from onyx.document_index.vespa.indexing_utils import BaseHTTPXClientContext
|
||||
from onyx.document_index.vespa.indexing_utils import batch_index_vespa_chunks
|
||||
from onyx.document_index.vespa.indexing_utils import check_for_final_chunk_existence
|
||||
from onyx.document_index.vespa.indexing_utils import clean_chunk_id_copy
|
||||
from onyx.document_index.vespa.indexing_utils import (
|
||||
get_multipass_config,
|
||||
)
|
||||
from onyx.document_index.vespa.indexing_utils import GlobalHTTPXClientContext
|
||||
from onyx.document_index.vespa.indexing_utils import TemporaryHTTPXClientContext
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
@@ -132,12 +131,34 @@ class VespaIndex(DocumentIndex):
|
||||
self,
|
||||
index_name: str,
|
||||
secondary_index_name: str | None,
|
||||
large_chunks_enabled: bool,
|
||||
secondary_large_chunks_enabled: bool | None,
|
||||
multitenant: bool = False,
|
||||
httpx_client: httpx.Client | None = None,
|
||||
) -> None:
|
||||
self.index_name = index_name
|
||||
self.secondary_index_name = secondary_index_name
|
||||
|
||||
self.large_chunks_enabled = large_chunks_enabled
|
||||
self.secondary_large_chunks_enabled = secondary_large_chunks_enabled
|
||||
|
||||
self.multitenant = multitenant
|
||||
self.http_client = get_vespa_http_client()
|
||||
|
||||
self.httpx_client_context: BaseHTTPXClientContext
|
||||
|
||||
if httpx_client:
|
||||
self.httpx_client_context = GlobalHTTPXClientContext(httpx_client)
|
||||
else:
|
||||
self.httpx_client_context = TemporaryHTTPXClientContext(
|
||||
get_vespa_http_client
|
||||
)
|
||||
|
||||
self.index_to_large_chunks_enabled: dict[str, bool] = {}
|
||||
self.index_to_large_chunks_enabled[index_name] = large_chunks_enabled
|
||||
if secondary_index_name and secondary_large_chunks_enabled:
|
||||
self.index_to_large_chunks_enabled[
|
||||
secondary_index_name
|
||||
] = secondary_large_chunks_enabled
|
||||
|
||||
def ensure_indices_exist(
|
||||
self,
|
||||
@@ -331,7 +352,7 @@ class VespaIndex(DocumentIndex):
|
||||
# indexing / updates / deletes since we have to make a large volume of requests.
|
||||
with (
|
||||
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
|
||||
get_vespa_http_client() as http_client,
|
||||
self.httpx_client_context as http_client,
|
||||
):
|
||||
# We require the start and end index for each document in order to
|
||||
# know precisely which chunks to delete. This information exists for
|
||||
@@ -390,9 +411,11 @@ class VespaIndex(DocumentIndex):
|
||||
for doc_id in all_doc_ids
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
@classmethod
|
||||
def _apply_updates_batched(
|
||||
cls,
|
||||
updates: list[_VespaUpdateRequest],
|
||||
httpx_client: httpx.Client,
|
||||
batch_size: int = BATCH_SIZE,
|
||||
) -> None:
|
||||
"""Runs a batch of updates in parallel via the ThreadPoolExecutor."""
|
||||
@@ -414,7 +437,7 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
with (
|
||||
concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor,
|
||||
get_vespa_http_client() as http_client,
|
||||
httpx_client as http_client,
|
||||
):
|
||||
for update_batch in batch_generator(updates, batch_size):
|
||||
future_to_document_id = {
|
||||
@@ -455,7 +478,7 @@ class VespaIndex(DocumentIndex):
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
chunk_id_start_time = time.monotonic()
|
||||
with get_vespa_http_client() as http_client:
|
||||
with self.httpx_client_context as http_client:
|
||||
for update_request in update_requests:
|
||||
for doc_info in update_request.minimal_document_indexing_info:
|
||||
for index_name in index_names:
|
||||
@@ -511,7 +534,8 @@ class VespaIndex(DocumentIndex):
|
||||
)
|
||||
)
|
||||
|
||||
self._apply_updates_batched(processed_updates_requests)
|
||||
with self.httpx_client_context as httpx_client:
|
||||
self._apply_updates_batched(processed_updates_requests, httpx_client)
|
||||
logger.debug(
|
||||
"Finished updating Vespa documents in %.2f seconds",
|
||||
time.monotonic() - update_start,
|
||||
@@ -523,6 +547,7 @@ class VespaIndex(DocumentIndex):
|
||||
index_name: str,
|
||||
fields: VespaDocumentFields,
|
||||
doc_id: str,
|
||||
http_client: httpx.Client,
|
||||
) -> None:
|
||||
"""
|
||||
Update a single "chunk" (document) in Vespa using its chunk ID.
|
||||
@@ -554,18 +579,17 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}?create=true"
|
||||
|
||||
with get_vespa_http_client(http2=False) as http_client:
|
||||
try:
|
||||
resp = http_client.put(
|
||||
vespa_url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=update_dict,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_message = f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). Details: {e.response.text}"
|
||||
logger.error(error_message)
|
||||
raise
|
||||
try:
|
||||
resp = http_client.put(
|
||||
vespa_url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=update_dict,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_message = f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). Details: {e.response.text}"
|
||||
logger.error(error_message)
|
||||
raise
|
||||
|
||||
def update_single(
|
||||
self,
|
||||
@@ -579,24 +603,16 @@ class VespaIndex(DocumentIndex):
|
||||
function will complete with no errors or exceptions.
|
||||
Handle other exceptions if you wish to implement retry behavior
|
||||
"""
|
||||
|
||||
doc_chunk_count = 0
|
||||
|
||||
index_names = [self.index_name]
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
with get_vespa_http_client(http2=False) as http_client:
|
||||
for index_name in index_names:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
multipass_config = get_multipass_config(
|
||||
db_session=db_session,
|
||||
primary_index=index_name == self.index_name,
|
||||
)
|
||||
large_chunks_enabled = multipass_config.enable_large_chunks
|
||||
with self.httpx_client_context as httpx_client:
|
||||
for (
|
||||
index_name,
|
||||
large_chunks_enabled,
|
||||
) in self.index_to_large_chunks_enabled.items():
|
||||
enriched_doc_infos = VespaIndex.enrich_basic_chunk_info(
|
||||
index_name=index_name,
|
||||
http_client=http_client,
|
||||
http_client=httpx_client,
|
||||
document_id=doc_id,
|
||||
previous_chunk_count=chunk_count,
|
||||
new_chunk_count=0,
|
||||
@@ -612,10 +628,7 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
for doc_chunk_id in doc_chunk_ids:
|
||||
self.update_single_chunk(
|
||||
doc_chunk_id=doc_chunk_id,
|
||||
index_name=index_name,
|
||||
fields=fields,
|
||||
doc_id=doc_id,
|
||||
doc_chunk_id, index_name, fields, doc_id, httpx_client
|
||||
)
|
||||
|
||||
return doc_chunk_count
|
||||
@@ -637,19 +650,13 @@ class VespaIndex(DocumentIndex):
|
||||
if self.secondary_index_name:
|
||||
index_names.append(self.secondary_index_name)
|
||||
|
||||
with get_vespa_http_client(
|
||||
http2=False
|
||||
) as http_client, concurrent.futures.ThreadPoolExecutor(
|
||||
with self.httpx_client_context as http_client, concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=NUM_THREADS
|
||||
) as executor:
|
||||
for index_name in index_names:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
multipass_config = get_multipass_config(
|
||||
db_session=db_session,
|
||||
primary_index=index_name == self.index_name,
|
||||
)
|
||||
large_chunks_enabled = multipass_config.enable_large_chunks
|
||||
|
||||
for (
|
||||
index_name,
|
||||
large_chunks_enabled,
|
||||
) in self.index_to_large_chunks_enabled.items():
|
||||
enriched_doc_infos = VespaIndex.enrich_basic_chunk_info(
|
||||
index_name=index_name,
|
||||
http_client=http_client,
|
||||
@@ -818,6 +825,9 @@ class VespaIndex(DocumentIndex):
|
||||
"""
|
||||
Deletes all entries in the specified index with the given tenant_id.
|
||||
|
||||
Currently unused, but we anticipate this being useful. The entire flow does not
|
||||
use the httpx connection pool of an instance.
|
||||
|
||||
Parameters:
|
||||
tenant_id (str): The tenant ID whose documents are to be deleted.
|
||||
index_name (str): The name of the index from which to delete documents.
|
||||
@@ -850,6 +860,8 @@ class VespaIndex(DocumentIndex):
|
||||
"""
|
||||
Retrieves all document IDs with the specified tenant_id, handling pagination.
|
||||
|
||||
Internal helper function for delete_entries_by_tenant_id.
|
||||
|
||||
Parameters:
|
||||
tenant_id (str): The tenant ID to search for.
|
||||
index_name (str): The name of the index to search in.
|
||||
@@ -882,8 +894,8 @@ class VespaIndex(DocumentIndex):
|
||||
f"Querying for document IDs with tenant_id: {tenant_id}, offset: {offset}"
|
||||
)
|
||||
|
||||
with get_vespa_http_client(no_timeout=True) as http_client:
|
||||
response = http_client.get(url, params=query_params)
|
||||
with get_vespa_http_client() as http_client:
|
||||
response = http_client.get(url, params=query_params, timeout=None)
|
||||
response.raise_for_status()
|
||||
|
||||
search_result = response.json()
|
||||
@@ -913,6 +925,11 @@ class VespaIndex(DocumentIndex):
|
||||
"""
|
||||
Deletes documents in batches using multiple threads.
|
||||
|
||||
Internal helper function for delete_entries_by_tenant_id.
|
||||
|
||||
This is a class method and does not use the httpx pool of the instance.
|
||||
This is OK because we don't use this method often.
|
||||
|
||||
Parameters:
|
||||
delete_requests (List[_VespaDeleteRequest]): The list of delete requests.
|
||||
batch_size (int): The number of documents to delete in each batch.
|
||||
@@ -925,13 +942,14 @@ class VespaIndex(DocumentIndex):
|
||||
response = http_client.delete(
|
||||
delete_request.url,
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=None,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
logger.debug(f"Starting batch deletion for {len(delete_requests)} documents")
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor:
|
||||
with get_vespa_http_client(no_timeout=True) as http_client:
|
||||
with get_vespa_http_client() as http_client:
|
||||
for batch_start in range(0, len(delete_requests), batch_size):
|
||||
batch = delete_requests[batch_start : batch_start + batch_size]
|
||||
|
||||
|
||||
@@ -1,21 +1,19 @@
|
||||
import concurrent.futures
|
||||
import json
|
||||
import uuid
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
|
||||
import httpx
|
||||
from retry import retry
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import ENABLE_MULTIPASS_INDEXING
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk
|
||||
from onyx.document_index.document_index_utils import get_uuid_from_chunk_info_old
|
||||
from onyx.document_index.interfaces import MinimalDocumentIndexingInfo
|
||||
@@ -50,10 +48,9 @@ from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import TITLE
|
||||
from onyx.document_index.vespa_constants import TITLE_EMBEDDING
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import EmbeddingProvider
|
||||
from onyx.indexing.models import MultipassConfig
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -275,46 +272,42 @@ def check_for_final_chunk_existence(
|
||||
index += 1
|
||||
|
||||
|
||||
def should_use_multipass(search_settings: SearchSettings | None) -> bool:
|
||||
"""
|
||||
Determines whether multipass should be used based on the search settings
|
||||
or the default config if settings are unavailable.
|
||||
"""
|
||||
if search_settings is not None:
|
||||
return search_settings.multipass_indexing
|
||||
return ENABLE_MULTIPASS_INDEXING
|
||||
class BaseHTTPXClientContext(ABC):
|
||||
"""Abstract base class for an HTTPX client context manager."""
|
||||
|
||||
@abstractmethod
|
||||
def __enter__(self) -> httpx.Client:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __exit__(self, exc_type, exc_value, traceback): # type: ignore
|
||||
pass
|
||||
|
||||
|
||||
def can_use_large_chunks(multipass: bool, search_settings: SearchSettings) -> bool:
|
||||
"""
|
||||
Given multipass usage and an embedder, decides whether large chunks are allowed
|
||||
based on model/provider constraints.
|
||||
"""
|
||||
# Only local models that support a larger context are from Nomic
|
||||
# Cohere does not support larger contexts (they recommend not going above ~512 tokens)
|
||||
return (
|
||||
multipass
|
||||
and search_settings.model_name.startswith("nomic-ai")
|
||||
and search_settings.provider_type != EmbeddingProvider.COHERE
|
||||
)
|
||||
class GlobalHTTPXClientContext(BaseHTTPXClientContext):
|
||||
"""Context manager for a global HTTPX client that does not close it."""
|
||||
|
||||
def __init__(self, client: httpx.Client):
|
||||
self._client = client
|
||||
|
||||
def __enter__(self) -> httpx.Client:
|
||||
return self._client # Reuse the global client
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback): # type: ignore
|
||||
pass # Do nothing; don't close the global client
|
||||
|
||||
|
||||
def get_multipass_config(
|
||||
db_session: Session, primary_index: bool = True
|
||||
) -> MultipassConfig:
|
||||
"""
|
||||
Determines whether to enable multipass and large chunks by examining
|
||||
the current search settings and the embedder configuration.
|
||||
"""
|
||||
search_settings = (
|
||||
get_current_search_settings(db_session)
|
||||
if primary_index
|
||||
else get_secondary_search_settings(db_session)
|
||||
)
|
||||
multipass = should_use_multipass(search_settings)
|
||||
if not search_settings:
|
||||
return MultipassConfig(multipass_indexing=False, enable_large_chunks=False)
|
||||
enable_large_chunks = can_use_large_chunks(multipass, search_settings)
|
||||
return MultipassConfig(
|
||||
multipass_indexing=multipass, enable_large_chunks=enable_large_chunks
|
||||
)
|
||||
class TemporaryHTTPXClientContext(BaseHTTPXClientContext):
|
||||
"""Context manager for a temporary HTTPX client that closes it after use."""
|
||||
|
||||
def __init__(self, client_factory: Callable[[], httpx.Client]):
|
||||
self._client_factory = client_factory
|
||||
self._client: httpx.Client | None = None # Client will be created in __enter__
|
||||
|
||||
def __enter__(self) -> httpx.Client:
|
||||
self._client = self._client_factory() # Create a new client
|
||||
return self._client
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback): # type: ignore
|
||||
if self._client:
|
||||
self._client.close()
|
||||
|
||||
@@ -55,7 +55,7 @@ def remove_invalid_unicode_chars(text: str) -> str:
|
||||
"""Vespa does not take in unicode chars that aren't valid for XML.
|
||||
This removes them."""
|
||||
_illegal_xml_chars_RE: re.Pattern = re.compile(
|
||||
"[\x00-\x08\x0b\x0c\x0e-\x1F\uD800-\uDFFF\uFFFE\uFFFF]"
|
||||
"[\x00-\x08\x0b\x0c\x0e-\x1F\uD800-\uDFFF\uFDD0-\uFDEF\uFFFE\uFFFF]"
|
||||
)
|
||||
return _illegal_xml_chars_RE.sub("", text)
|
||||
|
||||
|
||||
@@ -358,7 +358,13 @@ def extract_file_text(
|
||||
|
||||
try:
|
||||
if get_unstructured_api_key():
|
||||
return unstructured_to_text(file, file_name)
|
||||
try:
|
||||
return unstructured_to_text(file, file_name)
|
||||
except Exception as unstructured_error:
|
||||
logger.error(
|
||||
f"Failed to process with Unstructured: {str(unstructured_error)}. Falling back to normal processing."
|
||||
)
|
||||
# Fall through to normal processing
|
||||
|
||||
if file_name or extension:
|
||||
if extension is not None:
|
||||
|
||||
@@ -52,7 +52,7 @@ def _sdk_partition_request(
|
||||
|
||||
def unstructured_to_text(file: IO[Any], file_name: str) -> str:
|
||||
logger.debug(f"Starting to read file: {file_name}")
|
||||
req = _sdk_partition_request(file, file_name, strategy="auto")
|
||||
req = _sdk_partition_request(file, file_name, strategy="fast")
|
||||
|
||||
unstructured_client = UnstructuredClient(api_key_auth=get_unstructured_api_key())
|
||||
|
||||
|
||||
57
backend/onyx/httpx/httpx_pool.py
Normal file
57
backend/onyx/httpx/httpx_pool.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
class HttpxPool:
|
||||
"""Class to manage a global httpx Client instance"""
|
||||
|
||||
_clients: dict[str, httpx.Client] = {}
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
|
||||
# Default parameters for creation
|
||||
DEFAULT_KWARGS = {
|
||||
"http2": True,
|
||||
"limits": lambda: httpx.Limits(),
|
||||
}
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def _init_client(cls, **kwargs: Any) -> httpx.Client:
|
||||
"""Private helper method to create and return an httpx.Client."""
|
||||
merged_kwargs = {**cls.DEFAULT_KWARGS, **kwargs}
|
||||
return httpx.Client(**merged_kwargs)
|
||||
|
||||
@classmethod
|
||||
def init_client(cls, name: str, **kwargs: Any) -> None:
|
||||
"""Allow the caller to init the client with extra params."""
|
||||
with cls._lock:
|
||||
if name not in cls._clients:
|
||||
cls._clients[name] = cls._init_client(**kwargs)
|
||||
|
||||
@classmethod
|
||||
def close_client(cls, name: str) -> None:
|
||||
"""Allow the caller to close the client."""
|
||||
with cls._lock:
|
||||
client = cls._clients.pop(name, None)
|
||||
if client:
|
||||
client.close()
|
||||
|
||||
@classmethod
|
||||
def close_all(cls) -> None:
|
||||
"""Close all registered clients."""
|
||||
with cls._lock:
|
||||
for client in cls._clients.values():
|
||||
client.close()
|
||||
cls._clients.clear()
|
||||
|
||||
@classmethod
|
||||
def get(cls, name: str) -> httpx.Client:
|
||||
"""Gets the httpx.Client. Will init to default settings if not init'd."""
|
||||
with cls._lock:
|
||||
if name not in cls._clients:
|
||||
cls._clients[name] = cls._init_client()
|
||||
return cls._clients[name]
|
||||
@@ -31,14 +31,15 @@ from onyx.db.document import upsert_documents
|
||||
from onyx.db.document_set import fetch_document_sets_for_documents
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.models import Document as DBDocument
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.tag import create_or_add_document_tag
|
||||
from onyx.db.tag import create_or_add_document_tag_list
|
||||
from onyx.document_index.document_index_utils import (
|
||||
get_multipass_config,
|
||||
)
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.document_index.vespa.indexing_utils import (
|
||||
get_multipass_config,
|
||||
)
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -357,7 +358,6 @@ def index_doc_batch(
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
logger.debug("Filtering Documents")
|
||||
filtered_documents = filter_fnc(document_batch)
|
||||
|
||||
ctx = index_doc_batch_prepare(
|
||||
@@ -380,6 +380,15 @@ def index_doc_batch(
|
||||
new_docs=0, total_docs=len(filtered_documents), total_chunks=0
|
||||
)
|
||||
|
||||
doc_descriptors = [
|
||||
{
|
||||
"doc_id": doc.id,
|
||||
"doc_length": doc.get_total_char_length(),
|
||||
}
|
||||
for doc in ctx.updatable_docs
|
||||
]
|
||||
logger.debug(f"Starting indexing process for documents: {doc_descriptors}")
|
||||
|
||||
logger.debug("Starting chunking")
|
||||
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
|
||||
|
||||
@@ -527,7 +536,8 @@ def build_indexing_pipeline(
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
|
||||
multipass_config = get_multipass_config(db_session, primary_index=True)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
multipass_config = get_multipass_config(search_settings)
|
||||
|
||||
chunker = chunker or Chunker(
|
||||
tokenizer=embedder.embedding_model.tokenizer,
|
||||
|
||||
@@ -55,9 +55,7 @@ class DocAwareChunk(BaseChunk):
|
||||
|
||||
def to_short_descriptor(self) -> str:
|
||||
"""Used when logging the identity of a chunk"""
|
||||
return (
|
||||
f"Chunk ID: '{self.chunk_id}'; {self.source_document.to_short_descriptor()}"
|
||||
)
|
||||
return f"{self.source_document.to_short_descriptor()} Chunk ID: {self.chunk_id}"
|
||||
|
||||
|
||||
class IndexChunk(DocAwareChunk):
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
|
||||
@@ -11,6 +13,7 @@ from requests import RequestException
|
||||
from requests import Response
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS
|
||||
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
|
||||
from onyx.configs.app_configs import SKIP_WARM_UP
|
||||
from onyx.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
@@ -155,6 +158,7 @@ class EmbeddingModel:
|
||||
text_type: EmbedTextType,
|
||||
batch_size: int,
|
||||
max_seq_length: int,
|
||||
num_threads: int = INDEXING_EMBEDDING_MODEL_NUM_THREADS,
|
||||
) -> list[Embedding]:
|
||||
text_batches = batch_list(texts, batch_size)
|
||||
|
||||
@@ -163,12 +167,15 @@ class EmbeddingModel:
|
||||
)
|
||||
|
||||
embeddings: list[Embedding] = []
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
|
||||
def process_batch(
|
||||
batch_idx: int, text_batch: list[str]
|
||||
) -> tuple[int, list[Embedding]]:
|
||||
if self.callback:
|
||||
if self.callback.should_stop():
|
||||
raise RuntimeError("_batch_encode_texts detected stop signal")
|
||||
|
||||
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
|
||||
logger.debug(f"Encoding batch {batch_idx} of {len(text_batches)}")
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=text_batch,
|
||||
@@ -185,10 +192,43 @@ class EmbeddingModel:
|
||||
)
|
||||
|
||||
response = self._make_model_server_request(embed_request)
|
||||
embeddings.extend(response.embeddings)
|
||||
return batch_idx, response.embeddings
|
||||
|
||||
# only multi thread if:
|
||||
# 1. num_threads is greater than 1
|
||||
# 2. we are using an API-based embedding model (provider_type is not None)
|
||||
# 3. there are more than 1 batch (no point in threading if only 1)
|
||||
if num_threads >= 1 and self.provider_type and len(text_batches) > 1:
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
future_to_batch = {
|
||||
executor.submit(process_batch, idx, batch): idx
|
||||
for idx, batch in enumerate(text_batches, start=1)
|
||||
}
|
||||
|
||||
# Collect results in order
|
||||
batch_results: list[tuple[int, list[Embedding]]] = []
|
||||
for future in as_completed(future_to_batch):
|
||||
try:
|
||||
result = future.result()
|
||||
batch_results.append(result)
|
||||
if self.callback:
|
||||
self.callback.progress("_batch_encode_texts", 1)
|
||||
except Exception as e:
|
||||
logger.exception("Embedding model failed to process batch")
|
||||
raise e
|
||||
|
||||
# Sort by batch index and extend embeddings
|
||||
batch_results.sort(key=lambda x: x[0])
|
||||
for _, batch_embeddings in batch_results:
|
||||
embeddings.extend(batch_embeddings)
|
||||
else:
|
||||
# Original sequential processing
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
_, batch_embeddings = process_batch(idx, text_batch)
|
||||
embeddings.extend(batch_embeddings)
|
||||
if self.callback:
|
||||
self.callback.progress("_batch_encode_texts", 1)
|
||||
|
||||
if self.callback:
|
||||
self.callback.progress("_batch_encode_texts", 1)
|
||||
return embeddings
|
||||
|
||||
def encode(
|
||||
|
||||
@@ -537,30 +537,36 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
|
||||
# Let the tag flow handle this case, don't reply twice
|
||||
return False
|
||||
|
||||
if event.get("bot_profile"):
|
||||
# Check if this is a bot message (either via bot_profile or bot_message subtype)
|
||||
is_bot_message = bool(
|
||||
event.get("bot_profile") or event.get("subtype") == "bot_message"
|
||||
)
|
||||
if is_bot_message:
|
||||
channel_name, _ = get_channel_name_from_id(
|
||||
client=client.web_client, channel_id=channel
|
||||
)
|
||||
|
||||
with get_session_with_tenant(client.tenant_id) as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
channel_name=channel_name,
|
||||
)
|
||||
|
||||
# If OnyxBot is not specifically tagged and the channel is not set to respond to bots, ignore the message
|
||||
if (not bot_tag_id or bot_tag_id not in msg) and (
|
||||
not slack_channel_config
|
||||
or not slack_channel_config.channel_config.get("respond_to_bots")
|
||||
):
|
||||
channel_specific_logger.info("Ignoring message from bot")
|
||||
channel_specific_logger.info(
|
||||
"Ignoring message from bot since respond_to_bots is disabled"
|
||||
)
|
||||
return False
|
||||
|
||||
# Ignore things like channel_join, channel_leave, etc.
|
||||
# NOTE: "file_share" is just a message with a file attachment, so we
|
||||
# should not ignore it
|
||||
message_subtype = event.get("subtype")
|
||||
if message_subtype not in [None, "file_share"]:
|
||||
if message_subtype not in [None, "file_share", "bot_message"]:
|
||||
channel_specific_logger.info(
|
||||
f"Ignoring message with subtype '{message_subtype}' since it is a special message type"
|
||||
)
|
||||
|
||||
@@ -92,7 +92,7 @@ class RedisConnectorPrune:
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_int = cast(int, fence_bytes)
|
||||
fence_int = int(cast(bytes, fence_bytes))
|
||||
return fence_int
|
||||
|
||||
@generator_complete.setter
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.configs.app_configs import REDIS_HOST
|
||||
from onyx.configs.app_configs import REDIS_PASSWORD
|
||||
from onyx.configs.app_configs import REDIS_POOL_MAX_CONNECTIONS
|
||||
from onyx.configs.app_configs import REDIS_PORT
|
||||
from onyx.configs.app_configs import REDIS_REPLICA_HOST
|
||||
from onyx.configs.app_configs import REDIS_SSL
|
||||
from onyx.configs.app_configs import REDIS_SSL_CA_CERTS
|
||||
from onyx.configs.app_configs import REDIS_SSL_CERT_REQS
|
||||
@@ -132,23 +133,32 @@ class RedisPool:
|
||||
_instance: Optional["RedisPool"] = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
_pool: redis.BlockingConnectionPool
|
||||
_replica_pool: redis.BlockingConnectionPool
|
||||
|
||||
def __new__(cls) -> "RedisPool":
|
||||
if not cls._instance:
|
||||
with cls._lock:
|
||||
if not cls._instance:
|
||||
cls._instance = super(RedisPool, cls).__new__(cls)
|
||||
cls._instance._init_pool()
|
||||
cls._instance._init_pools()
|
||||
return cls._instance
|
||||
|
||||
def _init_pool(self) -> None:
|
||||
def _init_pools(self) -> None:
|
||||
self._pool = RedisPool.create_pool(ssl=REDIS_SSL)
|
||||
self._replica_pool = RedisPool.create_pool(
|
||||
host=REDIS_REPLICA_HOST, ssl=REDIS_SSL
|
||||
)
|
||||
|
||||
def get_client(self, tenant_id: str | None) -> Redis:
|
||||
if tenant_id is None:
|
||||
tenant_id = "public"
|
||||
return TenantRedis(tenant_id, connection_pool=self._pool)
|
||||
|
||||
def get_replica_client(self, tenant_id: str | None) -> Redis:
|
||||
if tenant_id is None:
|
||||
tenant_id = "public"
|
||||
return TenantRedis(tenant_id, connection_pool=self._replica_pool)
|
||||
|
||||
@staticmethod
|
||||
def create_pool(
|
||||
host: str = REDIS_HOST,
|
||||
@@ -212,6 +222,10 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
|
||||
return redis_pool.get_client(tenant_id)
|
||||
|
||||
|
||||
def get_redis_replica_client(*, tenant_id: str | None) -> Redis:
|
||||
return redis_pool.get_replica_client(tenant_id)
|
||||
|
||||
|
||||
SSL_CERT_REQS_MAP = {
|
||||
"none": ssl.CERT_NONE,
|
||||
"optional": ssl.CERT_OPTIONAL,
|
||||
|
||||
@@ -16,7 +16,7 @@ from onyx.context.search.preprocessing.access_filters import (
|
||||
from onyx.db.document_set import get_document_sets_by_ids
|
||||
from onyx.db.models import StarterMessageModel as StarterMessage
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.document_index_utils import get_both_index_names
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.prompts.starter_messages import format_persona_starter_message_prompt
|
||||
@@ -34,8 +34,11 @@ def get_random_chunks_from_doc_sets(
|
||||
"""
|
||||
Retrieves random chunks from the specified document sets.
|
||||
"""
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
document_index = get_default_document_index(curr_ind_name, sec_ind_name)
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
search_settings=active_search_settings.primary,
|
||||
secondary_search_settings=active_search_settings.secondary,
|
||||
)
|
||||
|
||||
acl_filters = build_access_filters_for_user(user, db_session)
|
||||
filters = IndexFilters(document_set=doc_sets, access_control_list=acl_filters)
|
||||
|
||||
@@ -6184,7 +6184,7 @@
|
||||
"chunk_ind": 0
|
||||
},
|
||||
{
|
||||
"url": "https://docs.onyx.app/more/use_cases/customer_support",
|
||||
"url": "https://docs.onyx.app/more/use_cases/support",
|
||||
"title": "Customer Support",
|
||||
"content": "Help your customer support team instantly answer any question across your entire product.\n\nAI Enabled Support\nCustomer support agents have one of the highest breadth jobs. They field requests that cover the entire surface area of the product and need to help your users find success on extremely short timelines. Because they're not the same people who designed or built the system, they often lack the depth of understanding needed - resulting in delays and escalations to other teams. Modern teams are leveraging AI to help their CS team optimize the speed and quality of these critical customer-facing interactions.\n\nThe Importance of Context\nThere are two critical components of AI copilots for customer support. The first is that the AI system needs to be connected with as much information as possible (not just support tools like Zendesk or Intercom) and that the knowledge needs to be as fresh as possible. Sometimes a fix might even be in places rarely checked by CS such as pull requests in a code repository. The second critical component is the ability of the AI system to break down difficult concepts and convoluted processes into more digestible descriptions and for your team members to be able to chat back and forth with the system to build a better understanding.\n\nOnyx takes care of both of these. The system connects up to over 30+ different applications and the knowledge is pulled in constantly so that the information access is always up to date.",
|
||||
"title_embedding": [
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
"chunk_ind": 0
|
||||
},
|
||||
{
|
||||
"url": "https://docs.onyx.app/more/use_cases/customer_support",
|
||||
"url": "https://docs.onyx.app/more/use_cases/support",
|
||||
"title": "Customer Support",
|
||||
"content": "Help your customer support team instantly answer any question across your entire product.\n\nAI Enabled Support\nCustomer support agents have one of the highest breadth jobs. They field requests that cover the entire surface area of the product and need to help your users find success on extremely short timelines. Because they're not the same people who designed or built the system, they often lack the depth of understanding needed - resulting in delays and escalations to other teams. Modern teams are leveraging AI to help their CS team optimize the speed and quality of these critical customer-facing interactions.\n\nThe Importance of Context\nThere are two critical components of AI copilots for customer support. The first is that the AI system needs to be connected with as much information as possible (not just support tools like Zendesk or Intercom) and that the knowledge needs to be as fresh as possible. Sometimes a fix might even be in places rarely checked by CS such as pull requests in a code repository. The second critical component is the ability of the AI system to break down difficult concepts and convoluted processes into more digestible descriptions and for your team members to be able to chat back and forth with the system to build a better understanding.\n\nOnyx takes care of both of these. The system connects up to over 30+ different applications and the knowledge is pulled in constantly so that the information access is always up to date.",
|
||||
"chunk_ind": 0
|
||||
|
||||
@@ -3,6 +3,7 @@ import json
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.access.models import default_public_access
|
||||
@@ -23,6 +24,7 @@ from onyx.db.document import check_docs_exist
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import mock_successful_index_attempt
|
||||
from onyx.db.models import Document as DbDocument
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
@@ -59,6 +61,7 @@ def _create_indexable_chunks(
|
||||
doc_updated_at=None,
|
||||
primary_owners=[],
|
||||
secondary_owners=[],
|
||||
chunk_count=1,
|
||||
)
|
||||
if preprocessed_doc["chunk_ind"] == 0:
|
||||
ids_to_documents[document.id] = document
|
||||
@@ -155,9 +158,7 @@ def seed_initial_documents(
|
||||
logger.info("Embedding model has been updated, skipping")
|
||||
return
|
||||
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name, secondary_index_name=None
|
||||
)
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
|
||||
# Create a connector so the user can delete it if they want
|
||||
# or reindex it with a new search model if they want
|
||||
@@ -240,4 +241,12 @@ def seed_initial_documents(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Since we bypass the indexing flow, we need to manually update the chunk count
|
||||
for doc in docs:
|
||||
db_session.execute(
|
||||
update(DbDocument)
|
||||
.where(DbDocument.id == doc.id)
|
||||
.values(chunk_count=doc.chunk_count)
|
||||
)
|
||||
|
||||
kv_store.store(KV_DOCUMENTS_SEEDED_KEY, True)
|
||||
|
||||
@@ -15,6 +15,9 @@ from onyx.background.celery.celery_utils import get_deletion_attempt_snapshot
|
||||
from onyx.background.celery.tasks.doc_permission_syncing.tasks import (
|
||||
try_creating_permissions_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.external_group_syncing.tasks import (
|
||||
try_creating_external_group_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.pruning.tasks import (
|
||||
try_creating_prune_generator_task,
|
||||
)
|
||||
@@ -39,7 +42,7 @@ from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from onyx.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -189,7 +192,7 @@ def update_cc_pair_status(
|
||||
if status_update_request.status == ConnectorCredentialPairStatus.PAUSED:
|
||||
redis_connector.stop.set_fence(True)
|
||||
|
||||
search_settings_list: list[SearchSettings] = get_active_search_settings(
|
||||
search_settings_list: list[SearchSettings] = get_active_search_settings_list(
|
||||
db_session
|
||||
)
|
||||
|
||||
@@ -443,6 +446,78 @@ def sync_cc_pair(
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/sync-groups")
|
||||
def get_cc_pair_latest_group_sync(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> datetime | None:
|
||||
cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="cc_pair not found for current user's permissions",
|
||||
)
|
||||
|
||||
return cc_pair.last_time_external_group_sync
|
||||
|
||||
|
||||
@router.post("/admin/cc-pair/{cc_pair_id}/sync-groups")
|
||||
def sync_cc_pair_groups(
|
||||
cc_pair_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse[list[int]]:
|
||||
"""Triggers group sync on a particular cc_pair immediately"""
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=False,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connection not found for current user's permissions",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
if redis_connector.external_group_sync.fenced:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="External group sync task already in progress.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"External group sync cc_pair={cc_pair_id} "
|
||||
f"connector_id={cc_pair.connector_id} "
|
||||
f"credential_id={cc_pair.credential_id} "
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
tasks_created = try_creating_external_group_sync_task(
|
||||
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
)
|
||||
if not tasks_created:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
detail="External group sync task creation failed.",
|
||||
)
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Successfully created the external group sync task.",
|
||||
)
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/get-docs-sync-status")
|
||||
def get_docs_sync_status(
|
||||
cc_pair_id: int,
|
||||
|
||||
@@ -32,10 +32,7 @@ def get_document_info(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> DocumentInfo:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name, secondary_index_name=None
|
||||
)
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
inference_chunks = document_index.id_based_retrieval(
|
||||
@@ -79,10 +76,7 @@ def get_chunk_info(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChunkInfo:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name, secondary_index_name=None
|
||||
)
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
chunk_request = VespaChunkRequest(
|
||||
|
||||
@@ -357,6 +357,7 @@ class ConnectorCredentialPairDescriptor(BaseModel):
|
||||
name: str | None = None
|
||||
connector: ConnectorSnapshot
|
||||
credential: CredentialSnapshot
|
||||
access_type: AccessType
|
||||
|
||||
|
||||
class RunConnectorRequest(BaseModel):
|
||||
|
||||
@@ -68,6 +68,7 @@ class DocumentSet(BaseModel):
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair.credential
|
||||
),
|
||||
access_type=cc_pair.access_type,
|
||||
)
|
||||
for cc_pair in document_set_model.connector_credential_pairs
|
||||
],
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_llm_providers_for_user
|
||||
from onyx.db.llm import fetch_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -195,5 +196,7 @@ def list_llm_provider_basics(
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
return [
|
||||
LLMProviderDescriptor.from_model(llm_provider_model)
|
||||
for llm_provider_model in fetch_existing_llm_providers(db_session, user)
|
||||
for llm_provider_model in fetch_existing_llm_providers_for_user(
|
||||
db_session, user
|
||||
)
|
||||
]
|
||||
|
||||
@@ -44,7 +44,6 @@ class UserPreferences(BaseModel):
|
||||
chosen_assistants: list[int] | None = None
|
||||
hidden_assistants: list[int] = []
|
||||
visible_assistants: list[int] = []
|
||||
recent_assistants: list[int] | None = None
|
||||
default_model: str | None = None
|
||||
auto_scroll: bool | None = None
|
||||
pinned_assistants: list[int] | None = None
|
||||
|
||||
@@ -22,6 +22,7 @@ from onyx.db.search_settings import get_embedding_provider_from_provider_type
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.search_settings import update_current_search_settings
|
||||
from onyx.db.search_settings import update_search_settings_status
|
||||
from onyx.document_index.document_index_utils import get_multipass_config
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_processing.unstructured import delete_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import get_unstructured_api_key
|
||||
@@ -97,10 +98,9 @@ def set_new_search_settings(
|
||||
)
|
||||
|
||||
# Ensure Vespa has the new index immediately
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=new_search_settings.index_name,
|
||||
)
|
||||
get_multipass_config(search_settings)
|
||||
get_multipass_config(new_search_settings)
|
||||
document_index = get_default_document_index(search_settings, new_search_settings)
|
||||
|
||||
document_index.ensure_indices_exist(
|
||||
index_embedding_dim=search_settings.model_dim,
|
||||
|
||||
@@ -572,59 +572,6 @@ class ChosenDefaultModelRequest(BaseModel):
|
||||
default_model: str | None = None
|
||||
|
||||
|
||||
class RecentAssistantsRequest(BaseModel):
|
||||
current_assistant: int
|
||||
|
||||
|
||||
def update_recent_assistants(
|
||||
recent_assistants: list[int] | None, current_assistant: int
|
||||
) -> list[int]:
|
||||
if recent_assistants is None:
|
||||
recent_assistants = []
|
||||
else:
|
||||
recent_assistants = [x for x in recent_assistants if x != current_assistant]
|
||||
|
||||
# Add current assistant to start of list
|
||||
recent_assistants.insert(0, current_assistant)
|
||||
|
||||
# Keep only the 5 most recent assistants
|
||||
recent_assistants = recent_assistants[:5]
|
||||
return recent_assistants
|
||||
|
||||
|
||||
@router.patch("/user/recent-assistants")
|
||||
def update_user_recent_assistants(
|
||||
request: RecentAssistantsRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if user is None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
store = get_kv_store()
|
||||
no_auth_user = fetch_no_auth_user(store)
|
||||
preferences = no_auth_user.preferences
|
||||
recent_assistants = preferences.recent_assistants
|
||||
updated_preferences = update_recent_assistants(
|
||||
recent_assistants, request.current_assistant
|
||||
)
|
||||
preferences.recent_assistants = updated_preferences
|
||||
set_no_auth_user_preferences(store, preferences)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError("This should never happen")
|
||||
|
||||
recent_assistants = UserInfo.from_model(user).preferences.recent_assistants
|
||||
updated_recent_assistants = update_recent_assistants(
|
||||
recent_assistants, request.current_assistant
|
||||
)
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id) # type: ignore
|
||||
.values(recent_assistants=updated_recent_assistants)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.patch("/shortcut-enabled")
|
||||
def update_user_shortcut_enabled(
|
||||
shortcut_enabled: bool,
|
||||
@@ -731,30 +678,6 @@ class ChosenAssistantsRequest(BaseModel):
|
||||
chosen_assistants: list[int]
|
||||
|
||||
|
||||
@router.patch("/user/assistant-list")
|
||||
def update_user_assistant_list(
|
||||
request: ChosenAssistantsRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if user is None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
store = get_kv_store()
|
||||
no_auth_user = fetch_no_auth_user(store)
|
||||
no_auth_user.preferences.chosen_assistants = request.chosen_assistants
|
||||
set_no_auth_user_preferences(store, no_auth_user.preferences)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError("This should never happen")
|
||||
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id) # type: ignore
|
||||
.values(chosen_assistants=request.chosen_assistants)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_assistant_visibility(
|
||||
preferences: UserPreferences, assistant_id: int, show: bool
|
||||
) -> UserPreferences:
|
||||
|
||||
@@ -14,9 +14,9 @@ from onyx.db.document import get_ingestion_documents
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.document_index.document_index_utils import get_both_index_names
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
@@ -89,9 +89,10 @@ def upsert_ingestion_doc(
|
||||
)
|
||||
|
||||
# Need to index for both the primary and secondary index if possible
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
curr_doc_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=None
|
||||
active_search_settings.primary,
|
||||
None,
|
||||
)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
@@ -117,11 +118,7 @@ def upsert_ingestion_doc(
|
||||
)
|
||||
|
||||
# If there's a secondary index being built, index the doc but don't use it for return here
|
||||
if sec_ind_name:
|
||||
sec_doc_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
if active_search_settings.secondary:
|
||||
sec_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
if sec_search_settings is None:
|
||||
@@ -134,6 +131,10 @@ def upsert_ingestion_doc(
|
||||
search_settings=sec_search_settings
|
||||
)
|
||||
|
||||
sec_doc_index = get_default_document_index(
|
||||
active_search_settings.secondary, None
|
||||
)
|
||||
|
||||
sec_ind_pipeline = build_indexing_pipeline(
|
||||
embedder=new_index_embedding_model,
|
||||
document_index=sec_doc_index,
|
||||
|
||||
@@ -672,23 +672,25 @@ def upload_files_for_chat(
|
||||
else ChatFileType.PLAIN_TEXT
|
||||
)
|
||||
|
||||
file_content = file.file.read() # Read the file content
|
||||
|
||||
if file_type == ChatFileType.IMAGE:
|
||||
file_content = file.file
|
||||
file_content_io = file.file
|
||||
# NOTE: Image conversion to JPEG used to be enforced here.
|
||||
# This was removed to:
|
||||
# 1. Preserve original file content for downloads
|
||||
# 2. Maintain transparency in formats like PNG
|
||||
# 3. Ameliorate issue with file conversion
|
||||
else:
|
||||
file_content = io.BytesIO(file.file.read())
|
||||
file_content_io = io.BytesIO(file_content)
|
||||
|
||||
new_content_type = file.content_type
|
||||
|
||||
# store the file (now JPEG for images)
|
||||
# Store the file normally
|
||||
file_id = str(uuid.uuid4())
|
||||
file_store.save_file(
|
||||
file_name=file_id,
|
||||
content=file_content,
|
||||
content=file_content_io,
|
||||
display_name=file.filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=new_content_type or file_type.value,
|
||||
@@ -698,7 +700,7 @@ def upload_files_for_chat(
|
||||
# to re-extract it every time we send a message
|
||||
if file_type == ChatFileType.DOC:
|
||||
extracted_text = extract_file_text(
|
||||
file=file.file,
|
||||
file=io.BytesIO(file_content), # use the bytes we already read
|
||||
file_name=file.filename or "",
|
||||
)
|
||||
text_file_id = str(uuid.uuid4())
|
||||
|
||||
@@ -64,9 +64,8 @@ def admin_search(
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name, secondary_index_name=None
|
||||
)
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
|
||||
if not isinstance(document_index, VespaIndex):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
|
||||
@@ -25,6 +25,7 @@ from onyx.db.llm import fetch_default_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.persona import delete_old_default_personas
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.db.search_settings import update_current_search_settings
|
||||
@@ -70,8 +71,19 @@ def setup_onyx(
|
||||
The Tenant Service calls the tenants/create endpoint which runs this.
|
||||
"""
|
||||
check_index_swap(db_session=db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
search_settings = active_search_settings.primary
|
||||
secondary_search_settings = active_search_settings.secondary
|
||||
|
||||
# search_settings = get_current_search_settings(db_session)
|
||||
# multipass_config_1 = get_multipass_config(search_settings)
|
||||
|
||||
# secondary_large_chunks_enabled: bool | None = None
|
||||
# secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
# if secondary_search_settings:
|
||||
# multipass_config_2 = get_multipass_config(secondary_search_settings)
|
||||
# secondary_large_chunks_enabled = multipass_config_2.enable_large_chunks
|
||||
|
||||
# Break bad state for thrashing indexes
|
||||
if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
@@ -122,10 +134,8 @@ def setup_onyx(
|
||||
# takes a bit of time to start up
|
||||
logger.notice("Verifying Document Index(s) is/are available.")
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None,
|
||||
search_settings,
|
||||
secondary_search_settings,
|
||||
)
|
||||
|
||||
success = setup_vespa(
|
||||
|
||||
@@ -220,6 +220,13 @@ class InternetSearchTool(Tool):
|
||||
)
|
||||
results = response.json()
|
||||
|
||||
# If no hits, Bing does not include the webPages key
|
||||
search_results = (
|
||||
results["webPages"]["value"][: self.num_results]
|
||||
if "webPages" in results
|
||||
else []
|
||||
)
|
||||
|
||||
return InternetSearchResponse(
|
||||
revised_query=query,
|
||||
internet_results=[
|
||||
@@ -228,7 +235,7 @@ class InternetSearchTool(Tool):
|
||||
link=result["url"],
|
||||
snippet=result["snippet"],
|
||||
)
|
||||
for result in results["webPages"]["value"][: self.num_results]
|
||||
for result in search_results
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -123,6 +123,7 @@ def optional_telemetry(
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=payload,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# This way it silences all thread level logging as well
|
||||
pass
|
||||
|
||||
@@ -81,6 +81,7 @@ hubspot-api-client==8.1.0
|
||||
asana==5.0.8
|
||||
dropbox==11.36.2
|
||||
boto3-stubs[s3]==1.34.133
|
||||
shapely==2.0.6
|
||||
stripe==10.12.0
|
||||
urllib3==2.2.3
|
||||
mistune==0.8.4
|
||||
|
||||
@@ -197,7 +197,7 @@ ai_platform_doc = SeedPresaveDocument(
|
||||
)
|
||||
|
||||
customer_support_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/customer_support",
|
||||
url="https://docs.onyx.app/more/use_cases/support",
|
||||
title=customer_support_title,
|
||||
content=customer_support,
|
||||
title_embedding=model.encode(f"search_document: {customer_support_title}"),
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.document import delete_documents_complete__no_commit
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
|
||||
# Modify sys.path
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -38,7 +39,6 @@ from onyx.db.connector_credential_pair import (
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.document_index.document_index_utils import get_both_index_names
|
||||
|
||||
# pylint: enable=E402
|
||||
# flake8: noqa: E402
|
||||
@@ -191,9 +191,10 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
|
||||
)
|
||||
try:
|
||||
logger.notice("Deleting information from Vespa and Postgres")
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=sec_ind_name
|
||||
active_search_settings.primary,
|
||||
active_search_settings.secondary,
|
||||
)
|
||||
|
||||
files_deleted_count = _unsafe_deletion(
|
||||
|
||||
@@ -21,35 +21,144 @@ Options:
|
||||
--doc-id : Document ID
|
||||
--fields : Fields to update (JSON)
|
||||
|
||||
Example: (gets docs for a given tenant id and connector id)
|
||||
Example:
|
||||
python vespa_debug_tool.py --action list_docs --tenant-id my_tenant --connector-id 1 --n 5
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_
|
||||
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.document_index_utils import get_document_chunk_ids
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
|
||||
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_SETS
|
||||
from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import SEARCH_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import SOURCE_TYPE
|
||||
from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import VESPA_APP_CONTAINER_URL
|
||||
from onyx.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DocumentFilter(BaseModel):
|
||||
# Document filter for link matching.
|
||||
link: str | None = None
|
||||
|
||||
|
||||
def build_vespa_filters(
|
||||
filters: IndexFilters,
|
||||
*,
|
||||
include_hidden: bool = False,
|
||||
remove_trailing_and: bool = False,
|
||||
) -> str:
|
||||
# Build a combined Vespa filter string from the given IndexFilters.
|
||||
def _build_or_filters(key: str, vals: list[str] | None) -> str:
|
||||
if vals is None:
|
||||
return ""
|
||||
valid_vals = [val for val in vals if val]
|
||||
if not key or not valid_vals:
|
||||
return ""
|
||||
eq_elems = [f'{key} contains "{elem}"' for elem in valid_vals]
|
||||
or_clause = " or ".join(eq_elems)
|
||||
return f"({or_clause})"
|
||||
|
||||
def _build_time_filter(
|
||||
cutoff: datetime | None,
|
||||
untimed_doc_cutoff: timedelta = timedelta(days=92),
|
||||
) -> str:
|
||||
if not cutoff:
|
||||
return ""
|
||||
include_untimed = datetime.now(timezone.utc) - untimed_doc_cutoff > cutoff
|
||||
cutoff_secs = int(cutoff.timestamp())
|
||||
if include_untimed:
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs})"
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs})"
|
||||
|
||||
filter_str = ""
|
||||
if not include_hidden:
|
||||
filter_str += f"AND !({HIDDEN}=true) "
|
||||
|
||||
if filters.tenant_id and MULTI_TENANT:
|
||||
filter_str += f'AND ({TENANT_ID} contains "{filters.tenant_id}") '
|
||||
|
||||
if filters.access_control_list is not None:
|
||||
acl_str = _build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list)
|
||||
if acl_str:
|
||||
filter_str += f"AND {acl_str} "
|
||||
|
||||
source_strs = (
|
||||
[s.value for s in filters.source_type] if filters.source_type else None
|
||||
)
|
||||
source_str = _build_or_filters(SOURCE_TYPE, source_strs)
|
||||
if source_str:
|
||||
filter_str += f"AND {source_str} "
|
||||
|
||||
tags = filters.tags
|
||||
if tags:
|
||||
tag_attributes = [tag.tag_key + INDEX_SEPARATOR + tag.tag_value for tag in tags]
|
||||
else:
|
||||
tag_attributes = None
|
||||
tag_str = _build_or_filters(METADATA_LIST, tag_attributes)
|
||||
if tag_str:
|
||||
filter_str += f"AND {tag_str} "
|
||||
|
||||
doc_set_str = _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
if doc_set_str:
|
||||
filter_str += f"AND {doc_set_str} "
|
||||
|
||||
time_filter = _build_time_filter(filters.time_cutoff)
|
||||
if time_filter:
|
||||
filter_str += f"AND {time_filter} "
|
||||
|
||||
if remove_trailing_and:
|
||||
while filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5]
|
||||
while filter_str.endswith("AND "):
|
||||
filter_str = filter_str[:-4]
|
||||
|
||||
return filter_str.strip()
|
||||
|
||||
|
||||
# Print Vespa configuration URLs
|
||||
def print_vespa_config() -> None:
|
||||
# Print Vespa configuration.
|
||||
logger.info("Printing Vespa configuration.")
|
||||
print(f"Vespa Application Endpoint: {VESPA_APPLICATION_ENDPOINT}")
|
||||
print(f"Vespa App Container URL: {VESPA_APP_CONTAINER_URL}")
|
||||
print(f"Vespa Search Endpoint: {SEARCH_ENDPOINT}")
|
||||
print(f"Vespa Document ID Endpoint: {DOCUMENT_ID_ENDPOINT}")
|
||||
|
||||
|
||||
# Check connectivity to Vespa endpoints
|
||||
def check_vespa_connectivity() -> None:
|
||||
# Check connectivity to Vespa endpoints.
|
||||
logger.info("Checking Vespa connectivity.")
|
||||
endpoints = [
|
||||
f"{VESPA_APPLICATION_ENDPOINT}/ApplicationStatus",
|
||||
f"{VESPA_APPLICATION_ENDPOINT}/tenant",
|
||||
@@ -61,17 +170,21 @@ def check_vespa_connectivity() -> None:
|
||||
try:
|
||||
with get_vespa_http_client() as client:
|
||||
response = client.get(endpoint)
|
||||
logger.info(
|
||||
f"Connected to Vespa at {endpoint}, status code {response.status_code}"
|
||||
)
|
||||
print(f"Successfully connected to Vespa at {endpoint}")
|
||||
print(f"Status code: {response.status_code}")
|
||||
print(f"Response: {response.text[:200]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Vespa at {endpoint}: {str(e)}")
|
||||
print(f"Failed to connect to Vespa at {endpoint}: {str(e)}")
|
||||
|
||||
print("Vespa connectivity check completed.")
|
||||
|
||||
|
||||
# Get info about the default Vespa application
|
||||
def get_vespa_info() -> Dict[str, Any]:
|
||||
# Get info about the default Vespa application.
|
||||
url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/application/default"
|
||||
with get_vespa_http_client() as client:
|
||||
response = client.get(url)
|
||||
@@ -79,121 +192,298 @@ def get_vespa_info() -> Dict[str, Any]:
|
||||
return response.json()
|
||||
|
||||
|
||||
# Get index name for a tenant and connector pair
|
||||
def get_index_name(tenant_id: str, connector_id: int) -> str:
|
||||
def get_index_name(tenant_id: str) -> str:
|
||||
# Return the index name for a given tenant.
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(db_session, connector_id)
|
||||
if not cc_pair:
|
||||
raise ValueError(f"No connector found for id {connector_id}")
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
return search_settings.index_name if search_settings else "public"
|
||||
if not search_settings:
|
||||
raise ValueError(f"No search settings found for tenant {tenant_id}")
|
||||
return search_settings.index_name
|
||||
|
||||
|
||||
# Perform a Vespa query using YQL syntax
|
||||
def query_vespa(yql: str) -> List[Dict[str, Any]]:
|
||||
params = {
|
||||
"yql": yql,
|
||||
"timeout": "10s",
|
||||
}
|
||||
def query_vespa(
|
||||
yql: str, tenant_id: Optional[str] = None, limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
# Perform a Vespa query using YQL syntax.
|
||||
filters = IndexFilters(tenant_id=tenant_id, access_control_list=[])
|
||||
filter_string = build_vespa_filters(filters, remove_trailing_and=True)
|
||||
full_yql = yql.strip()
|
||||
if filter_string:
|
||||
full_yql = f"{full_yql} {filter_string}"
|
||||
full_yql = f"{full_yql} limit {limit}"
|
||||
|
||||
params = {"yql": full_yql, "timeout": "10s"}
|
||||
search_request = SearchRequest(query="", limit=limit, offset=0)
|
||||
params.update(search_request.model_dump())
|
||||
|
||||
logger.info(f"Executing Vespa query: {full_yql}")
|
||||
with get_vespa_http_client() as client:
|
||||
response = client.get(SEARCH_ENDPOINT, params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()["root"]["children"]
|
||||
result = response.json()
|
||||
documents = result.get("root", {}).get("children", [])
|
||||
logger.info(f"Found {len(documents)} documents from query.")
|
||||
return documents
|
||||
|
||||
|
||||
# Get first N documents
|
||||
def get_first_n_documents(n: int = 10) -> List[Dict[str, Any]]:
|
||||
yql = f"select * from sources * where true limit {n};"
|
||||
return query_vespa(yql)
|
||||
# Get the first n documents from any source.
|
||||
yql = "select * from sources * where true"
|
||||
return query_vespa(yql, limit=n)
|
||||
|
||||
|
||||
# Pretty-print a list of documents
|
||||
def print_documents(documents: List[Dict[str, Any]]) -> None:
|
||||
# Pretty-print a list of documents.
|
||||
for doc in documents:
|
||||
print(json.dumps(doc, indent=2))
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
# Get and print documents for a specific tenant and connector
|
||||
def get_documents_for_tenant_connector(
|
||||
tenant_id: str, connector_id: int, n: int = 10
|
||||
) -> None:
|
||||
get_index_name(tenant_id, connector_id)
|
||||
documents = get_first_n_documents(n)
|
||||
print(f"First {n} documents for tenant {tenant_id}, connector {connector_id}:")
|
||||
# Get and print documents for a specific tenant and connector.
|
||||
index_name = get_index_name(tenant_id)
|
||||
logger.info(
|
||||
f"Fetching documents for tenant={tenant_id}, connector_id={connector_id}"
|
||||
)
|
||||
yql = f"select * from sources {index_name} where true"
|
||||
documents = query_vespa(yql, tenant_id, limit=n)
|
||||
print(
|
||||
f"First {len(documents)} documents for tenant {tenant_id}, connector {connector_id}:"
|
||||
)
|
||||
print_documents(documents)
|
||||
|
||||
|
||||
# Search documents for a specific tenant and connector
|
||||
def search_documents(
|
||||
tenant_id: str, connector_id: int, query: str, n: int = 10
|
||||
) -> None:
|
||||
index_name = get_index_name(tenant_id, connector_id)
|
||||
yql = f"select * from sources {index_name} where userInput(@query) limit {n};"
|
||||
documents = query_vespa(yql)
|
||||
print(f"Search results for query '{query}':")
|
||||
# Search documents for a specific tenant and connector.
|
||||
index_name = get_index_name(tenant_id)
|
||||
logger.info(
|
||||
f"Searching documents for tenant={tenant_id}, connector_id={connector_id}, query='{query}'"
|
||||
)
|
||||
yql = f"select * from sources {index_name} where userInput(@query)"
|
||||
documents = query_vespa(yql, tenant_id, limit=n)
|
||||
print(f"Search results for query '{query}' in tenant {tenant_id}:")
|
||||
print_documents(documents)
|
||||
|
||||
|
||||
# Update a specific document
|
||||
def update_document(
|
||||
tenant_id: str, connector_id: int, doc_id: str, fields: Dict[str, Any]
|
||||
) -> None:
|
||||
index_name = get_index_name(tenant_id, connector_id)
|
||||
# Update a specific document.
|
||||
index_name = get_index_name(tenant_id)
|
||||
logger.info(
|
||||
f"Updating document doc_id={doc_id} in tenant={tenant_id}, connector_id={connector_id}"
|
||||
)
|
||||
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name) + f"/{doc_id}"
|
||||
update_request = {"fields": {k: {"assign": v} for k, v in fields.items()}}
|
||||
|
||||
with get_vespa_http_client() as client:
|
||||
response = client.put(url, json=update_request)
|
||||
response.raise_for_status()
|
||||
logger.info(f"Document {doc_id} updated successfully.")
|
||||
print(f"Document {doc_id} updated successfully")
|
||||
|
||||
|
||||
# Delete a specific document
|
||||
def delete_document(tenant_id: str, connector_id: int, doc_id: str) -> None:
|
||||
index_name = get_index_name(tenant_id, connector_id)
|
||||
# Delete a specific document.
|
||||
index_name = get_index_name(tenant_id)
|
||||
logger.info(
|
||||
f"Deleting document doc_id={doc_id} in tenant={tenant_id}, connector_id={connector_id}"
|
||||
)
|
||||
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name) + f"/{doc_id}"
|
||||
|
||||
with get_vespa_http_client() as client:
|
||||
response = client.delete(url)
|
||||
response.raise_for_status()
|
||||
logger.info(f"Document {doc_id} deleted successfully.")
|
||||
print(f"Document {doc_id} deleted successfully")
|
||||
|
||||
|
||||
# List documents from any source
|
||||
def list_documents(n: int = 10) -> None:
|
||||
yql = f"select * from sources * where true limit {n};"
|
||||
url = f"{VESPA_APP_CONTAINER_URL}/search/"
|
||||
params = {
|
||||
"yql": yql,
|
||||
"timeout": "10s",
|
||||
}
|
||||
try:
|
||||
with get_vespa_http_client() as client:
|
||||
response = client.get(url, params=params)
|
||||
response.raise_for_status()
|
||||
documents = response.json()["root"]["children"]
|
||||
print(f"First {n} documents:")
|
||||
print_documents(documents)
|
||||
except Exception as e:
|
||||
print(f"Failed to list documents: {str(e)}")
|
||||
|
||||
|
||||
# Get and print ACLs for documents of a specific tenant and connector
|
||||
def get_document_acls(tenant_id: str, connector_id: int, n: int = 10) -> None:
|
||||
index_name = get_index_name(tenant_id, connector_id)
|
||||
yql = f"select documentid, access_control_list from sources {index_name} where true limit {n};"
|
||||
documents = query_vespa(yql)
|
||||
print(f"ACLs for {n} documents from tenant {tenant_id}, connector {connector_id}:")
|
||||
for doc in documents:
|
||||
print(f"Document ID: {doc['fields']['documentid']}")
|
||||
print(
|
||||
f"ACL: {json.dumps(doc['fields'].get('access_control_list', {}), indent=2)}"
|
||||
)
|
||||
def list_documents(n: int = 10, tenant_id: Optional[str] = None) -> None:
|
||||
# List documents from any source, filtered by tenant if provided.
|
||||
logger.info(f"Listing up to {n} documents for tenant={tenant_id or 'ALL'}")
|
||||
yql = "select * from sources * where true"
|
||||
if tenant_id:
|
||||
yql += f" and tenant_id contains '{tenant_id}'"
|
||||
documents = query_vespa(yql, tenant_id=tenant_id, limit=n)
|
||||
print(f"Total documents found: {len(documents)}")
|
||||
logger.info(f"Total documents found: {len(documents)}")
|
||||
print(f"First {min(n, len(documents))} documents:")
|
||||
for doc in documents[:n]:
|
||||
print(json.dumps(doc, indent=2))
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
def get_document_and_chunk_counts(
|
||||
tenant_id: str, cc_pair_id: int, filter_doc: DocumentFilter | None = None
|
||||
) -> Dict[str, int]:
|
||||
# Return a dict mapping each document ID to its chunk count for a given connector.
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as session:
|
||||
doc_ids_data = (
|
||||
session.query(DocumentByConnectorCredentialPair.id, Document.link)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.join(Document, DocumentByConnectorCredentialPair.id == Document.id)
|
||||
.filter(ConnectorCredentialPair.id == cc_pair_id)
|
||||
.distinct()
|
||||
.all()
|
||||
)
|
||||
doc_ids = []
|
||||
for doc_id, link in doc_ids_data:
|
||||
if filter_doc and filter_doc.link:
|
||||
if link and filter_doc.link.lower() in link.lower():
|
||||
doc_ids.append(doc_id)
|
||||
else:
|
||||
doc_ids.append(doc_id)
|
||||
chunk_counts_data = (
|
||||
session.query(Document.id, Document.chunk_count)
|
||||
.filter(Document.id.in_(doc_ids))
|
||||
.all()
|
||||
)
|
||||
return {
|
||||
doc_id: chunk_count
|
||||
for doc_id, chunk_count in chunk_counts_data
|
||||
if chunk_count is not None
|
||||
}
|
||||
|
||||
|
||||
def get_chunk_ids_for_connector(
|
||||
tenant_id: str,
|
||||
cc_pair_id: int,
|
||||
index_name: str,
|
||||
filter_doc: DocumentFilter | None = None,
|
||||
) -> List[UUID]:
|
||||
# Return chunk IDs for a given connector.
|
||||
doc_id_to_new_chunk_cnt = get_document_and_chunk_counts(
|
||||
tenant_id, cc_pair_id, filter_doc
|
||||
)
|
||||
doc_infos: List[EnrichedDocumentIndexingInfo] = [
|
||||
VespaIndex.enrich_basic_chunk_info(
|
||||
index_name=index_name,
|
||||
http_client=get_vespa_http_client(),
|
||||
document_id=doc_id,
|
||||
previous_chunk_count=doc_id_to_new_chunk_cnt.get(doc_id, 0),
|
||||
new_chunk_count=0,
|
||||
)
|
||||
for doc_id in doc_id_to_new_chunk_cnt.keys()
|
||||
]
|
||||
chunk_ids = get_document_chunk_ids(
|
||||
enriched_document_info_list=doc_infos,
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=False,
|
||||
)
|
||||
if not isinstance(chunk_ids, list):
|
||||
raise ValueError(f"Expected list of chunk IDs, got {type(chunk_ids)}")
|
||||
return chunk_ids
|
||||
|
||||
|
||||
def get_document_acls(
|
||||
tenant_id: str,
|
||||
cc_pair_id: int,
|
||||
n: int | None = 10,
|
||||
filter_doc: DocumentFilter | None = None,
|
||||
) -> None:
|
||||
# Fetch document ACLs for the given tenant and connector pair.
|
||||
index_name = get_index_name(tenant_id)
|
||||
logger.info(
|
||||
f"Fetching document ACLs for tenant={tenant_id}, cc_pair_id={cc_pair_id}"
|
||||
)
|
||||
chunk_ids: List[UUID] = get_chunk_ids_for_connector(
|
||||
tenant_id, cc_pair_id, index_name, filter_doc
|
||||
)
|
||||
vespa_client = get_vespa_http_client()
|
||||
|
||||
target_ids = chunk_ids if n is None else chunk_ids[:n]
|
||||
logger.info(
|
||||
f"Found {len(chunk_ids)} chunk IDs, showing ACLs for {len(target_ids)}."
|
||||
)
|
||||
for doc_chunk_id in target_ids:
|
||||
document_url = (
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{str(doc_chunk_id)}"
|
||||
)
|
||||
response = vespa_client.get(document_url)
|
||||
if response.status_code == 200:
|
||||
fields = response.json().get("fields", {})
|
||||
document_id = fields.get("document_id") or fields.get(
|
||||
"documentid", "Unknown"
|
||||
)
|
||||
acls = fields.get("access_control_list", {})
|
||||
title = fields.get("title", "")
|
||||
source_type = fields.get("source_type", "")
|
||||
source_links_raw = fields.get("source_links", "{}")
|
||||
try:
|
||||
source_links = json.loads(source_links_raw)
|
||||
except json.JSONDecodeError:
|
||||
source_links = {}
|
||||
|
||||
print(f"Document Chunk ID: {doc_chunk_id}")
|
||||
print(f"Document ID: {document_id}")
|
||||
print(f"ACLs:\n{json.dumps(acls, indent=2)}")
|
||||
print(f"Source Links: {source_links}")
|
||||
print(f"Title: {title}")
|
||||
print(f"Source Type: {source_type}")
|
||||
if MULTI_TENANT:
|
||||
print(f"Tenant ID: {fields.get('tenant_id', 'N/A')}")
|
||||
print("-" * 80)
|
||||
else:
|
||||
logger.error(f"Failed to fetch document for chunk ID: {doc_chunk_id}")
|
||||
print(f"Failed to fetch document for chunk ID: {doc_chunk_id}")
|
||||
print(f"Status Code: {response.status_code}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
class VespaDebugging:
|
||||
# Class for managing Vespa debugging actions.
|
||||
def __init__(self, tenant_id: str | None = None):
|
||||
self.tenant_id = POSTGRES_DEFAULT_SCHEMA if not tenant_id else tenant_id
|
||||
|
||||
def print_config(self) -> None:
|
||||
# Print Vespa config.
|
||||
print_vespa_config()
|
||||
|
||||
def check_connectivity(self) -> None:
|
||||
# Check Vespa connectivity.
|
||||
check_vespa_connectivity()
|
||||
|
||||
def list_documents(self, n: int = 10) -> None:
|
||||
# List documents for a tenant.
|
||||
list_documents(n, self.tenant_id)
|
||||
|
||||
def search_documents(self, connector_id: int, query: str, n: int = 10) -> None:
|
||||
# Search documents for a tenant and connector.
|
||||
search_documents(self.tenant_id, connector_id, query, n)
|
||||
|
||||
def update_document(
|
||||
self, connector_id: int, doc_id: str, fields: Dict[str, Any]
|
||||
) -> None:
|
||||
# Update a document.
|
||||
update_document(self.tenant_id, connector_id, doc_id, fields)
|
||||
|
||||
def delete_document(self, connector_id: int, doc_id: str) -> None:
|
||||
# Delete a document.
|
||||
delete_document(self.tenant_id, connector_id, doc_id)
|
||||
|
||||
def acls_by_link(self, cc_pair_id: int, link: str) -> None:
|
||||
# Get ACLs for a document matching a link.
|
||||
get_document_acls(
|
||||
self.tenant_id, cc_pair_id, n=None, filter_doc=DocumentFilter(link=link)
|
||||
)
|
||||
|
||||
def acls(self, cc_pair_id: int, n: int | None = 10) -> None:
|
||||
# Get ACLs for a connector.
|
||||
get_document_acls(self.tenant_id, cc_pair_id, n)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# Main CLI entry point.
|
||||
parser = argparse.ArgumentParser(description="Vespa debugging tool")
|
||||
parser.add_argument(
|
||||
"--action",
|
||||
@@ -209,60 +499,45 @@ def main() -> None:
|
||||
required=True,
|
||||
help="Action to perform",
|
||||
)
|
||||
parser.add_argument("--tenant-id", help="Tenant ID")
|
||||
parser.add_argument("--connector-id", type=int, help="Connector ID")
|
||||
parser.add_argument(
|
||||
"--tenant-id", help="Tenant ID (for update, delete, and get_acls actions)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--connector-id",
|
||||
type=int,
|
||||
help="Connector ID (for update, delete, and get_acls actions)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--n",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of documents to retrieve (for list_docs, search, update, and get_acls actions)",
|
||||
"--n", type=int, default=10, help="Number of documents to retrieve"
|
||||
)
|
||||
parser.add_argument("--query", help="Search query (for search action)")
|
||||
parser.add_argument("--doc-id", help="Document ID (for update and delete actions)")
|
||||
parser.add_argument(
|
||||
"--fields", help="Fields to update, in JSON format (for update action)"
|
||||
"--fields", help="Fields to update, in JSON format (for update)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
vespa_debug = VespaDebugging(args.tenant_id)
|
||||
|
||||
if args.action == "config":
|
||||
print_vespa_config()
|
||||
vespa_debug.print_config()
|
||||
elif args.action == "connect":
|
||||
check_vespa_connectivity()
|
||||
vespa_debug.check_connectivity()
|
||||
elif args.action == "list_docs":
|
||||
# If tenant_id and connector_id are provided, list docs for that tenant/connector.
|
||||
# Otherwise, list documents from any source.
|
||||
if args.tenant_id and args.connector_id:
|
||||
get_documents_for_tenant_connector(
|
||||
args.tenant_id, args.connector_id, args.n
|
||||
)
|
||||
else:
|
||||
list_documents(args.n)
|
||||
vespa_debug.list_documents(args.n)
|
||||
elif args.action == "search":
|
||||
if not args.query:
|
||||
parser.error("--query is required for search action")
|
||||
search_documents(args.tenant_id, args.connector_id, args.query, args.n)
|
||||
if not args.query or args.connector_id is None:
|
||||
parser.error("--query and --connector-id are required for search action")
|
||||
vespa_debug.search_documents(args.connector_id, args.query, args.n)
|
||||
elif args.action == "update":
|
||||
if not args.doc_id or not args.fields:
|
||||
parser.error("--doc-id and --fields are required for update action")
|
||||
fields = json.loads(args.fields)
|
||||
update_document(args.tenant_id, args.connector_id, args.doc_id, fields)
|
||||
elif args.action == "delete":
|
||||
if not args.doc_id:
|
||||
parser.error("--doc-id is required for delete action")
|
||||
delete_document(args.tenant_id, args.connector_id, args.doc_id)
|
||||
elif args.action == "get_acls":
|
||||
if not args.tenant_id or args.connector_id is None:
|
||||
if not args.doc_id or not args.fields or args.connector_id is None:
|
||||
parser.error(
|
||||
"--tenant-id and --connector-id are required for get_acls action"
|
||||
"--doc-id, --fields, and --connector-id are required for update action"
|
||||
)
|
||||
get_document_acls(args.tenant_id, args.connector_id, args.n)
|
||||
fields = json.loads(args.fields)
|
||||
vespa_debug.update_document(args.connector_id, args.doc_id, fields)
|
||||
elif args.action == "delete":
|
||||
if not args.doc_id or args.connector_id is None:
|
||||
parser.error("--doc-id and --connector-id are required for delete action")
|
||||
vespa_debug.delete_document(args.connector_id, args.doc_id)
|
||||
elif args.action == "get_acls":
|
||||
if args.connector_id is None:
|
||||
parser.error("--connector-id is required for get_acls action")
|
||||
vespa_debug.acls(args.connector_id, args.n)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -5,6 +5,8 @@ import sys
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.document_index.document_index_utils import get_multipass_config
|
||||
|
||||
# makes it so `PYTHONPATH=.` is not required when running this script
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
@@ -54,8 +56,14 @@ def main() -> None:
|
||||
|
||||
# Setup Vespa index
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
multipass_config = get_multipass_config(search_settings)
|
||||
index_name = search_settings.index_name
|
||||
vespa_index = VespaIndex(index_name=index_name, secondary_index_name=None)
|
||||
vespa_index = VespaIndex(
|
||||
index_name=index_name,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=multipass_config.enable_large_chunks,
|
||||
secondary_large_chunks_enabled=None,
|
||||
)
|
||||
|
||||
# Delete chunks from Vespa first
|
||||
print("Deleting orphaned document chunks from Vespa")
|
||||
|
||||
@@ -16,6 +16,7 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.document_index_utils import get_multipass_config
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.indexing.indexing_pipeline import IndexBatchParams
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
@@ -133,10 +134,16 @@ def seed_dummy_docs(
|
||||
) -> None:
|
||||
with get_session_context_manager() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
multipass_config = get_multipass_config(search_settings)
|
||||
index_name = search_settings.index_name
|
||||
embedding_dim = search_settings.model_dim
|
||||
|
||||
vespa_index = VespaIndex(index_name=index_name, secondary_index_name=None)
|
||||
vespa_index = VespaIndex(
|
||||
index_name=index_name,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=multipass_config.enable_large_chunks,
|
||||
secondary_large_chunks_enabled=None,
|
||||
)
|
||||
print(index_name)
|
||||
|
||||
all_chunks = []
|
||||
|
||||
@@ -9,6 +9,7 @@ from onyx.configs.model_configs import DOC_EMBEDDING_DIM
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.document_index_utils import get_multipass_config
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from scripts.query_time_check.seed_dummy_docs import TOTAL_ACL_ENTRIES_PER_CATEGORY
|
||||
from scripts.query_time_check.seed_dummy_docs import TOTAL_DOC_SETS
|
||||
@@ -62,9 +63,15 @@ def test_hybrid_retrieval_times(
|
||||
) -> None:
|
||||
with get_session_context_manager() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
multipass_config = get_multipass_config(search_settings)
|
||||
index_name = search_settings.index_name
|
||||
|
||||
vespa_index = VespaIndex(index_name=index_name, secondary_index_name=None)
|
||||
vespa_index = VespaIndex(
|
||||
index_name=index_name,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=multipass_config.enable_large_chunks,
|
||||
secondary_large_chunks_enabled=None,
|
||||
)
|
||||
|
||||
# Generate random queries
|
||||
queries = [f"Random Query {i}" for i in range(number_of_queries)]
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
@@ -10,25 +10,24 @@ from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=[
|
||||
("table_name", os.environ["AIRTABLE_TEST_TABLE_NAME"]),
|
||||
("table_id", os.environ["AIRTABLE_TEST_TABLE_ID"]),
|
||||
]
|
||||
)
|
||||
def airtable_connector(request: pytest.FixtureRequest) -> AirtableConnector:
|
||||
param_type, table_identifier = request.param
|
||||
connector = AirtableConnector(
|
||||
base_id=os.environ["AIRTABLE_TEST_BASE_ID"],
|
||||
table_name_or_id=table_identifier,
|
||||
)
|
||||
class AirtableConfig(BaseModel):
|
||||
base_id: str
|
||||
table_identifier: str
|
||||
access_token: str
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
"airtable_access_token": os.environ["AIRTABLE_ACCESS_TOKEN"],
|
||||
}
|
||||
|
||||
@pytest.fixture(params=[True, False])
|
||||
def airtable_config(request: pytest.FixtureRequest) -> AirtableConfig:
|
||||
table_identifier = (
|
||||
os.environ["AIRTABLE_TEST_TABLE_NAME"]
|
||||
if request.param
|
||||
else os.environ["AIRTABLE_TEST_TABLE_ID"]
|
||||
)
|
||||
return AirtableConfig(
|
||||
base_id=os.environ["AIRTABLE_TEST_BASE_ID"],
|
||||
table_identifier=table_identifier,
|
||||
access_token=os.environ["AIRTABLE_ACCESS_TOKEN"],
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
def create_test_document(
|
||||
@@ -46,18 +45,37 @@ def create_test_document(
|
||||
assignee: str,
|
||||
days_since_status_change: int | None,
|
||||
attachments: list[tuple[str, str]] | None = None,
|
||||
all_fields_as_metadata: bool = False,
|
||||
) -> Document:
|
||||
link_base = f"https://airtable.com/{os.environ['AIRTABLE_TEST_BASE_ID']}/{os.environ['AIRTABLE_TEST_TABLE_ID']}"
|
||||
sections = [
|
||||
Section(
|
||||
text=f"Title:\n------------------------\n{title}\n------------------------",
|
||||
link=f"{link_base}/{id}",
|
||||
),
|
||||
Section(
|
||||
text=f"Description:\n------------------------\n{description}\n------------------------",
|
||||
link=f"{link_base}/{id}",
|
||||
),
|
||||
]
|
||||
base_id = os.environ.get("AIRTABLE_TEST_BASE_ID")
|
||||
table_id = os.environ.get("AIRTABLE_TEST_TABLE_ID")
|
||||
missing_vars = []
|
||||
if not base_id:
|
||||
missing_vars.append("AIRTABLE_TEST_BASE_ID")
|
||||
if not table_id:
|
||||
missing_vars.append("AIRTABLE_TEST_TABLE_ID")
|
||||
|
||||
if missing_vars:
|
||||
raise RuntimeError(
|
||||
f"Required environment variables not set: {', '.join(missing_vars)}. "
|
||||
"These variables are required to run Airtable connector tests."
|
||||
)
|
||||
link_base = f"https://airtable.com/{base_id}/{table_id}"
|
||||
sections = []
|
||||
|
||||
if not all_fields_as_metadata:
|
||||
sections.extend(
|
||||
[
|
||||
Section(
|
||||
text=f"Title:\n------------------------\n{title}\n------------------------",
|
||||
link=f"{link_base}/{id}",
|
||||
),
|
||||
Section(
|
||||
text=f"Description:\n------------------------\n{description}\n------------------------",
|
||||
link=f"{link_base}/{id}",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
if attachments:
|
||||
for attachment_text, attachment_link in attachments:
|
||||
@@ -68,26 +86,36 @@ def create_test_document(
|
||||
),
|
||||
)
|
||||
|
||||
metadata: dict[str, str | list[str]] = {
|
||||
# "Category": category,
|
||||
"Assignee": assignee,
|
||||
"Submitted by": submitted_by,
|
||||
"Priority": priority,
|
||||
"Status": status,
|
||||
"Created time": created_time,
|
||||
"ID": ticket_id,
|
||||
"Status last changed": status_last_changed,
|
||||
**(
|
||||
{"Days since status change": str(days_since_status_change)}
|
||||
if days_since_status_change is not None
|
||||
else {}
|
||||
),
|
||||
}
|
||||
|
||||
if all_fields_as_metadata:
|
||||
metadata.update(
|
||||
{
|
||||
"Title": title,
|
||||
"Description": description,
|
||||
}
|
||||
)
|
||||
|
||||
return Document(
|
||||
id=f"airtable__{id}",
|
||||
sections=sections,
|
||||
source=DocumentSource.AIRTABLE,
|
||||
semantic_identifier=f"{os.environ['AIRTABLE_TEST_TABLE_NAME']}: {title}",
|
||||
metadata={
|
||||
# "Category": category,
|
||||
"Assignee": assignee,
|
||||
"Submitted by": submitted_by,
|
||||
"Priority": priority,
|
||||
"Status": status,
|
||||
"Created time": created_time,
|
||||
"ID": ticket_id,
|
||||
"Status last changed": status_last_changed,
|
||||
**(
|
||||
{"Days since status change": str(days_since_status_change)}
|
||||
if days_since_status_change is not None
|
||||
else {}
|
||||
),
|
||||
},
|
||||
semantic_identifier=f"{os.environ.get('AIRTABLE_TEST_TABLE_NAME', '')}: {title}",
|
||||
metadata=metadata,
|
||||
doc_updated_at=None,
|
||||
primary_owners=None,
|
||||
secondary_owners=None,
|
||||
@@ -97,15 +125,75 @@ def create_test_document(
|
||||
)
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
)
|
||||
def test_airtable_connector_basic(
|
||||
mock_get_api_key: MagicMock, airtable_connector: AirtableConnector
|
||||
def compare_documents(
|
||||
actual_docs: list[Document], expected_docs: list[Document]
|
||||
) -> None:
|
||||
doc_batch_generator = airtable_connector.load_from_state()
|
||||
"""Utility function to compare actual and expected documents, ignoring order."""
|
||||
actual_docs_dict = {doc.id: doc for doc in actual_docs}
|
||||
expected_docs_dict = {doc.id: doc for doc in expected_docs}
|
||||
|
||||
assert actual_docs_dict.keys() == expected_docs_dict.keys(), "Document ID mismatch"
|
||||
|
||||
for doc_id in actual_docs_dict:
|
||||
actual = actual_docs_dict[doc_id]
|
||||
expected = expected_docs_dict[doc_id]
|
||||
|
||||
assert (
|
||||
actual.source == expected.source
|
||||
), f"Source mismatch for document {doc_id}"
|
||||
assert (
|
||||
actual.semantic_identifier == expected.semantic_identifier
|
||||
), f"Semantic identifier mismatch for document {doc_id}"
|
||||
assert (
|
||||
actual.metadata == expected.metadata
|
||||
), f"Metadata mismatch for document {doc_id}"
|
||||
assert (
|
||||
actual.doc_updated_at == expected.doc_updated_at
|
||||
), f"Updated at mismatch for document {doc_id}"
|
||||
assert (
|
||||
actual.primary_owners == expected.primary_owners
|
||||
), f"Primary owners mismatch for document {doc_id}"
|
||||
assert (
|
||||
actual.secondary_owners == expected.secondary_owners
|
||||
), f"Secondary owners mismatch for document {doc_id}"
|
||||
assert actual.title == expected.title, f"Title mismatch for document {doc_id}"
|
||||
assert (
|
||||
actual.from_ingestion_api == expected.from_ingestion_api
|
||||
), f"Ingestion API flag mismatch for document {doc_id}"
|
||||
assert (
|
||||
actual.additional_info == expected.additional_info
|
||||
), f"Additional info mismatch for document {doc_id}"
|
||||
|
||||
# Compare sections
|
||||
assert len(actual.sections) == len(
|
||||
expected.sections
|
||||
), f"Number of sections mismatch for document {doc_id}"
|
||||
for i, (actual_section, expected_section) in enumerate(
|
||||
zip(actual.sections, expected.sections)
|
||||
):
|
||||
assert (
|
||||
actual_section.text == expected_section.text
|
||||
), f"Section {i} text mismatch for document {doc_id}"
|
||||
assert (
|
||||
actual_section.link == expected_section.link
|
||||
), f"Section {i} link mismatch for document {doc_id}"
|
||||
|
||||
|
||||
def test_airtable_connector_basic(
|
||||
mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig
|
||||
) -> None:
|
||||
"""Test behavior when all non-attachment fields are treated as metadata."""
|
||||
connector = AirtableConnector(
|
||||
base_id=airtable_config.base_id,
|
||||
table_name_or_id=airtable_config.table_identifier,
|
||||
treat_all_non_attachment_fields_as_metadata=False,
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"airtable_access_token": airtable_config.access_token,
|
||||
}
|
||||
)
|
||||
doc_batch_generator = connector.load_from_state()
|
||||
doc_batch = next(doc_batch_generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
@@ -119,15 +207,62 @@ def test_airtable_connector_basic(
|
||||
description="The internet connection is very slow.",
|
||||
priority="Medium",
|
||||
status="In Progress",
|
||||
# Link to another record is skipped for now
|
||||
# category="Data Science",
|
||||
ticket_id="2",
|
||||
created_time="2024-12-24T21:02:49.000Z",
|
||||
status_last_changed="2024-12-24T21:02:49.000Z",
|
||||
days_since_status_change=0,
|
||||
assignee="Chris Weaver (chris@onyx.app)",
|
||||
submitted_by="Chris Weaver (chris@onyx.app)",
|
||||
all_fields_as_metadata=False,
|
||||
),
|
||||
create_test_document(
|
||||
id="reccSlIA4pZEFxPBg",
|
||||
title="Printer Issue",
|
||||
description="The office printer is not working.",
|
||||
priority="High",
|
||||
status="Open",
|
||||
ticket_id="1",
|
||||
created_time="2024-12-24T21:02:49.000Z",
|
||||
status_last_changed="2024-12-24T21:02:49.000Z",
|
||||
days_since_status_change=0,
|
||||
assignee="Chris Weaver (chris@onyx.app)",
|
||||
submitted_by="Chris Weaver (chris@onyx.app)",
|
||||
attachments=[
|
||||
(
|
||||
"Test.pdf:\ntesting!!!",
|
||||
"https://airtable.com/appCXJqDFS4gea8tn/tblRxFQsTlBBZdRY1/viwVUEJjWPd8XYjh8/reccSlIA4pZEFxPBg/fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide",
|
||||
)
|
||||
],
|
||||
all_fields_as_metadata=False,
|
||||
),
|
||||
]
|
||||
|
||||
# Compare documents using the utility function
|
||||
compare_documents(doc_batch, expected_docs)
|
||||
|
||||
|
||||
def test_airtable_connector_all_metadata(
|
||||
mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig
|
||||
) -> None:
|
||||
connector = AirtableConnector(
|
||||
base_id=airtable_config.base_id,
|
||||
table_name_or_id=airtable_config.table_identifier,
|
||||
treat_all_non_attachment_fields_as_metadata=True,
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"airtable_access_token": airtable_config.access_token,
|
||||
}
|
||||
)
|
||||
doc_batch_generator = connector.load_from_state()
|
||||
doc_batch = next(doc_batch_generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
# NOTE: one of the rows has no attachments -> no content -> no document
|
||||
assert len(doc_batch) == 1
|
||||
|
||||
expected_docs = [
|
||||
create_test_document(
|
||||
id="reccSlIA4pZEFxPBg",
|
||||
title="Printer Issue",
|
||||
@@ -149,50 +284,9 @@ def test_airtable_connector_basic(
|
||||
"https://airtable.com/appCXJqDFS4gea8tn/tblRxFQsTlBBZdRY1/viwVUEJjWPd8XYjh8/reccSlIA4pZEFxPBg/fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide",
|
||||
)
|
||||
],
|
||||
all_fields_as_metadata=True,
|
||||
),
|
||||
]
|
||||
|
||||
# Compare each document field by field
|
||||
for actual, expected in zip(doc_batch, expected_docs):
|
||||
assert actual.id == expected.id, f"ID mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.source == expected.source
|
||||
), f"Source mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.semantic_identifier == expected.semantic_identifier
|
||||
), f"Semantic identifier mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.metadata == expected.metadata
|
||||
), f"Metadata mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.doc_updated_at == expected.doc_updated_at
|
||||
), f"Updated at mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.primary_owners == expected.primary_owners
|
||||
), f"Primary owners mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.secondary_owners == expected.secondary_owners
|
||||
), f"Secondary owners mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.title == expected.title
|
||||
), f"Title mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.from_ingestion_api == expected.from_ingestion_api
|
||||
), f"Ingestion API flag mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual.additional_info == expected.additional_info
|
||||
), f"Additional info mismatch for document {actual.id}"
|
||||
|
||||
# Compare sections
|
||||
assert len(actual.sections) == len(
|
||||
expected.sections
|
||||
), f"Number of sections mismatch for document {actual.id}"
|
||||
for i, (actual_section, expected_section) in enumerate(
|
||||
zip(actual.sections, expected.sections)
|
||||
):
|
||||
assert (
|
||||
actual_section.text == expected_section.text
|
||||
), f"Section {i} text mismatch for document {actual.id}"
|
||||
assert (
|
||||
actual_section.link == expected_section.link
|
||||
), f"Section {i} link mismatch for document {actual.id}"
|
||||
# Compare documents using the utility function
|
||||
compare_documents(doc_batch, expected_docs)
|
||||
|
||||
14
backend/tests/daily/connectors/conftest.py
Normal file
14
backend/tests/daily/connectors/conftest.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]:
|
||||
with patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
) as mock:
|
||||
yield mock
|
||||
@@ -0,0 +1,210 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpectedDocument:
|
||||
semantic_identifier: str
|
||||
content: str
|
||||
folder_path: str | None = None
|
||||
library: str = "Shared Documents" # Default to main library
|
||||
|
||||
|
||||
EXPECTED_DOCUMENTS = [
|
||||
ExpectedDocument(
|
||||
semantic_identifier="test1.docx",
|
||||
content="test1",
|
||||
folder_path="test",
|
||||
),
|
||||
ExpectedDocument(
|
||||
semantic_identifier="test2.docx",
|
||||
content="test2",
|
||||
folder_path="test/nested with spaces",
|
||||
),
|
||||
ExpectedDocument(
|
||||
semantic_identifier="should-not-index-on-specific-folder.docx",
|
||||
content="should-not-index-on-specific-folder",
|
||||
folder_path=None, # root folder
|
||||
),
|
||||
ExpectedDocument(
|
||||
semantic_identifier="other.docx",
|
||||
content="other",
|
||||
folder_path=None,
|
||||
library="Other Library",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def verify_document_metadata(doc: Document) -> None:
|
||||
"""Verify common metadata that should be present on all documents."""
|
||||
assert isinstance(doc.doc_updated_at, datetime)
|
||||
assert doc.doc_updated_at.tzinfo == timezone.utc
|
||||
assert doc.source == DocumentSource.SHAREPOINT
|
||||
assert doc.primary_owners is not None
|
||||
assert len(doc.primary_owners) == 1
|
||||
owner = doc.primary_owners[0]
|
||||
assert owner.display_name is not None
|
||||
assert owner.email is not None
|
||||
|
||||
|
||||
def verify_document_content(doc: Document, expected: ExpectedDocument) -> None:
|
||||
"""Verify a document matches its expected content."""
|
||||
assert doc.semantic_identifier == expected.semantic_identifier
|
||||
assert len(doc.sections) == 1
|
||||
assert expected.content in doc.sections[0].text
|
||||
verify_document_metadata(doc)
|
||||
|
||||
|
||||
def find_document(documents: list[Document], semantic_identifier: str) -> Document:
|
||||
"""Find a document by its semantic identifier."""
|
||||
matching_docs = [
|
||||
d for d in documents if d.semantic_identifier == semantic_identifier
|
||||
]
|
||||
assert (
|
||||
len(matching_docs) == 1
|
||||
), f"Expected exactly one document with identifier {semantic_identifier}"
|
||||
return matching_docs[0]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sharepoint_credentials() -> dict[str, str]:
|
||||
return {
|
||||
"sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"],
|
||||
"sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"],
|
||||
"sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"],
|
||||
}
|
||||
|
||||
|
||||
def test_sharepoint_connector_specific_folder(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the test site URL and specific folder
|
||||
connector = SharepointConnector(
|
||||
sites=[os.environ["SHAREPOINT_SITE"] + "/Shared Documents/test"]
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Get all documents
|
||||
document_batches = list(connector.load_from_state())
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
|
||||
# Should only find documents in the test folder
|
||||
test_folder_docs = [
|
||||
doc
|
||||
for doc in EXPECTED_DOCUMENTS
|
||||
if doc.folder_path and doc.folder_path.startswith("test")
|
||||
]
|
||||
assert len(found_documents) == len(
|
||||
test_folder_docs
|
||||
), "Should only find documents in test folder"
|
||||
|
||||
# Verify each expected document
|
||||
for expected in test_folder_docs:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
|
||||
|
||||
def test_sharepoint_connector_root_folder(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the base site URL
|
||||
connector = SharepointConnector(sites=[os.environ["SHAREPOINT_SITE"]])
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Get all documents
|
||||
document_batches = list(connector.load_from_state())
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
|
||||
assert len(found_documents) == len(
|
||||
EXPECTED_DOCUMENTS
|
||||
), "Should find all documents in main library"
|
||||
|
||||
# Verify each expected document
|
||||
for expected in EXPECTED_DOCUMENTS:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
|
||||
|
||||
def test_sharepoint_connector_other_library(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the other library
|
||||
connector = SharepointConnector(
|
||||
sites=[
|
||||
os.environ["SHAREPOINT_SITE"] + "/Other Library",
|
||||
]
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Get all documents
|
||||
document_batches = list(connector.load_from_state())
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
expected_documents: list[ExpectedDocument] = [
|
||||
doc for doc in EXPECTED_DOCUMENTS if doc.library == "Other Library"
|
||||
]
|
||||
|
||||
# Should find all documents in `Other Library`
|
||||
assert len(found_documents) == len(
|
||||
expected_documents
|
||||
), "Should find all documents in `Other Library`"
|
||||
|
||||
# Verify each expected document
|
||||
for expected in expected_documents:
|
||||
doc = find_document(found_documents, expected.semantic_identifier)
|
||||
verify_document_content(doc, expected)
|
||||
|
||||
|
||||
def test_sharepoint_connector_poll(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the base site URL
|
||||
connector = SharepointConnector(
|
||||
sites=["https://danswerai.sharepoint.com/sites/sharepoint-tests"]
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Set time window to only capture test1.docx (modified at 2025-01-28 20:51:42+00:00)
|
||||
start = datetime(2025, 1, 28, 20, 51, 30, tzinfo=timezone.utc) # 12 seconds before
|
||||
end = datetime(2025, 1, 28, 20, 51, 50, tzinfo=timezone.utc) # 8 seconds after
|
||||
|
||||
# Get documents within the time window
|
||||
document_batches = list(connector._fetch_from_sharepoint(start=start, end=end))
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
|
||||
# Should only find test1.docx
|
||||
assert len(found_documents) == 1, "Should only find one document in the time window"
|
||||
doc = found_documents[0]
|
||||
assert doc.semantic_identifier == "test1.docx"
|
||||
verify_document_metadata(doc)
|
||||
verify_document_content(
|
||||
doc, [d for d in EXPECTED_DOCUMENTS if d.semantic_identifier == "test1.docx"][0]
|
||||
)
|
||||
@@ -432,30 +432,61 @@ class CCPairManager:
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
#
|
||||
if result.status_code != 409:
|
||||
result.raise_for_status()
|
||||
|
||||
group_sync_result = requests.post(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-groups",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
if group_sync_result.status_code != 409:
|
||||
group_sync_result.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_sync_task(
|
||||
def get_doc_sync_task(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> datetime | None:
|
||||
response = requests.get(
|
||||
doc_sync_response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-permissions",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
response_str = response.json()
|
||||
doc_sync_response.raise_for_status()
|
||||
doc_sync_response_str = doc_sync_response.json()
|
||||
|
||||
# If the response itself is a datetime string, parse it
|
||||
if not isinstance(response_str, str):
|
||||
if not isinstance(doc_sync_response_str, str):
|
||||
return None
|
||||
|
||||
try:
|
||||
return datetime.fromisoformat(response_str)
|
||||
return datetime.fromisoformat(doc_sync_response_str)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_group_sync_task(
|
||||
cc_pair: DATestCCPair,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> datetime | None:
|
||||
group_sync_response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair.id}/sync-groups",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
group_sync_response.raise_for_status()
|
||||
group_sync_response_str = group_sync_response.json()
|
||||
|
||||
# If the response itself is a datetime string, parse it
|
||||
if not isinstance(group_sync_response_str, str):
|
||||
return None
|
||||
|
||||
try:
|
||||
return datetime.fromisoformat(group_sync_response_str)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@@ -498,15 +529,37 @@ class CCPairManager:
|
||||
timeout: float = MAX_DELAY,
|
||||
number_of_updated_docs: int = 0,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
# Sometimes waiting for a group sync is not necessary
|
||||
should_wait_for_group_sync: bool = True,
|
||||
# Sometimes waiting for a vespa sync is not necessary
|
||||
should_wait_for_vespa_sync: bool = True,
|
||||
) -> None:
|
||||
"""after: The task register time must be after this time."""
|
||||
doc_synced = False
|
||||
group_synced = False
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
last_synced = CCPairManager.get_sync_task(cc_pair, user_performing_action)
|
||||
if last_synced and last_synced > after:
|
||||
print(f"last_synced: {last_synced}")
|
||||
# We are treating both syncs as part of one larger permission sync job
|
||||
doc_last_synced = CCPairManager.get_doc_sync_task(
|
||||
cc_pair, user_performing_action
|
||||
)
|
||||
group_last_synced = CCPairManager.get_group_sync_task(
|
||||
cc_pair, user_performing_action
|
||||
)
|
||||
|
||||
if not doc_synced and doc_last_synced and doc_last_synced > after:
|
||||
print(f"doc_last_synced: {doc_last_synced}")
|
||||
print(f"sync command start time: {after}")
|
||||
print(f"permission sync complete: cc_pair={cc_pair.id}")
|
||||
doc_synced = True
|
||||
|
||||
if not group_synced and group_last_synced and group_last_synced > after:
|
||||
print(f"group_last_synced: {group_last_synced}")
|
||||
print(f"sync command start time: {after}")
|
||||
print(f"group sync complete: cc_pair={cc_pair.id}")
|
||||
group_synced = True
|
||||
|
||||
if doc_synced and (group_synced or not should_wait_for_group_sync):
|
||||
break
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
@@ -524,6 +577,9 @@ class CCPairManager:
|
||||
# this shouldnt be necessary but something is off with the timing for the sync jobs
|
||||
time.sleep(5)
|
||||
|
||||
if not should_wait_for_vespa_sync:
|
||||
return
|
||||
|
||||
print("waiting for vespa sync")
|
||||
# wait for the vespa sync to complete once the permission sync is complete
|
||||
start = time.monotonic()
|
||||
|
||||
@@ -18,6 +18,7 @@ from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import SYNC_DB_API
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.swap_index import check_index_swap
|
||||
from onyx.document_index.document_index_utils import get_multipass_config
|
||||
from onyx.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
@@ -173,10 +174,16 @@ def reset_vespa() -> None:
|
||||
check_index_swap(db_session)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
multipass_config = get_multipass_config(search_settings)
|
||||
index_name = search_settings.index_name
|
||||
|
||||
success = setup_vespa(
|
||||
document_index=VespaIndex(index_name=index_name, secondary_index_name=None),
|
||||
document_index=VespaIndex(
|
||||
index_name=index_name,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=multipass_config.enable_large_chunks,
|
||||
secondary_large_chunks_enabled=None,
|
||||
),
|
||||
index_setting=IndexingSetting.from_db_model(search_settings),
|
||||
secondary_index_setting=None,
|
||||
)
|
||||
@@ -250,10 +257,16 @@ def reset_vespa_multitenant() -> None:
|
||||
check_index_swap(db_session)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
multipass_config = get_multipass_config(search_settings)
|
||||
index_name = search_settings.index_name
|
||||
|
||||
success = setup_vespa(
|
||||
document_index=VespaIndex(index_name=index_name, secondary_index_name=None),
|
||||
document_index=VespaIndex(
|
||||
index_name=index_name,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=multipass_config.enable_large_chunks,
|
||||
secondary_large_chunks_enabled=None,
|
||||
),
|
||||
index_setting=IndexingSetting.from_db_model(search_settings),
|
||||
secondary_index_setting=None,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from google.oauth2.service_account import Credentials
|
||||
|
||||
from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.connectors.google_utils.resources import get_google_docs_service
|
||||
from onyx.connectors.google_utils.resources import GoogleDocsService
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
|
||||
|
||||
GOOGLE_SCOPES = {
|
||||
"google_drive": [
|
||||
"https://www.googleapis.com/auth/drive",
|
||||
"https://www.googleapis.com/auth/admin.directory.group",
|
||||
"https://www.googleapis.com/auth/admin.directory.user",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _create_doc_service(drive_service: GoogleDriveService) -> GoogleDocsService:
|
||||
docs_service = get_google_docs_service(
|
||||
creds=drive_service._http.credentials,
|
||||
user_email=drive_service._http.credentials._subject,
|
||||
)
|
||||
return docs_service
|
||||
|
||||
|
||||
class GoogleDriveManager:
|
||||
@staticmethod
|
||||
def create_impersonated_drive_service(
|
||||
service_account_key: dict, impersonated_user_email: str
|
||||
) -> GoogleDriveService:
|
||||
"""Gets a drive service that impersonates a specific user"""
|
||||
credentials = Credentials.from_service_account_info(
|
||||
service_account_key,
|
||||
scopes=GOOGLE_SCOPES["google_drive"],
|
||||
subject=impersonated_user_email,
|
||||
)
|
||||
|
||||
service = get_drive_service(credentials, impersonated_user_email)
|
||||
|
||||
# Verify impersonation
|
||||
about = service.about().get(fields="user").execute()
|
||||
if about.get("user", {}).get("emailAddress") != impersonated_user_email:
|
||||
raise ValueError(
|
||||
f"Failed to impersonate {impersonated_user_email}. "
|
||||
f"Instead got {about.get('user', {}).get('emailAddress')}"
|
||||
)
|
||||
return service
|
||||
|
||||
@staticmethod
|
||||
def create_shared_drive(
|
||||
drive_service: GoogleDriveService, admin_email: str, test_id: str
|
||||
) -> str:
|
||||
"""
|
||||
Creates a shared drive and returns the drive's ID
|
||||
"""
|
||||
try:
|
||||
about = drive_service.about().get(fields="user").execute()
|
||||
creating_user = about["user"]["emailAddress"]
|
||||
|
||||
# Verify we're still impersonating the admin
|
||||
if creating_user != admin_email:
|
||||
raise ValueError(
|
||||
f"Expected to create drive as {admin_email}, but instead created drive as {creating_user}"
|
||||
)
|
||||
|
||||
drive_metadata = {"name": f"perm_sync_drive_{test_id}"}
|
||||
|
||||
request_id = str(uuid4())
|
||||
drive = (
|
||||
drive_service.drives()
|
||||
.create(
|
||||
body=drive_metadata,
|
||||
requestId=request_id,
|
||||
fields="id,name,capabilities",
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
return drive["id"]
|
||||
except Exception as e:
|
||||
print(f"Error creating shared drive: {str(e)}")
|
||||
raise
|
||||
|
||||
@staticmethod
|
||||
def create_empty_doc(
|
||||
drive_service: Any,
|
||||
drive_id: str,
|
||||
) -> str:
|
||||
"""
|
||||
Creates an empty document in the given drive and returns the document's ID
|
||||
"""
|
||||
file_metadata = {
|
||||
"name": f"perm_sync_doc_{drive_id}_{str(uuid4())}",
|
||||
"mimeType": "application/vnd.google-apps.document",
|
||||
"parents": [drive_id],
|
||||
}
|
||||
file = (
|
||||
drive_service.files()
|
||||
.create(body=file_metadata, supportsAllDrives=True)
|
||||
.execute()
|
||||
)
|
||||
|
||||
return file["id"]
|
||||
|
||||
@staticmethod
|
||||
def append_text_to_doc(
|
||||
drive_service: GoogleDriveService, doc_id: str, text: str
|
||||
) -> None:
|
||||
docs_service = _create_doc_service(drive_service)
|
||||
|
||||
docs_service.documents().batchUpdate(
|
||||
documentId=doc_id,
|
||||
body={
|
||||
"requests": [{"insertText": {"location": {"index": 1}, "text": text}}]
|
||||
},
|
||||
).execute()
|
||||
|
||||
@staticmethod
|
||||
def update_file_permissions(
|
||||
drive_service: Any, file_id: str, email: str, role: str = "reader"
|
||||
) -> None:
|
||||
permission = {"type": "user", "role": role, "emailAddress": email}
|
||||
drive_service.permissions().create(
|
||||
fileId=file_id,
|
||||
body=permission,
|
||||
supportsAllDrives=True,
|
||||
sendNotificationEmail=False,
|
||||
).execute()
|
||||
|
||||
@staticmethod
|
||||
def remove_file_permissions(drive_service: Any, file_id: str, email: str) -> None:
|
||||
permissions = (
|
||||
drive_service.permissions()
|
||||
.list(fileId=file_id, supportsAllDrives=True)
|
||||
.execute()
|
||||
)
|
||||
# TODO: This is a hacky way to remove permissions. Removes anyone with reader role.
|
||||
# Need to find a way to map a user's email to a permission id.
|
||||
# The permissions.get returns a permissionID but email field is None,
|
||||
# something to do with it being a group or domain wide delegation.
|
||||
for permission in permissions.get("permissions", []):
|
||||
if permission.get("role") == "reader":
|
||||
drive_service.permissions().delete(
|
||||
fileId=file_id,
|
||||
permissionId=permission["id"],
|
||||
supportsAllDrives=True,
|
||||
).execute()
|
||||
break
|
||||
|
||||
@staticmethod
|
||||
def make_file_public(drive_service: Any, file_id: str) -> None:
|
||||
permission = {"type": "anyone", "role": "reader"}
|
||||
drive_service.permissions().create(
|
||||
fileId=file_id, body=permission, supportsAllDrives=True
|
||||
).execute()
|
||||
|
||||
@staticmethod
|
||||
def cleanup_drive(drive_service: Any, drive_id: str) -> None:
|
||||
try:
|
||||
# Delete up to 2 files that match our pattern
|
||||
file_name_prefix = f"perm_sync_doc_{drive_id}"
|
||||
files = (
|
||||
drive_service.files()
|
||||
.list(
|
||||
q=f"name contains '{file_name_prefix}'",
|
||||
driveId=drive_id,
|
||||
includeItemsFromAllDrives=True,
|
||||
supportsAllDrives=True,
|
||||
corpora="drive",
|
||||
fields="files(id)",
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
|
||||
for file in files.get("files", []):
|
||||
drive_service.files().delete(
|
||||
fileId=file["id"], supportsAllDrives=True
|
||||
).execute()
|
||||
|
||||
# Then delete the drive
|
||||
drive_service.drives().delete(driveId=drive_id).execute()
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up drive {drive_id}: {e}")
|
||||
@@ -0,0 +1,332 @@
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.enums import AccessType
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.document_search import (
|
||||
DocumentSearchManager,
|
||||
)
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestConnector
|
||||
from tests.integration.common_utils.test_models import DATestCredential
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
from tests.integration.connector_job_tests.google.google_drive_api_utils import (
|
||||
GoogleDriveManager,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def google_drive_test_env_setup() -> (
|
||||
Generator[
|
||||
tuple[
|
||||
GoogleDriveService, str, DATestCCPair, DATestUser, DATestUser, DATestUser
|
||||
],
|
||||
None,
|
||||
None,
|
||||
]
|
||||
):
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(email="admin@onyx-test.com")
|
||||
# Creating a non-admin user
|
||||
test_user_1: DATestUser = UserManager.create(email="test_user_1@onyx-test.com")
|
||||
# Creating a non-admin user
|
||||
test_user_2: DATestUser = UserManager.create(email="test_user_2@onyx-test.com")
|
||||
|
||||
service_account_key = os.environ["FULL_CONTROL_DRIVE_SERVICE_ACCOUNT"]
|
||||
drive_id: str | None = None
|
||||
|
||||
try:
|
||||
credentials = {
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY: admin_user.email,
|
||||
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key,
|
||||
}
|
||||
|
||||
# Setup Google Drive
|
||||
drive_service = GoogleDriveManager.create_impersonated_drive_service(
|
||||
json.loads(service_account_key), admin_user.email
|
||||
)
|
||||
test_id = str(uuid4())
|
||||
drive_id = GoogleDriveManager.create_shared_drive(
|
||||
drive_service, admin_user.email, test_id
|
||||
)
|
||||
|
||||
# Setup Onyx infrastructure
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
before = datetime.now(timezone.utc)
|
||||
credential: DATestCredential = CredentialManager.create(
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
credential_json=credentials,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
connector: DATestConnector = ConnectorManager.create(
|
||||
name="Google Drive Test",
|
||||
input_type=InputType.POLL,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
connector_specific_config={
|
||||
"shared_drive_urls": f"https://drive.google.com/drive/folders/{drive_id}"
|
||||
},
|
||||
access_type=AccessType.SYNC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair: DATestCCPair = CCPairManager.create(
|
||||
credential_id=credential.id,
|
||||
connector_id=connector.id,
|
||||
access_type=AccessType.SYNC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair, after=before, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
yield drive_service, drive_id, cc_pair, admin_user, test_user_1, test_user_2
|
||||
|
||||
except json.JSONDecodeError:
|
||||
pytest.skip("FULL_CONTROL_DRIVE_SERVICE_ACCOUNT is not valid JSON")
|
||||
finally:
|
||||
# Cleanup drive and file
|
||||
if drive_id is not None:
|
||||
GoogleDriveManager.cleanup_drive(drive_service, drive_id)
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="Needs to be tested for flakiness")
|
||||
def test_google_permission_sync(
|
||||
reset: None,
|
||||
vespa_client: vespa_fixture,
|
||||
google_drive_test_env_setup: tuple[
|
||||
GoogleDriveService, str, DATestCCPair, DATestUser, DATestUser, DATestUser
|
||||
],
|
||||
) -> None:
|
||||
(
|
||||
drive_service,
|
||||
drive_id,
|
||||
cc_pair,
|
||||
admin_user,
|
||||
test_user_1,
|
||||
test_user_2,
|
||||
) = google_drive_test_env_setup
|
||||
|
||||
# ----------------------BASELINE TEST----------------------
|
||||
before = datetime.now(timezone.utc)
|
||||
|
||||
# Create empty test doc in drive
|
||||
doc_id_1 = GoogleDriveManager.create_empty_doc(drive_service, drive_id)
|
||||
|
||||
# Append text to doc
|
||||
doc_text_1 = "The secret number is 12345"
|
||||
GoogleDriveManager.append_text_to_doc(drive_service, doc_id_1, doc_text_1)
|
||||
|
||||
# run indexing
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair, after=before, user_performing_action=admin_user
|
||||
)
|
||||
|
||||
# run permission sync
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
number_of_updated_docs=1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify admin has access to document
|
||||
admin_results = DocumentSearchManager.search_documents(
|
||||
query="secret number", user_performing_action=admin_user
|
||||
)
|
||||
assert doc_text_1 in [result.strip("\ufeff") for result in admin_results]
|
||||
|
||||
# Verify test_user_1 cannot access document
|
||||
user1_results = DocumentSearchManager.search_documents(
|
||||
query="secret number", user_performing_action=test_user_1
|
||||
)
|
||||
assert doc_text_1 not in [result.strip("\ufeff") for result in user1_results]
|
||||
|
||||
# ----------------------GRANT USER 1 DOC PERMISSIONS TEST--------------------------
|
||||
before = datetime.now(timezone.utc)
|
||||
|
||||
# Grant user 1 access to document 1
|
||||
GoogleDriveManager.update_file_permissions(
|
||||
drive_service=drive_service,
|
||||
file_id=doc_id_1,
|
||||
email=test_user_1.email,
|
||||
role="reader",
|
||||
)
|
||||
|
||||
# Create a second doc in the drive which user 1 should not have access to
|
||||
doc_id_2 = GoogleDriveManager.create_empty_doc(drive_service, drive_id)
|
||||
doc_text_2 = "The secret number is 67890"
|
||||
GoogleDriveManager.append_text_to_doc(drive_service, doc_id_2, doc_text_2)
|
||||
|
||||
# Run indexing
|
||||
CCPairManager.run_once(cc_pair, admin_user)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Run permission sync
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
number_of_updated_docs=1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify admin can access both documents
|
||||
admin_results = DocumentSearchManager.search_documents(
|
||||
query="secret number", user_performing_action=admin_user
|
||||
)
|
||||
assert {doc_text_1, doc_text_2} == {
|
||||
result.strip("\ufeff") for result in admin_results
|
||||
}
|
||||
|
||||
# Verify user 1 can access document 1
|
||||
user1_results = DocumentSearchManager.search_documents(
|
||||
query="secret number", user_performing_action=test_user_1
|
||||
)
|
||||
assert doc_text_1 in [result.strip("\ufeff") for result in user1_results]
|
||||
|
||||
# Verify user 1 cannot access document 2
|
||||
user1_results_2 = DocumentSearchManager.search_documents(
|
||||
query="secret number", user_performing_action=test_user_1
|
||||
)
|
||||
assert doc_text_2 not in [result.strip("\ufeff") for result in user1_results_2]
|
||||
|
||||
# ----------------------REMOVE USER 1 DOC PERMISSIONS TEST--------------------------
|
||||
before = datetime.now(timezone.utc)
|
||||
|
||||
# Remove user 1 access to document 1
|
||||
GoogleDriveManager.remove_file_permissions(
|
||||
drive_service=drive_service, file_id=doc_id_1, email=test_user_1.email
|
||||
)
|
||||
# Run permission sync
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
number_of_updated_docs=1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify admin can access both documents
|
||||
admin_results = DocumentSearchManager.search_documents(
|
||||
query="secret number", user_performing_action=admin_user
|
||||
)
|
||||
assert {doc_text_1, doc_text_2} == {
|
||||
result.strip("\ufeff") for result in admin_results
|
||||
}
|
||||
|
||||
# Verify user 1 cannot access either document
|
||||
user1_results = DocumentSearchManager.search_documents(
|
||||
query="secret numbers", user_performing_action=test_user_1
|
||||
)
|
||||
assert {result.strip("\ufeff") for result in user1_results} == set()
|
||||
|
||||
# ----------------------GRANT USER 1 DRIVE PERMISSIONS TEST--------------------------
|
||||
before = datetime.now(timezone.utc)
|
||||
|
||||
# Grant user 1 access to drive
|
||||
GoogleDriveManager.update_file_permissions(
|
||||
drive_service=drive_service,
|
||||
file_id=drive_id,
|
||||
email=test_user_1.email,
|
||||
role="reader",
|
||||
)
|
||||
|
||||
# Run permission sync
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
number_of_updated_docs=2,
|
||||
user_performing_action=admin_user,
|
||||
# if we are only updating the group definition for this test we use this varaiable,
|
||||
# since it doesn't result in a vespa sync so we don't want to wait for it
|
||||
should_wait_for_vespa_sync=False,
|
||||
)
|
||||
|
||||
# Verify user 1 can access both documents
|
||||
user1_results = DocumentSearchManager.search_documents(
|
||||
query="secret numbers", user_performing_action=test_user_1
|
||||
)
|
||||
assert {doc_text_1, doc_text_2} == {
|
||||
result.strip("\ufeff") for result in user1_results
|
||||
}
|
||||
|
||||
# ----------------------MAKE DRIVE PUBLIC TEST--------------------------
|
||||
before = datetime.now(timezone.utc)
|
||||
|
||||
# Unable to make drive itself public as Google's security policies prevent this, so we make the documents public instead
|
||||
GoogleDriveManager.make_file_public(drive_service, doc_id_1)
|
||||
GoogleDriveManager.make_file_public(drive_service, doc_id_2)
|
||||
|
||||
# Run permission sync
|
||||
CCPairManager.sync(
|
||||
cc_pair=cc_pair,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_sync(
|
||||
cc_pair=cc_pair,
|
||||
after=before,
|
||||
number_of_updated_docs=2,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify all users can access both documents
|
||||
admin_results = DocumentSearchManager.search_documents(
|
||||
query="secret number", user_performing_action=admin_user
|
||||
)
|
||||
assert {doc_text_1, doc_text_2} == {
|
||||
result.strip("\ufeff") for result in admin_results
|
||||
}
|
||||
|
||||
user1_results = DocumentSearchManager.search_documents(
|
||||
query="secret number", user_performing_action=test_user_1
|
||||
)
|
||||
assert {doc_text_1, doc_text_2} == {
|
||||
result.strip("\ufeff") for result in user1_results
|
||||
}
|
||||
|
||||
user2_results = DocumentSearchManager.search_documents(
|
||||
query="secret number", user_performing_action=test_user_2
|
||||
)
|
||||
assert {doc_text_1, doc_text_2} == {
|
||||
result.strip("\ufeff") for result in user2_results
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
from onyx.document_index.vespa.shared_utils.utils import remove_invalid_unicode_chars
|
||||
|
||||
|
||||
def test_remove_invalid_unicode_chars() -> None:
|
||||
"""Test that invalid Unicode characters are properly removed."""
|
||||
# Test removal of illegal XML character 0xFDDB
|
||||
text_with_illegal_char = "Valid text \uFDDB more text"
|
||||
sanitized = remove_invalid_unicode_chars(text_with_illegal_char)
|
||||
assert "\uFDDB" not in sanitized
|
||||
assert sanitized == "Valid text more text"
|
||||
|
||||
# Test that valid characters are preserved
|
||||
valid_text = "Hello, world! 你好世界"
|
||||
assert remove_invalid_unicode_chars(valid_text) == valid_text
|
||||
|
||||
# Test multiple invalid characters including 0xFDDB
|
||||
text_with_multiple_illegal = "\x00Hello\uFDDB World\uFFFE!"
|
||||
sanitized = remove_invalid_unicode_chars(text_with_multiple_illegal)
|
||||
assert all(c not in sanitized for c in ["\x00", "\uFDDB", "\uFFFE"])
|
||||
assert sanitized == "Hello World!"
|
||||
@@ -6,7 +6,7 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.document_index.document_index_utils import get_both_index_names
|
||||
from onyx.document_index.document_index_utils import get_both_index_properties
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ def test_vespa_update() -> None:
|
||||
doc_id = "test-vespa-update"
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
primary_index_name, _ = get_both_index_names(db_session)
|
||||
primary_index_name, _, _, _ = get_both_index_properties(db_session)
|
||||
endpoint = (
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=primary_index_name)}/{doc_id}"
|
||||
)
|
||||
|
||||
183
package-lock.json
generated
183
package-lock.json
generated
@@ -1,183 +0,0 @@
|
||||
{
|
||||
"name": "onyx",
|
||||
"lockfileVersion": 3,
|
||||
"requires": true,
|
||||
"packages": {
|
||||
"": {
|
||||
"dependencies": {
|
||||
"react-datepicker": "^7.6.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/react-datepicker": "^6.2.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@floating-ui/core": {
|
||||
"version": "1.6.9",
|
||||
"resolved": "https://registry.npmjs.org/@floating-ui/core/-/core-1.6.9.tgz",
|
||||
"integrity": "sha512-uMXCuQ3BItDUbAMhIXw7UPXRfAlOAvZzdK9BWpE60MCn+Svt3aLn9jsPTi/WNGlRUu2uI0v5S7JiIUsbsvh3fw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@floating-ui/utils": "^0.2.9"
|
||||
}
|
||||
},
|
||||
"node_modules/@floating-ui/dom": {
|
||||
"version": "1.6.13",
|
||||
"resolved": "https://registry.npmjs.org/@floating-ui/dom/-/dom-1.6.13.tgz",
|
||||
"integrity": "sha512-umqzocjDgNRGTuO7Q8CU32dkHkECqI8ZdMZ5Swb6QAM0t5rnlrN3lGo1hdpscRd3WS8T6DKYK4ephgIH9iRh3w==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@floating-ui/core": "^1.6.0",
|
||||
"@floating-ui/utils": "^0.2.9"
|
||||
}
|
||||
},
|
||||
"node_modules/@floating-ui/react": {
|
||||
"version": "0.27.3",
|
||||
"resolved": "https://registry.npmjs.org/@floating-ui/react/-/react-0.27.3.tgz",
|
||||
"integrity": "sha512-CLHnes3ixIFFKVQDdICjel8muhFLOBdQH7fgtHNPY8UbCNqbeKZ262G7K66lGQOUQWWnYocf7ZbUsLJgGfsLHg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@floating-ui/react-dom": "^2.1.2",
|
||||
"@floating-ui/utils": "^0.2.9",
|
||||
"tabbable": "^6.0.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": ">=17.0.0",
|
||||
"react-dom": ">=17.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@floating-ui/react-dom": {
|
||||
"version": "2.1.2",
|
||||
"resolved": "https://registry.npmjs.org/@floating-ui/react-dom/-/react-dom-2.1.2.tgz",
|
||||
"integrity": "sha512-06okr5cgPzMNBy+Ycse2A6udMi4bqwW/zgBF/rwjcNqWkyr82Mcg8b0vjX8OJpZFy/FKjJmw6wV7t44kK6kW7A==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@floating-ui/dom": "^1.0.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": ">=16.8.0",
|
||||
"react-dom": ">=16.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@floating-ui/utils": {
|
||||
"version": "0.2.9",
|
||||
"resolved": "https://registry.npmjs.org/@floating-ui/utils/-/utils-0.2.9.tgz",
|
||||
"integrity": "sha512-MDWhGtE+eHw5JW7lq4qhc5yRLS11ERl1c7Z6Xd0a58DozHES6EnNNwUWbMiG4J9Cgj053Bhk8zvlhFYKVhULwg==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@types/react": {
|
||||
"version": "19.0.4",
|
||||
"resolved": "https://registry.npmjs.org/@types/react/-/react-19.0.4.tgz",
|
||||
"integrity": "sha512-3O4QisJDYr1uTUMZHA2YswiQZRq+Pd8D+GdVFYikTutYsTz+QZgWkAPnP7rx9txoI6EXKcPiluMqWPFV3tT9Wg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"csstype": "^3.0.2"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/react-datepicker": {
|
||||
"version": "6.2.0",
|
||||
"resolved": "https://registry.npmjs.org/@types/react-datepicker/-/react-datepicker-6.2.0.tgz",
|
||||
"integrity": "sha512-+JtO4Fm97WLkJTH8j8/v3Ldh7JCNRwjMYjRaKh4KHH0M3jJoXtwiD3JBCsdlg3tsFIw9eQSqyAPeVDN2H2oM9Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@floating-ui/react": "^0.26.2",
|
||||
"@types/react": "*",
|
||||
"date-fns": "^3.3.1"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/react-datepicker/node_modules/@floating-ui/react": {
|
||||
"version": "0.26.28",
|
||||
"resolved": "https://registry.npmjs.org/@floating-ui/react/-/react-0.26.28.tgz",
|
||||
"integrity": "sha512-yORQuuAtVpiRjpMhdc0wJj06b9JFjrYF4qp96j++v2NBpbi6SEGF7donUJ3TMieerQ6qVkAv1tgr7L4r5roTqw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@floating-ui/react-dom": "^2.1.2",
|
||||
"@floating-ui/utils": "^0.2.8",
|
||||
"tabbable": "^6.0.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": ">=16.8.0",
|
||||
"react-dom": ">=16.8.0"
|
||||
}
|
||||
},
|
||||
"node_modules/clsx": {
|
||||
"version": "2.1.1",
|
||||
"resolved": "https://registry.npmjs.org/clsx/-/clsx-2.1.1.tgz",
|
||||
"integrity": "sha512-eYm0QWBtUrBWZWG0d386OGAw16Z995PiOVo2B7bjWSbHedGl5e0ZWaq65kOGgUSNesEIDkB9ISbTg/JK9dhCZA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/csstype": {
|
||||
"version": "3.1.3",
|
||||
"resolved": "https://registry.npmjs.org/csstype/-/csstype-3.1.3.tgz",
|
||||
"integrity": "sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/date-fns": {
|
||||
"version": "3.6.0",
|
||||
"resolved": "https://registry.npmjs.org/date-fns/-/date-fns-3.6.0.tgz",
|
||||
"integrity": "sha512-fRHTG8g/Gif+kSh50gaGEdToemgfj74aRX3swtiouboip5JDLAyDE9F11nHMIcvOaXeOC6D7SpNhi7uFyB7Uww==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/kossnocorp"
|
||||
}
|
||||
},
|
||||
"node_modules/react": {
|
||||
"version": "19.0.0",
|
||||
"resolved": "https://registry.npmjs.org/react/-/react-19.0.0.tgz",
|
||||
"integrity": "sha512-V8AVnmPIICiWpGfm6GLzCR/W5FXLchHop40W4nXBmdlEceh16rCN8O8LNWm5bh5XUX91fh7KpA+W0TgMKmgTpQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/react-datepicker": {
|
||||
"version": "7.6.0",
|
||||
"resolved": "https://registry.npmjs.org/react-datepicker/-/react-datepicker-7.6.0.tgz",
|
||||
"integrity": "sha512-9cQH6Z/qa4LrGhzdc3XoHbhrxNcMi9MKjZmYgF/1MNNaJwvdSjv3Xd+jjvrEEbKEf71ZgCA3n7fQbdwd70qCRw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@floating-ui/react": "^0.27.0",
|
||||
"clsx": "^2.1.1",
|
||||
"date-fns": "^3.6.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": "^16.9.0 || ^17 || ^18 || ^19 || ^19.0.0-rc",
|
||||
"react-dom": "^16.9.0 || ^17 || ^18 || ^19 || ^19.0.0-rc"
|
||||
}
|
||||
},
|
||||
"node_modules/react-dom": {
|
||||
"version": "19.0.0",
|
||||
"resolved": "https://registry.npmjs.org/react-dom/-/react-dom-19.0.0.tgz",
|
||||
"integrity": "sha512-4GV5sHFG0e/0AD4X+ySy6UJd3jVl1iNsNHdpad0qhABJ11twS3TTBnseqsKurKcsNqCEFeGL3uLpVChpIO3QfQ==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"scheduler": "^0.25.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": "^19.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/scheduler": {
|
||||
"version": "0.25.0",
|
||||
"resolved": "https://registry.npmjs.org/scheduler/-/scheduler-0.25.0.tgz",
|
||||
"integrity": "sha512-xFVuu11jh+xcO7JOAGJNOXld8/TcEHK/4CituBUeUb5hqxJLj9YuemAEuvm9gQ/+pgXYfbQuqAkiYu+u7YEsNA==",
|
||||
"license": "MIT",
|
||||
"peer": true
|
||||
},
|
||||
"node_modules/tabbable": {
|
||||
"version": "6.2.0",
|
||||
"resolved": "https://registry.npmjs.org/tabbable/-/tabbable-6.2.0.tgz",
|
||||
"integrity": "sha512-Cat63mxsVJlzYvN51JmVXIgNoUokrIaT2zLclCXjRd8boZ0004U4KCs/sToJ75C6sdlByWxpYnb5Boif1VSFew==",
|
||||
"license": "MIT"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
{
|
||||
"dependencies": {
|
||||
"react-datepicker": "^7.6.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/react-datepicker": "^6.2.0"
|
||||
}
|
||||
}
|
||||
@@ -18,6 +18,9 @@ FROM base AS builder
|
||||
RUN apk add --no-cache libc6-compat
|
||||
WORKDIR /app
|
||||
|
||||
# Add NODE_OPTIONS argument
|
||||
ARG NODE_OPTIONS
|
||||
|
||||
# pull in source code / package.json / package-lock.json
|
||||
COPY . .
|
||||
|
||||
@@ -78,7 +81,8 @@ ENV NEXT_PUBLIC_GTM_ENABLED=${NEXT_PUBLIC_GTM_ENABLED}
|
||||
ARG NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED
|
||||
ENV NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=${NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED}
|
||||
|
||||
RUN NODE_OPTIONS="--max-old-space-size=8192" npx next build
|
||||
# Use NODE_OPTIONS in the build command
|
||||
RUN NODE_OPTIONS="${NODE_OPTIONS}" npx next build
|
||||
|
||||
# Step 2. Production image, copy all the files and run next
|
||||
FROM base AS runner
|
||||
|
||||
@@ -86,14 +86,16 @@ const sentryWebpackPluginOptions = {
|
||||
authToken: process.env.SENTRY_AUTH_TOKEN,
|
||||
silent: !sentryEnabled, // Silence output when Sentry is disabled
|
||||
dryRun: !sentryEnabled, // Don't upload source maps when Sentry is disabled
|
||||
sourceMaps: {
|
||||
include: ["./.next"],
|
||||
ignore: ["node_modules"],
|
||||
urlPrefix: "~/_next",
|
||||
stripPrefix: ["webpack://_N_E/"],
|
||||
validate: true,
|
||||
cleanArtifacts: true,
|
||||
},
|
||||
...(sentryEnabled && {
|
||||
sourceMaps: {
|
||||
include: ["./.next"],
|
||||
ignore: ["node_modules"],
|
||||
urlPrefix: "~/_next",
|
||||
stripPrefix: ["webpack://_N_E/"],
|
||||
validate: true,
|
||||
cleanArtifacts: true,
|
||||
},
|
||||
}),
|
||||
};
|
||||
|
||||
// Export the module with conditional Sentry configuration
|
||||
|
||||
132
web/package-lock.json
generated
132
web/package-lock.json
generated
@@ -16,6 +16,7 @@
|
||||
"@headlessui/tailwindcss": "^0.2.1",
|
||||
"@phosphor-icons/react": "^2.0.8",
|
||||
"@radix-ui/react-checkbox": "^1.1.2",
|
||||
"@radix-ui/react-collapsible": "^1.1.2",
|
||||
"@radix-ui/react-dialog": "^1.1.2",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.4",
|
||||
"@radix-ui/react-label": "^2.1.1",
|
||||
@@ -3507,6 +3508,137 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-collapsible": {
|
||||
"version": "1.1.2",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-collapsible/-/react-collapsible-1.1.2.tgz",
|
||||
"integrity": "sha512-PliMB63vxz7vggcyq0IxNYk8vGDrLXVWw4+W4B8YnwI1s18x7YZYqlG9PLX7XxAJUi0g2DxP4XKJMFHh/iVh9A==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/primitive": "1.1.1",
|
||||
"@radix-ui/react-compose-refs": "1.1.1",
|
||||
"@radix-ui/react-context": "1.1.1",
|
||||
"@radix-ui/react-id": "1.1.0",
|
||||
"@radix-ui/react-presence": "1.1.2",
|
||||
"@radix-ui/react-primitive": "2.0.1",
|
||||
"@radix-ui/react-use-controllable-state": "1.1.0",
|
||||
"@radix-ui/react-use-layout-effect": "1.1.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-collapsible/node_modules/@radix-ui/primitive": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.1.1.tgz",
|
||||
"integrity": "sha512-SJ31y+Q/zAyShtXJc8x83i9TYdbAfHZ++tUZnvjJJqFjzsdUnKsxPL6IEtBlxKkU7yzer//GQtZSV4GbldL3YA==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@radix-ui/react-collapsible/node_modules/@radix-ui/react-compose-refs": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-compose-refs/-/react-compose-refs-1.1.1.tgz",
|
||||
"integrity": "sha512-Y9VzoRDSJtgFMUCoiZBDVo084VQ5hfpXxVE+NgkdNsjiDBByiImMZKKhxMwCbdHvhlENG6a833CbFkOQvTricw==",
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-collapsible/node_modules/@radix-ui/react-context": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.1.1.tgz",
|
||||
"integrity": "sha512-UASk9zi+crv9WteK/NU4PLvOoL3OuE6BWVKNF6hPRBtYBDXQ2u5iu3O59zUlJiTVvkyuycnqrztsHVJwcK9K+Q==",
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-collapsible/node_modules/@radix-ui/react-presence": {
|
||||
"version": "1.1.2",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-presence/-/react-presence-1.1.2.tgz",
|
||||
"integrity": "sha512-18TFr80t5EVgL9x1SwF/YGtfG+l0BS0PRAlCWBDoBEiDQjeKgnNZRVJp/oVBl24sr3Gbfwc/Qpj4OcWTQMsAEg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-compose-refs": "1.1.1",
|
||||
"@radix-ui/react-use-layout-effect": "1.1.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-collapsible/node_modules/@radix-ui/react-primitive": {
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.0.1.tgz",
|
||||
"integrity": "sha512-sHCWTtxwNn3L3fH8qAfnF3WbUZycW93SM1j3NFDzXBiz8D6F5UTTy8G1+WFEaiCdvCVRJWj6N2R4Xq6HdiHmDg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-slot": "1.1.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-collapsible/node_modules/@radix-ui/react-slot": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.1.1.tgz",
|
||||
"integrity": "sha512-RApLLOcINYJA+dMVbOju7MYv1Mb2EBp2nH4HdDzXTSyaR5optlm6Otrz1euW3HbdOR8UmmFK06TD+A9frYWv+g==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-compose-refs": "1.1.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-collection": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-collection/-/react-collection-1.1.0.tgz",
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
"@headlessui/tailwindcss": "^0.2.1",
|
||||
"@phosphor-icons/react": "^2.0.8",
|
||||
"@radix-ui/react-checkbox": "^1.1.2",
|
||||
"@radix-ui/react-collapsible": "^1.1.2",
|
||||
"@radix-ui/react-dialog": "^1.1.2",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.4",
|
||||
"@radix-ui/react-label": "^2.1.1",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user