mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-19 08:45:47 +00:00
Compare commits
2 Commits
checkmark_
...
testing
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
25b38212e9 | ||
|
|
3096b0b2a7 |
1
.github/pull_request_template.md
vendored
1
.github/pull_request_template.md
vendored
@@ -11,4 +11,5 @@
|
||||
Note: You have to check that the action passes, otherwise resolve the conflicts manually and tag the patches.
|
||||
|
||||
- [ ] This PR should be backported (make sure to check that the backport attempt succeeds)
|
||||
- [ ] I have included a link to a Linear ticket in my description.
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
@@ -67,7 +67,6 @@ jobs:
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
@@ -60,8 +60,6 @@ jobs:
|
||||
push: true
|
||||
build-args: |
|
||||
ONYX_VERSION=${{ github.ref_name }}
|
||||
NODE_OPTIONS=--max-old-space-size=8192
|
||||
|
||||
# needed due to weird interactions with the builds for different platforms
|
||||
no-cache: true
|
||||
labels: ${{ steps.meta.outputs.labels }}
|
||||
|
||||
4
.github/workflows/pr-linear-check.yml
vendored
4
.github/workflows/pr-linear-check.yml
vendored
@@ -9,9 +9,9 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check PR body for Linear link or override
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
PR_BODY="${{ github.event.pull_request.body }}"
|
||||
|
||||
# Looking for "https://linear.app" in the body
|
||||
if echo "$PR_BODY" | grep -qE "https://linear\.app"; then
|
||||
echo "Found a Linear link. Check passed."
|
||||
|
||||
@@ -119,7 +119,7 @@ There are two editions of Onyx:
|
||||
- Whitelabeling
|
||||
- API key authentication
|
||||
- Encryption of secrets
|
||||
- And many more! Checkout [our website](https://www.onyx.app/) for the latest.
|
||||
- Any many more! Checkout [our website](https://www.onyx.app/) for the latest.
|
||||
|
||||
To try the Onyx Enterprise Edition:
|
||||
|
||||
|
||||
@@ -9,10 +9,8 @@ founders@onyx.app for more information. Please visit https://github.com/onyx-dot
|
||||
|
||||
# Default ONYX_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG ONYX_VERSION=0.8-dev
|
||||
# DO_NOT_TRACK is used to disable telemetry for Unstructured
|
||||
ENV ONYX_VERSION=${ONYX_VERSION} \
|
||||
DANSWER_RUNNING_IN_DOCKER="true" \
|
||||
DO_NOT_TRACK="true"
|
||||
DANSWER_RUNNING_IN_DOCKER="true"
|
||||
|
||||
|
||||
RUN echo "ONYX_VERSION: ${ONYX_VERSION}"
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
"""add user files
|
||||
|
||||
Revision ID: 9aadf32dfeb4
|
||||
Revises: f1ca58b2f2ec
|
||||
Create Date: 2025-01-26 16:08:21.551022
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import datetime
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9aadf32dfeb4"
|
||||
down_revision = "f1ca58b2f2ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create user_folder table without parent_id
|
||||
op.create_table(
|
||||
"user_folder",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
|
||||
sa.Column("name", sa.String(length=255), nullable=True),
|
||||
sa.Column("description", sa.String(length=255), nullable=True),
|
||||
sa.Column("display_priority", sa.Integer(), nullable=True, default=0),
|
||||
sa.Column("created_at", sa.DateTime(), default=datetime.datetime.utcnow),
|
||||
)
|
||||
|
||||
# Create user_file table with folder_id instead of parent_folder_id
|
||||
op.create_table(
|
||||
"user_file",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True),
|
||||
sa.Column(
|
||||
"folder_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_folder.id"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("file_type", sa.String(), nullable=True),
|
||||
sa.Column("file_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("document_id", sa.String(length=255), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(),
|
||||
default=datetime.datetime.utcnow,
|
||||
),
|
||||
)
|
||||
|
||||
# Create persona__user_file table
|
||||
op.create_table(
|
||||
"persona__user_file",
|
||||
sa.Column(
|
||||
"persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True
|
||||
),
|
||||
sa.Column(
|
||||
"user_file_id",
|
||||
sa.Integer(),
|
||||
sa.ForeignKey("user_file.id"),
|
||||
primary_key=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the persona__user_file table
|
||||
op.drop_table("persona__user_file")
|
||||
# Drop the user_file table
|
||||
op.drop_table("user_file")
|
||||
# Drop the user_folder table
|
||||
op.drop_table("user_folder")
|
||||
@@ -1,33 +0,0 @@
|
||||
"""add passthrough auth to tool
|
||||
|
||||
Revision ID: f1ca58b2f2ec
|
||||
Revises: c7bf5721733e
|
||||
Create Date: 2024-03-19
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f1ca58b2f2ec"
|
||||
down_revision: Union[str, None] = "c7bf5721733e"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add passthrough_auth column to tool table with default value of False
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column(
|
||||
"passthrough_auth", sa.Boolean(), nullable=False, server_default=sa.false()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove passthrough_auth column from tool table
|
||||
op.drop_column("tool", "passthrough_auth")
|
||||
@@ -32,7 +32,6 @@ def perform_ttl_management_task(
|
||||
|
||||
@celery_app.task(
|
||||
name="check_ttl_management_task",
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
@@ -57,7 +56,6 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
|
||||
@celery_app.task(
|
||||
name="autogenerate_usage_report_task",
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
|
||||
|
||||
@@ -1,72 +1,30 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
cloud_tasks_to_schedule as base_cloud_tasks_to_schedule,
|
||||
)
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
tasks_to_schedule as base_tasks_to_schedule,
|
||||
)
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
ee_cloud_tasks_to_schedule = [
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
},
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-ttl-management",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
ee_tasks_to_schedule: list[dict] = []
|
||||
|
||||
if not MULTI_TENANT:
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return ee_cloud_tasks_to_schedule + base_cloud_tasks_to_schedule
|
||||
return base_cloud_tasks_to_schedule
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
|
||||
@@ -4,20 +4,6 @@ import os
|
||||
# Applicable for OIDC Auth
|
||||
OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL", "")
|
||||
|
||||
# Applicable for OIDC Auth, allows you to override the scopes that
|
||||
# are requested from the OIDC provider. Currently used when passing
|
||||
# over access tokens to tool calls and the tool needs more scopes
|
||||
OIDC_SCOPE_OVERRIDE: list[str] | None = None
|
||||
_OIDC_SCOPE_OVERRIDE = os.environ.get("OIDC_SCOPE_OVERRIDE")
|
||||
|
||||
if _OIDC_SCOPE_OVERRIDE:
|
||||
try:
|
||||
OIDC_SCOPE_OVERRIDE = [
|
||||
scope.strip() for scope in _OIDC_SCOPE_OVERRIDE.split(",")
|
||||
]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Applicable for SAML Auth
|
||||
SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_config"
|
||||
|
||||
|
||||
@@ -98,9 +98,10 @@ def get_page_of_chat_sessions(
|
||||
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
|
||||
|
||||
subquery = (
|
||||
select(ChatSession.id)
|
||||
select(ChatSession.id, ChatSession.time_created)
|
||||
.filter(*conditions)
|
||||
.order_by(desc(ChatSession.time_created), ChatSession.id)
|
||||
.order_by(ChatSession.id, desc(ChatSession.time_created))
|
||||
.distinct(ChatSession.id)
|
||||
.limit(page_size)
|
||||
.offset(page_num * page_size)
|
||||
.subquery()
|
||||
@@ -117,11 +118,7 @@ def get_page_of_chat_sessions(
|
||||
ChatMessage.chat_message_feedbacks
|
||||
),
|
||||
)
|
||||
.order_by(
|
||||
desc(ChatSession.time_created),
|
||||
ChatSession.id,
|
||||
asc(ChatMessage.id), # Ensure chronological message order
|
||||
)
|
||||
.order_by(desc(ChatSession.time_created), asc(ChatMessage.id))
|
||||
)
|
||||
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
@@ -42,22 +42,24 @@ def _fetch_permissions_for_permission_ids(
|
||||
if not permission_info or not doc_id:
|
||||
return []
|
||||
|
||||
# Check cache first for all permission IDs
|
||||
permissions = [
|
||||
_PERMISSION_ID_PERMISSION_MAP[pid]
|
||||
for pid in permission_ids
|
||||
if pid in _PERMISSION_ID_PERMISSION_MAP
|
||||
]
|
||||
|
||||
# If we found all permissions in cache, return them
|
||||
if len(permissions) == len(permission_ids):
|
||||
return permissions
|
||||
|
||||
owner_email = permission_info.get("owner_email")
|
||||
|
||||
drive_service = get_drive_service(
|
||||
creds=google_drive_connector.creds,
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
)
|
||||
|
||||
# Otherwise, fetch all permissions and update cache
|
||||
fetched_permissions = execute_paginated_retrieval(
|
||||
retrieval_function=drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
@@ -67,6 +69,7 @@ def _fetch_permissions_for_permission_ids(
|
||||
)
|
||||
|
||||
permissions_for_doc_id = []
|
||||
# Update cache and return all permissions
|
||||
for permission in fetched_permissions:
|
||||
permissions_for_doc_id.append(permission)
|
||||
_PERMISSION_ID_PERMISSION_MAP[permission["id"]] = permission
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
from fastapi import FastAPI
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from httpx_oauth.clients.openid import BASE_SCOPES
|
||||
from httpx_oauth.clients.openid import OpenID
|
||||
|
||||
from ee.onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE
|
||||
from ee.onyx.configs.app_configs import OPENID_CONFIG_URL
|
||||
from ee.onyx.server.analytics.api import router as analytics_router
|
||||
from ee.onyx.server.auth_check import check_ee_router_auth
|
||||
@@ -90,13 +88,7 @@ def get_application() -> FastAPI:
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
OpenID(
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
OPENID_CONFIG_URL,
|
||||
# BASE_SCOPES is the same as not setting this
|
||||
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
|
||||
),
|
||||
OpenID(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL),
|
||||
auth_backend,
|
||||
USER_AUTH_SECRET,
|
||||
associate_by_email=True,
|
||||
|
||||
@@ -23,6 +23,7 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
preferences_data = cast(
|
||||
Mapping[str, Any], store.load(KV_NO_AUTH_USER_PREFERENCES_KEY)
|
||||
)
|
||||
print("preferences_data", preferences_data)
|
||||
return UserPreferences(**preferences_data)
|
||||
except KvKeyNotFoundError:
|
||||
return UserPreferences(
|
||||
|
||||
@@ -47,8 +47,3 @@ class UserUpdate(schemas.BaseUserUpdate):
|
||||
Role updates are not allowed through the user update endpoint for security reasons
|
||||
Role changes should be handled through a separate, admin-only process
|
||||
"""
|
||||
|
||||
|
||||
class AuthBackend(str, Enum):
|
||||
REDIS = "redis"
|
||||
POSTGRES = "postgres"
|
||||
|
||||
@@ -33,8 +33,6 @@ from fastapi_users.authentication import AuthenticationBackend
|
||||
from fastapi_users.authentication import CookieTransport
|
||||
from fastapi_users.authentication import RedisStrategy
|
||||
from fastapi_users.authentication import Strategy
|
||||
from fastapi_users.authentication.strategy.db import AccessTokenDatabase
|
||||
from fastapi_users.authentication.strategy.db import DatabaseStrategy
|
||||
from fastapi_users.exceptions import UserAlreadyExists
|
||||
from fastapi_users.jwt import decode_jwt
|
||||
from fastapi_users.jwt import generate_jwt
|
||||
@@ -54,15 +52,13 @@ from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.email_utils import send_forgot_password_email
|
||||
from onyx.auth.email_utils import send_user_verification_email
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserUpdate
|
||||
from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import REDIS_AUTH_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
@@ -78,7 +74,6 @@ from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import PASSWORD_SPECIAL_CHARS
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.api_key import fetch_user_for_api_key
|
||||
from onyx.db.auth import get_access_token_db
|
||||
from onyx.db.auth import get_default_admin_user_emails
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.auth import get_user_db
|
||||
@@ -87,7 +82,6 @@ from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_with_tenant
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
@@ -215,7 +209,6 @@ def verify_email_domain(email: str) -> None:
|
||||
class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
reset_password_token_secret = USER_AUTH_SECRET
|
||||
verification_token_secret = USER_AUTH_SECRET
|
||||
verification_token_lifetime_seconds = AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
|
||||
user_db: SQLAlchemyUserDatabase[User, uuid.UUID]
|
||||
|
||||
@@ -587,14 +580,6 @@ def get_redis_strategy() -> RedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
|
||||
)
|
||||
|
||||
|
||||
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
"""
|
||||
A custom strategy that fetches the actual async Redis connection inside each method.
|
||||
@@ -603,7 +588,7 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lifetime_seconds: Optional[int] = SESSION_EXPIRE_TIME_SECONDS,
|
||||
lifetime_seconds: Optional[int] = REDIS_AUTH_EXPIRE_TIME_SECONDS,
|
||||
key_prefix: str = REDIS_AUTH_KEY_PREFIX,
|
||||
):
|
||||
self.lifetime_seconds = lifetime_seconds
|
||||
@@ -652,16 +637,9 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
await redis.delete(f"{self.key_prefix}{token}")
|
||||
|
||||
|
||||
if AUTH_BACKEND == AuthBackend.REDIS:
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
|
||||
)
|
||||
elif AUTH_BACKEND == AuthBackend.POSTGRES:
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="postgres", transport=cookie_transport, get_strategy=get_database_strategy
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid auth backend: {AUTH_BACKEND}")
|
||||
auth_backend = AuthenticationBackend(
|
||||
name="redis", transport=cookie_transport, get_strategy=get_redis_strategy
|
||||
)
|
||||
|
||||
|
||||
class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
|
||||
@@ -23,7 +23,8 @@ from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa_constants import VESPA_CONFIG_SERVER_URL
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
@@ -279,6 +280,51 @@ def wait_for_db(sender: Any, **kwargs: Any) -> None:
|
||||
return
|
||||
|
||||
|
||||
def wait_for_vespa(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
|
||||
ready = False
|
||||
time_start = time.monotonic()
|
||||
logger.info("Vespa: Readiness probe starting.")
|
||||
while True:
|
||||
try:
|
||||
client = get_vespa_http_client()
|
||||
response = client.get(f"{VESPA_CONFIG_SERVER_URL}/state/v1/health")
|
||||
response.raise_for_status()
|
||||
|
||||
response_dict = response.json()
|
||||
if response_dict["status"]["code"] == "up":
|
||||
ready = True
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > WAIT_LIMIT:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}"
|
||||
)
|
||||
|
||||
time.sleep(WAIT_INTERVAL)
|
||||
|
||||
if not ready:
|
||||
msg = (
|
||||
f"Vespa: Readiness probe did not succeed within the timeout "
|
||||
f"({WAIT_LIMIT} seconds). Exiting..."
|
||||
)
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
logger.info("Vespa: Readiness probe succeeded. Continuing...")
|
||||
return
|
||||
|
||||
|
||||
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
|
||||
@@ -464,13 +510,3 @@ def reset_tenant_id(
|
||||
) -> None:
|
||||
"""Signal handler to reset tenant ID in context var after task ends."""
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
|
||||
def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Raises WorkerShutdown if the timeout is reached."""
|
||||
|
||||
if not wait_for_vespa_with_timeout():
|
||||
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
@@ -7,6 +8,7 @@ from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.signals import beat_init
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import SqlEngine
|
||||
@@ -79,7 +81,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
cloud_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": task.get("kwargs", {}),
|
||||
"kwargs": {},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(f"Adding options to task {task_name}: {options}")
|
||||
@@ -130,25 +132,21 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
# get current schedule and extract current tenants
|
||||
current_schedule = self.schedule.items()
|
||||
|
||||
# there are no more per tenant beat tasks, so comment this out
|
||||
# NOTE: we may not actualy need this scheduler any more and should
|
||||
# test reverting to a regular beat schedule implementation
|
||||
current_tenants = set()
|
||||
for task_name, _ in current_schedule:
|
||||
task_name = cast(str, task_name)
|
||||
if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
|
||||
continue
|
||||
|
||||
# current_tenants = set()
|
||||
# for task_name, _ in current_schedule:
|
||||
# task_name = cast(str, task_name)
|
||||
# if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
|
||||
# continue
|
||||
if "_" in task_name:
|
||||
# example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
|
||||
# -> "12345678-abcd-efgh-ijkl-12345678"
|
||||
current_tenants.add(task_name.split("_")[-1])
|
||||
logger.info(f"Found {len(current_tenants)} existing items in schedule")
|
||||
|
||||
# if "_" in task_name:
|
||||
# # example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
|
||||
# # -> "12345678-abcd-efgh-ijkl-12345678"
|
||||
# current_tenants.add(task_name.split("_")[-1])
|
||||
# logger.info(f"Found {len(current_tenants)} existing items in schedule")
|
||||
|
||||
# for tenant_id in tenant_ids:
|
||||
# if tenant_id not in current_tenants:
|
||||
# logger.info(f"Processing new tenant: {tenant_id}")
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id not in current_tenants:
|
||||
logger.info(f"Processing new tenant: {tenant_id}")
|
||||
|
||||
new_schedule = self._generate_schedule(tenant_ids)
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -68,7 +68,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -63,7 +63,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -86,7 +86,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
app_base.wait_for_vespa(sender, **kwargs)
|
||||
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
|
||||
@@ -16,241 +16,125 @@ from shared_configs.configs import MULTI_TENANT
|
||||
# it's only important that they run relatively regularly
|
||||
BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# hack to slow down task dispatch in the cloud until
|
||||
# we have a better implementation (backpressure, etc)
|
||||
CLOUD_BEAT_SCHEDULE_MULTIPLIER = 8
|
||||
|
||||
# tasks that only run in the cloud
|
||||
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be filtered
|
||||
# the name attribute must start with ONYX_CELERY_CLOUD_PREFIX = "cloud" to be filtered
|
||||
# by the DynamicTenantScheduler
|
||||
cloud_tasks_to_schedule = [
|
||||
# cloud specific tasks
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
|
||||
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
|
||||
"schedule": timedelta(hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
# remaining tasks are cloud generators for per tenant tasks
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-indexing",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"task": OnyxCeleryTask.CLOUD_CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-prune",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=30 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-background-processes",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(minutes=5 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
cloud_tasks_to_schedule.append(
|
||||
# tasks that run in either self-hosted on cloud
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOWEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-background-processes",
|
||||
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"schedule": timedelta(minutes=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
if not MULTI_TENANT:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(
|
||||
hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER
|
||||
), # Check every hour
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# tasks that run in either self-hosted on cloud
|
||||
tasks_to_schedule: list[dict] = []
|
||||
|
||||
if not MULTI_TENANT:
|
||||
tasks_to_schedule.extend(
|
||||
[
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-pruning",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-background-processes",
|
||||
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"schedule": timedelta(minutes=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return cloud_tasks_to_schedule
|
||||
|
||||
@@ -33,7 +33,6 @@ class TaskDependencyError(RuntimeError):
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
@@ -140,6 +139,13 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
submitted=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
)
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
|
||||
try:
|
||||
@@ -178,13 +184,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
)
|
||||
if tasks_generated is None:
|
||||
raise ValueError("RedisConnectorDeletion.generate_tasks returned None")
|
||||
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
)
|
||||
|
||||
except TaskDependencyError:
|
||||
redis_connector.delete.set_fence(None)
|
||||
raise
|
||||
|
||||
@@ -91,7 +91,6 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
|
||||
@@ -91,7 +91,6 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@ from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.background.celery.tasks.indexing.utils import _should_index
|
||||
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
@@ -25,12 +26,15 @@ from onyx.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingMode
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
@@ -45,7 +49,6 @@ from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
@@ -65,12 +68,15 @@ logger = setup_logger()
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
debug_tenants = {
|
||||
"tenant_i-043470d740845ec56",
|
||||
"tenant_82b497ce-88aa-4fbd-841a-92cae43529c8",
|
||||
}
|
||||
time_start = time.monotonic()
|
||||
|
||||
tasks_created = 0
|
||||
locked = False
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
@@ -117,6 +123,16 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
# kick off index attempts
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
# debugging logic - remove after we're done
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
@@ -125,12 +141,30 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
db_session
|
||||
)
|
||||
for search_settings_instance in search_settings_list:
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair search settings lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings_instance.id
|
||||
)
|
||||
if redis_connector_index.fenced:
|
||||
continue
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing get_connector_credential_pair_from_id: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -138,10 +172,28 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing get_last_attempt_for_cc_pair: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair should index: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
search_settings_primary = False
|
||||
if search_settings_instance.id == search_settings_list[0].id:
|
||||
search_settings_primary = True
|
||||
@@ -174,6 +226,15 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
cc_pair.id, None, db_session
|
||||
)
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair try_creating_indexing_task: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
# using a task queue and only allowing one task per cc_pair/search_setting
|
||||
# prevents us from starving out certain attempts
|
||||
attempt_id = try_creating_indexing_task(
|
||||
@@ -194,6 +255,24 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
)
|
||||
tasks_created += 1
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing cc_pair try_creating_indexing_task finished: "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
# debugging logic - remove after we're done
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing unfenced lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
@@ -203,7 +282,24 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
db_session, redis_client
|
||||
)
|
||||
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing after get unfenced lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
# debugging logic - remove after we're done
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing unfenced attempt id lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
attempt = get_index_attempt(db_session, attempt_id)
|
||||
@@ -221,6 +317,15 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
attempt.id, db_session, failure_reason=failure_reason
|
||||
)
|
||||
|
||||
# debugging logic - remove after we're done
|
||||
if tenant_id in debug_tenants:
|
||||
ttl = redis_client.ttl(OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK)
|
||||
task_logger.info(
|
||||
f"check_for_indexing validate fences lock: "
|
||||
f"tenant={tenant_id} "
|
||||
f"ttl={ttl}"
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
# we want to run this less frequently than the overall task
|
||||
if not redis_client.exists(OnyxRedisSignals.VALIDATE_INDEXING_FENCES):
|
||||
@@ -229,7 +334,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_indexing_fences(
|
||||
tenant_id, redis_client_replica, redis_client_celery, lock_beat
|
||||
tenant_id, self.app, redis_client, redis_client_celery, lock_beat
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating indexing fences")
|
||||
@@ -569,9 +674,6 @@ def connector_indexing_proxy_task(
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
# renew watchdog signal (this has a shorter timeout than set_active)
|
||||
redis_connector_index.set_watchdog(True)
|
||||
|
||||
# renew active signal
|
||||
redis_connector_index.set_active()
|
||||
|
||||
@@ -678,10 +780,67 @@ def connector_indexing_proxy_task(
|
||||
)
|
||||
continue
|
||||
|
||||
redis_connector_index.set_watchdog(False)
|
||||
task_logger.info(
|
||||
f"Indexing watchdog - finished: attempt={index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_CHECK_FOR_INDEXING,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def cloud_check_for_indexing(self: Task) -> bool | None:
|
||||
"""a lightweight task used to kick off individual check tasks for each tenant."""
|
||||
time_start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CLOUD_CHECK_INDEXING_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
for tenant_id in tenant_ids:
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
lock_beat.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
kwargs=dict(
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=BEAT_EXPIRES_DEFAULT,
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during cloud indexing check")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error("cloud_check_for_indexing - Lock not owned on completion")
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(
|
||||
f"cloud_check_for_indexing finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -291,20 +291,17 @@ def validate_indexing_fence(
|
||||
|
||||
def validate_indexing_fences(
|
||||
tenant_id: str | None,
|
||||
r_replica: Redis,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
"""Validates all indexing fences for this tenant ... aka makes sure
|
||||
indexing tasks sent to celery are still in flight.
|
||||
"""
|
||||
reserved_indexing_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
# Use replica for this because the worst thing that happens
|
||||
# is that we don't run the validation on this pass
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -14,16 +14,8 @@ from onyx.db.models import LLMProvider
|
||||
|
||||
def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
# Handle case where response is wrapped in a "data" field
|
||||
if isinstance(model_list_json, dict):
|
||||
if "data" in model_list_json:
|
||||
model_list_json = model_list_json["data"]
|
||||
elif "models" in model_list_json:
|
||||
model_list_json = model_list_json["models"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid response from API - expected dict with 'data' or "
|
||||
f"'models' field, got {type(model_list_json)}"
|
||||
)
|
||||
if isinstance(model_list_json, dict) and "data" in model_list_json:
|
||||
model_list_json = model_list_json["data"]
|
||||
|
||||
if not isinstance(model_list_json, list):
|
||||
raise ValueError(
|
||||
@@ -35,18 +27,11 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
for item in model_list_json:
|
||||
if isinstance(item, str):
|
||||
model_names.append(item)
|
||||
elif isinstance(item, dict):
|
||||
if "model_name" in item:
|
||||
model_names.append(item["model_name"])
|
||||
elif "id" in item:
|
||||
model_names.append(item["id"])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid item in model list - expected dict with model_name or id, got {type(item)}"
|
||||
)
|
||||
elif isinstance(item, dict) and "model_name" in item:
|
||||
model_names.append(item["model_name"])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid item in model list - expected string or dict, got {type(item)}"
|
||||
f"Invalid item in model list - expected string or dict with model_name, got {type(item)}"
|
||||
)
|
||||
|
||||
return model_names
|
||||
@@ -54,7 +39,6 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from itertools import islice
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
@@ -13,34 +10,26 @@ from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import SyncRecord
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
|
||||
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
_MONITORING_TIME_LIMIT = _MONITORING_SOFT_TIME_LIMIT + 60 # 6 minutes
|
||||
|
||||
@@ -52,12 +41,6 @@ _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT = (
|
||||
"monitoring_connector_index_attempt_run_success:{cc_pair_id}:{index_attempt_id}"
|
||||
)
|
||||
|
||||
_FINAL_METRIC_KEY_FMT = "sync_final_metrics:{sync_type}:{entity_id}:{sync_record_id}"
|
||||
|
||||
_SYNC_START_LATENCY_KEY_FMT = (
|
||||
"sync_start_latency:{sync_type}:{entity_id}:{sync_record_id}"
|
||||
)
|
||||
|
||||
|
||||
def _mark_metric_as_emitted(redis_std: Redis, key: str) -> None:
|
||||
"""Mark a metric as having been emitted by setting a Redis key with expiration"""
|
||||
@@ -120,7 +103,6 @@ class Metric(BaseModel):
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
task_logger.info(f"Emitting metric: {data}")
|
||||
optional_telemetry(
|
||||
record_type=RecordType.METRIC,
|
||||
data=data,
|
||||
@@ -195,375 +177,225 @@ def _build_connector_start_latency_metric(
|
||||
|
||||
start_latency = (recent_attempt.time_started - desired_start_time).total_seconds()
|
||||
|
||||
task_logger.info(
|
||||
f"Start latency for index attempt {recent_attempt.id}: {start_latency:.2f}s "
|
||||
f"(desired: {desired_start_time}, actual: {recent_attempt.time_started})"
|
||||
)
|
||||
|
||||
job_id = build_job_id("connector", str(cc_pair.id), str(recent_attempt.id))
|
||||
|
||||
return Metric(
|
||||
key=metric_key,
|
||||
name="connector_start_latency",
|
||||
value=start_latency,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
},
|
||||
tags={},
|
||||
)
|
||||
|
||||
|
||||
def _build_connector_final_metrics(
|
||||
def _build_run_success_metrics(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
recent_attempts: list[IndexAttempt],
|
||||
redis_std: Redis,
|
||||
) -> list[Metric]:
|
||||
"""
|
||||
Final metrics for connector index attempts:
|
||||
- Boolean success/fail metric
|
||||
- If success, emit:
|
||||
* duration (seconds)
|
||||
* doc_count
|
||||
"""
|
||||
metrics = []
|
||||
for attempt in recent_attempts:
|
||||
metric_key = _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT.format(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempt_id=attempt.id,
|
||||
)
|
||||
|
||||
if _has_metric_been_emitted(redis_std, metric_key):
|
||||
task_logger.info(
|
||||
f"Skipping final metrics for connector {cc_pair.connector.id} "
|
||||
f"index attempt {attempt.id}, already emitted."
|
||||
f"Skipping metric for connector {cc_pair.connector.id} "
|
||||
f"index attempt {attempt.id} because it has already been "
|
||||
"emitted"
|
||||
)
|
||||
continue
|
||||
|
||||
# We only emit final metrics if the attempt is in a terminal state
|
||||
if attempt.status not in [
|
||||
if attempt.status in [
|
||||
IndexingStatus.SUCCESS,
|
||||
IndexingStatus.FAILED,
|
||||
IndexingStatus.CANCELED,
|
||||
]:
|
||||
# Not finished; skip
|
||||
continue
|
||||
|
||||
job_id = build_job_id("connector", str(cc_pair.id), str(attempt.id))
|
||||
success = attempt.status == IndexingStatus.SUCCESS
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=metric_key, # We'll mark the same key for any final metrics
|
||||
name="connector_run_succeeded",
|
||||
value=success,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
"status": attempt.status.value,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if success:
|
||||
# Make sure we have valid time_started
|
||||
if attempt.time_started and attempt.time_updated:
|
||||
duration_seconds = (
|
||||
attempt.time_updated - attempt.time_started
|
||||
).total_seconds()
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=None, # No need for a new key, or you can reuse the same if you prefer
|
||||
name="connector_index_duration_seconds",
|
||||
value=duration_seconds,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Index attempt {attempt.id} succeeded but has missing time "
|
||||
f"(time_started={attempt.time_started}, time_updated={attempt.time_updated})."
|
||||
)
|
||||
|
||||
# For doc counts, choose whichever field is more relevant
|
||||
doc_count = attempt.total_docs_indexed or 0
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=None,
|
||||
name="connector_index_doc_count",
|
||||
value=doc_count,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"connector_id": str(cc_pair.connector.id),
|
||||
"source": str(cc_pair.connector.source),
|
||||
},
|
||||
key=metric_key,
|
||||
name="connector_run_succeeded",
|
||||
value=attempt.status == IndexingStatus.SUCCESS,
|
||||
tags={"source": str(cc_pair.connector.source)},
|
||||
)
|
||||
)
|
||||
|
||||
_mark_metric_as_emitted(redis_std, metric_key)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
|
||||
"""Collect metrics about connector runs from the past hour"""
|
||||
# NOTE: use get_db_current_time since the IndexAttempt times are set based on DB time
|
||||
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
|
||||
|
||||
# Get all connector credential pairs
|
||||
cc_pairs = db_session.scalars(select(ConnectorCredentialPair)).all()
|
||||
# Might be more than one search setting, or just one
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
|
||||
metrics = []
|
||||
|
||||
# If you want to process each cc_pair against each search setting:
|
||||
for cc_pair in cc_pairs:
|
||||
for search_settings in active_search_settings:
|
||||
recent_attempts = (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair.id,
|
||||
IndexAttempt.search_settings_id == search_settings.id,
|
||||
)
|
||||
.order_by(IndexAttempt.time_created.desc())
|
||||
.limit(2)
|
||||
.all()
|
||||
# Get all attempts in the last hour
|
||||
recent_attempts = (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair.id,
|
||||
IndexAttempt.time_created >= one_hour_ago,
|
||||
)
|
||||
.order_by(IndexAttempt.time_created.desc())
|
||||
.all()
|
||||
)
|
||||
most_recent_attempt = recent_attempts[0] if recent_attempts else None
|
||||
second_most_recent_attempt = (
|
||||
recent_attempts[1] if len(recent_attempts) > 1 else None
|
||||
)
|
||||
|
||||
if not recent_attempts:
|
||||
continue
|
||||
# if no metric to emit, skip
|
||||
if most_recent_attempt is None:
|
||||
continue
|
||||
|
||||
most_recent_attempt = recent_attempts[0]
|
||||
second_most_recent_attempt = (
|
||||
recent_attempts[1] if len(recent_attempts) > 1 else None
|
||||
)
|
||||
# Connector start latency
|
||||
start_latency_metric = _build_connector_start_latency_metric(
|
||||
cc_pair, most_recent_attempt, second_most_recent_attempt, redis_std
|
||||
)
|
||||
if start_latency_metric:
|
||||
metrics.append(start_latency_metric)
|
||||
|
||||
if one_hour_ago > most_recent_attempt.time_created:
|
||||
continue
|
||||
|
||||
# Connector start latency
|
||||
start_latency_metric = _build_connector_start_latency_metric(
|
||||
cc_pair, most_recent_attempt, second_most_recent_attempt, redis_std
|
||||
)
|
||||
|
||||
if start_latency_metric:
|
||||
metrics.append(start_latency_metric)
|
||||
|
||||
# Connector run success/failure
|
||||
final_metrics = _build_connector_final_metrics(
|
||||
cc_pair, recent_attempts, redis_std
|
||||
)
|
||||
metrics.extend(final_metrics)
|
||||
# Connector run success/failure
|
||||
run_success_metrics = _build_run_success_metrics(
|
||||
cc_pair, recent_attempts, redis_std
|
||||
)
|
||||
metrics.extend(run_success_metrics)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]:
|
||||
"""
|
||||
Collect metrics for document set and group syncing:
|
||||
- Success/failure status
|
||||
- Start latency (always)
|
||||
- Duration & doc count (only if success)
|
||||
- Throughput (docs/min) (only if success)
|
||||
"""
|
||||
"""Collect metrics about document set and group syncing speed"""
|
||||
# NOTE: use get_db_current_time since the SyncRecord times are set based on DB time
|
||||
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
|
||||
|
||||
# Get all sync records that ended in the last hour
|
||||
# Get all sync records from the last hour
|
||||
recent_sync_records = db_session.scalars(
|
||||
select(SyncRecord)
|
||||
.where(SyncRecord.sync_end_time.isnot(None))
|
||||
.where(SyncRecord.sync_end_time >= one_hour_ago)
|
||||
.order_by(SyncRecord.sync_end_time.desc())
|
||||
.where(SyncRecord.sync_start_time >= one_hour_ago)
|
||||
.order_by(SyncRecord.sync_start_time.desc())
|
||||
).all()
|
||||
|
||||
task_logger.info(
|
||||
f"Collecting sync metrics for {len(recent_sync_records)} sync records"
|
||||
)
|
||||
|
||||
metrics = []
|
||||
|
||||
for sync_record in recent_sync_records:
|
||||
# Build a job_id for correlation
|
||||
job_id = build_job_id("sync_record", str(sync_record.id))
|
||||
# Skip if no end time (sync still in progress)
|
||||
if not sync_record.sync_end_time:
|
||||
continue
|
||||
|
||||
# Emit a SUCCESS/FAIL boolean metric
|
||||
# Use a single Redis key to avoid re-emitting final metrics
|
||||
final_metric_key = _FINAL_METRIC_KEY_FMT.format(
|
||||
sync_type=sync_record.sync_type,
|
||||
entity_id=sync_record.entity_id,
|
||||
sync_record_id=sync_record.id,
|
||||
# Check if we already emitted a metric for this sync record
|
||||
metric_key = (
|
||||
f"sync_speed:{sync_record.sync_type}:"
|
||||
f"{sync_record.entity_id}:{sync_record.id}"
|
||||
)
|
||||
if not _has_metric_been_emitted(redis_std, final_metric_key):
|
||||
# Evaluate success
|
||||
sync_succeeded = sync_record.sync_status == SyncStatus.SUCCESS
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=final_metric_key,
|
||||
name="sync_run_succeeded",
|
||||
value=sync_succeeded,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
"status": str(sync_record.sync_status),
|
||||
},
|
||||
)
|
||||
if _has_metric_been_emitted(redis_std, metric_key):
|
||||
task_logger.debug(
|
||||
f"Skipping metric for sync record {sync_record.id} "
|
||||
"because it has already been emitted"
|
||||
)
|
||||
continue
|
||||
|
||||
# If successful, emit additional metrics
|
||||
if sync_succeeded:
|
||||
if sync_record.sync_end_time and sync_record.sync_start_time:
|
||||
duration_seconds = (
|
||||
sync_record.sync_end_time - sync_record.sync_start_time
|
||||
).total_seconds()
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Invalid times for sync record {sync_record.id}: "
|
||||
f"start={sync_record.sync_start_time}, end={sync_record.sync_end_time}"
|
||||
)
|
||||
duration_seconds = None
|
||||
# Calculate sync duration in minutes
|
||||
sync_duration_mins = (
|
||||
sync_record.sync_end_time - sync_record.sync_start_time
|
||||
).total_seconds() / 60.0
|
||||
|
||||
doc_count = sync_record.num_docs_synced or 0
|
||||
|
||||
sync_speed = None
|
||||
if duration_seconds and duration_seconds > 0:
|
||||
duration_mins = duration_seconds / 60.0
|
||||
sync_speed = (
|
||||
doc_count / duration_mins if duration_mins > 0 else None
|
||||
)
|
||||
|
||||
# Emit duration, doc count, speed
|
||||
if duration_seconds is not None:
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=None,
|
||||
name="sync_duration_seconds",
|
||||
value=duration_seconds,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Invalid sync record {sync_record.id} with no duration"
|
||||
)
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=None,
|
||||
name="sync_doc_count",
|
||||
value=doc_count,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if sync_speed is not None:
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=None,
|
||||
name="sync_speed_docs_per_min",
|
||||
value=sync_speed,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
else:
|
||||
task_logger.error(
|
||||
f"Invalid sync record {sync_record.id} with no duration"
|
||||
)
|
||||
|
||||
# Mark final metrics as emitted so we don't re-emit
|
||||
_mark_metric_as_emitted(redis_std, final_metric_key)
|
||||
|
||||
# Emit start latency
|
||||
start_latency_key = _SYNC_START_LATENCY_KEY_FMT.format(
|
||||
sync_type=sync_record.sync_type,
|
||||
entity_id=sync_record.entity_id,
|
||||
sync_record_id=sync_record.id,
|
||||
# Calculate sync speed (docs/min) - avoid division by zero
|
||||
sync_speed = (
|
||||
sync_record.num_docs_synced / sync_duration_mins
|
||||
if sync_duration_mins > 0
|
||||
else None
|
||||
)
|
||||
if not _has_metric_been_emitted(redis_std, start_latency_key):
|
||||
# Get the entity's last update time based on sync type
|
||||
entity: DocumentSet | UserGroup | None = None
|
||||
if sync_record.sync_type == SyncType.DOCUMENT_SET:
|
||||
entity = db_session.scalar(
|
||||
select(DocumentSet).where(DocumentSet.id == sync_record.entity_id)
|
||||
)
|
||||
elif sync_record.sync_type == SyncType.USER_GROUP:
|
||||
entity = db_session.scalar(
|
||||
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
|
||||
)
|
||||
else:
|
||||
task_logger.info(
|
||||
f"Skipping sync record {sync_record.id} of type {sync_record.sync_type}."
|
||||
)
|
||||
continue
|
||||
|
||||
if entity is None:
|
||||
task_logger.error(
|
||||
f"Could not find entity for sync record {sync_record.id} "
|
||||
f"(type={sync_record.sync_type}, id={sync_record.entity_id})."
|
||||
)
|
||||
continue
|
||||
if sync_speed is None:
|
||||
task_logger.error(
|
||||
"Something went wrong with sync speed calculation. "
|
||||
f"Sync record: {sync_record.id}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Calculate start latency in seconds:
|
||||
# (actual sync start) - (last modified time)
|
||||
if entity.time_last_modified_by_user and sync_record.sync_start_time:
|
||||
start_latency = (
|
||||
sync_record.sync_start_time - entity.time_last_modified_by_user
|
||||
).total_seconds()
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=metric_key,
|
||||
name="sync_speed_docs_per_min",
|
||||
value=sync_speed,
|
||||
tags={
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
"status": str(sync_record.sync_status),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
if start_latency < 0:
|
||||
task_logger.error(
|
||||
f"Negative start latency for sync record {sync_record.id} "
|
||||
f"(start={sync_record.sync_start_time}, entity_modified={entity.time_last_modified_by_user})"
|
||||
)
|
||||
continue
|
||||
# Add sync start latency metric
|
||||
start_latency_key = (
|
||||
f"sync_start_latency:{sync_record.sync_type}"
|
||||
f":{sync_record.entity_id}:{sync_record.id}"
|
||||
)
|
||||
if _has_metric_been_emitted(redis_std, start_latency_key):
|
||||
task_logger.debug(
|
||||
f"Skipping start latency metric for sync record {sync_record.id} "
|
||||
"because it has already been emitted"
|
||||
)
|
||||
continue
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=start_latency_key,
|
||||
name="sync_start_latency_seconds",
|
||||
value=start_latency,
|
||||
tags={
|
||||
"job_id": job_id,
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
# Get the entity's last update time based on sync type
|
||||
entity: DocumentSet | UserGroup | None = None
|
||||
if sync_record.sync_type == SyncType.DOCUMENT_SET:
|
||||
entity = db_session.scalar(
|
||||
select(DocumentSet).where(DocumentSet.id == sync_record.entity_id)
|
||||
)
|
||||
elif sync_record.sync_type == SyncType.USER_GROUP:
|
||||
entity = db_session.scalar(
|
||||
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
|
||||
)
|
||||
else:
|
||||
# Skip other sync types
|
||||
task_logger.debug(
|
||||
f"Skipping sync record {sync_record.id} "
|
||||
f"with type {sync_record.sync_type} "
|
||||
f"and id {sync_record.entity_id} "
|
||||
"because it is not a document set or user group"
|
||||
)
|
||||
continue
|
||||
|
||||
_mark_metric_as_emitted(redis_std, start_latency_key)
|
||||
if entity is None:
|
||||
task_logger.error(
|
||||
f"Could not find entity for sync record {sync_record.id} "
|
||||
f"with type {sync_record.sync_type} and id {sync_record.entity_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Calculate start latency in seconds
|
||||
start_latency = (
|
||||
sync_record.sync_start_time - entity.time_last_modified_by_user
|
||||
).total_seconds()
|
||||
if start_latency < 0:
|
||||
task_logger.error(
|
||||
f"Start latency is negative for sync record {sync_record.id} "
|
||||
f"with type {sync_record.sync_type} and id {sync_record.entity_id}."
|
||||
"This is likely because the entity was updated between the time the "
|
||||
"time the sync finished and this job ran. Skipping."
|
||||
)
|
||||
continue
|
||||
|
||||
metrics.append(
|
||||
Metric(
|
||||
key=start_latency_key,
|
||||
name="sync_start_latency_seconds",
|
||||
value=start_latency,
|
||||
tags={
|
||||
"sync_type": str(sync_record.sync_type),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return metrics
|
||||
|
||||
|
||||
def build_job_id(
|
||||
job_type: Literal["connector", "sync_record"],
|
||||
primary_id: str,
|
||||
secondary_id: str | None = None,
|
||||
) -> str:
|
||||
if job_type == "connector":
|
||||
if secondary_id is None:
|
||||
raise ValueError(
|
||||
"secondary_id (attempt_id) is required for connector job_type"
|
||||
)
|
||||
return f"connector:{primary_id}:attempt:{secondary_id}"
|
||||
elif job_type == "sync_record":
|
||||
return f"sync_record:{primary_id}"
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
ignore_result=True,
|
||||
soft_time_limit=_MONITORING_SOFT_TIME_LIMIT,
|
||||
time_limit=_MONITORING_TIME_LIMIT,
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
@@ -601,7 +433,6 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
lambda: _collect_connector_metrics(db_session, redis_std),
|
||||
lambda: _collect_sync_metrics(db_session, redis_std),
|
||||
]
|
||||
|
||||
# Collect and log each metric
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
for metric_fn in metric_functions:
|
||||
@@ -625,116 +456,3 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
lock_monitoring.release()
|
||||
|
||||
task_logger.info("Background monitoring task finished")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
|
||||
)
|
||||
def cloud_check_alembic() -> bool | None:
|
||||
"""A task to verify that all tenants are on the same alembic revision.
|
||||
|
||||
This check is expected to fail if a cloud alembic migration is currently running
|
||||
across all tenants.
|
||||
|
||||
TODO: have the cloud migration script set an activity signal that this check
|
||||
uses to know it doesn't make sense to run a check at the present time.
|
||||
"""
|
||||
time_start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.CLOUD_CHECK_ALEMBIC_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
tenant_to_revision: dict[str, str | None] = {}
|
||||
revision_counts: dict[str, int] = {}
|
||||
out_of_date_tenants: dict[str, str | None] = {}
|
||||
top_revision: str = ""
|
||||
|
||||
try:
|
||||
# map each tenant_id to its revision
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
for tenant_id in tenant_ids:
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
lock_beat.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
if tenant_id is None:
|
||||
continue
|
||||
|
||||
with get_session_with_tenant(tenant_id=None) as session:
|
||||
result = session.execute(
|
||||
text(f'SELECT * FROM "{tenant_id}".alembic_version LIMIT 1')
|
||||
)
|
||||
|
||||
result_scalar: str | None = result.scalar_one_or_none()
|
||||
tenant_to_revision[tenant_id] = result_scalar
|
||||
|
||||
# get the total count of each revision
|
||||
for k, v in tenant_to_revision.items():
|
||||
if v is None:
|
||||
continue
|
||||
|
||||
revision_counts[v] = revision_counts.get(v, 0) + 1
|
||||
|
||||
# get the revision with the most counts
|
||||
sorted_revision_counts = sorted(
|
||||
revision_counts.items(), key=lambda item: item[1], reverse=True
|
||||
)
|
||||
|
||||
if len(sorted_revision_counts) == 0:
|
||||
task_logger.error(
|
||||
f"cloud_check_alembic - No revisions found for {len(tenant_ids)} tenant ids!"
|
||||
)
|
||||
else:
|
||||
top_revision, _ = sorted_revision_counts[0]
|
||||
|
||||
# build a list of out of date tenants
|
||||
for k, v in tenant_to_revision.items():
|
||||
if v == top_revision:
|
||||
continue
|
||||
|
||||
out_of_date_tenants[k] = v
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during cloud alembic check")
|
||||
raise
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error("cloud_check_alembic - Lock not owned on completion")
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
|
||||
if len(out_of_date_tenants) > 0:
|
||||
task_logger.error(
|
||||
f"Found out of date tenants: "
|
||||
f"num_out_of_date_tenants={len(out_of_date_tenants)} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"revision={top_revision}"
|
||||
)
|
||||
for k, v in islice(out_of_date_tenants.items(), 5):
|
||||
task_logger.info(f"Out of date tenant: tenant={k} revision={v}")
|
||||
else:
|
||||
task_logger.info(
|
||||
f"All tenants are up to date: num_tenants={len(tenant_ids)} revision={top_revision}"
|
||||
)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(
|
||||
f"cloud_check_alembic finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -78,7 +78,6 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
|
||||
@@ -1,22 +1,15 @@
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
|
||||
import httpx
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
from tenacity import RetryError
|
||||
|
||||
from onyx.access.access import get_access_for_document
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from onyx.db.document import delete_documents_complete__no_commit
|
||||
from onyx.db.document import fetch_chunk_count_for_document
|
||||
@@ -25,15 +18,11 @@ from onyx.db.document import get_document_connector_count
|
||||
from onyx.db.document import mark_document_as_modified
|
||||
from onyx.db.document import mark_document_as_synced
|
||||
from onyx.db.document_set import fetch_document_sets_for_document
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.document_index.document_index_utils import get_both_index_names
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3
|
||||
|
||||
@@ -210,78 +199,3 @@ def document_by_cc_pair_cleanup_task(
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
ignore_result=True,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def cloud_beat_task_generator(
|
||||
self: Task,
|
||||
task_name: str,
|
||||
queue: str = OnyxCeleryTask.DEFAULT,
|
||||
priority: int = OnyxCeleryPriority.MEDIUM,
|
||||
expires: int = BEAT_EXPIRES_DEFAULT,
|
||||
) -> bool | None:
|
||||
"""a lightweight task used to kick off individual beat tasks per tenant."""
|
||||
time_start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
lock_beat: RedisLock = redis_client.lock(
|
||||
f"{OnyxRedisLocks.CLOUD_BEAT_TASK_GENERATOR_LOCK}:{task_name}",
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
for tenant_id in tenant_ids:
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
lock_beat.reacquire()
|
||||
last_lock_time = current_time
|
||||
|
||||
# needed in the cloud
|
||||
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
|
||||
continue
|
||||
|
||||
self.app.send_task(
|
||||
task_name,
|
||||
kwargs=dict(
|
||||
tenant_id=tenant_id,
|
||||
),
|
||||
queue=queue,
|
||||
priority=priority,
|
||||
expires=expires,
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during cloud_beat_task_generator")
|
||||
finally:
|
||||
if not lock_beat.owned():
|
||||
task_logger.error(
|
||||
"cloud_beat_task_generator - Lock not owned on completion"
|
||||
)
|
||||
redis_lock_dump(lock_beat, redis_client)
|
||||
else:
|
||||
lock_beat.release()
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(
|
||||
f"cloud_beat_task_generator finished: "
|
||||
f"task={task_name} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -78,7 +78,6 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
@@ -98,7 +97,6 @@ logger = setup_logger()
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
@@ -737,7 +735,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"Connector indexing: could not parse composite_id from {fence_key}"
|
||||
f"monitor_ccpair_indexing_taskset: could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
@@ -787,7 +785,6 @@ def monitor_ccpair_indexing_taskset(
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
|
||||
# Verify: if the generator isn't complete, the task must not be in READY state
|
||||
# inner = get_completion / generator_complete not signaled
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
@@ -833,7 +830,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Connector indexing - Transient exception marking index attempt as failed: "
|
||||
"monitor_ccpair_indexing_taskset - transient exception marking index attempt as failed: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -843,20 +840,6 @@ def monitor_ccpair_indexing_taskset(
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
if redis_connector_index.watchdog_signaled():
|
||||
# if the generator is complete, don't clean up until the watchdog has exited
|
||||
task_logger.info(
|
||||
f"Connector indexing - Delaying finalization until watchdog has exited: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
@@ -873,20 +856,11 @@ def monitor_ccpair_indexing_taskset(
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
ignore_result=True,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
@shared_task(name=OnyxCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
"""This is a celery beat task that monitors and finalizes various long running tasks.
|
||||
|
||||
The name monitor_vespa_sync is a bit of a misnomer since it checks many different tasks
|
||||
now. Should change that at some point.
|
||||
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
For many tasks, the count is 0, that means all tasks finished and we should clean up.
|
||||
If the count is 0, that means all tasks finished and we should clean up.
|
||||
|
||||
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
|
||||
do anything too expensive in this function!
|
||||
@@ -902,17 +876,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Replica usage notes
|
||||
#
|
||||
# False negatives are OK. (aka fail to to see a key that exists on the master).
|
||||
# We simply skip the monitoring work and it will be caught on the next pass.
|
||||
#
|
||||
# False positives are not OK, and are possible if we clear a fence on the master and
|
||||
# then read from the replica. In this case, monitoring work could be done on a fence
|
||||
# that no longer exists. To avoid this, we scan from the replica, but double check
|
||||
# the result on the master.
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
@@ -972,19 +935,17 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
# scan and monitor activity to completion
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
if r_replica.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
if r.exists(RedisConnectorCredentialPair.get_fence_key()):
|
||||
monitor_connector_taskset(r)
|
||||
timings["connector"] = time.monotonic() - phase_start
|
||||
timings["connector_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorDelete.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
if r.exists(key_bytes):
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
lock_beat.reacquire()
|
||||
|
||||
timings["connector_deletion"] = time.monotonic() - phase_start
|
||||
@@ -994,74 +955,66 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisDocumentSet.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
if r.exists(key_bytes):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
timings["documentset"] = time.monotonic() - phase_start
|
||||
timings["documentset_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisUserGroup.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
if r.exists(key_bytes):
|
||||
monitor_usergroup_taskset = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"onyx.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
monitor_usergroup_taskset = fetch_versioned_implementation_with_fallback(
|
||||
"onyx.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
timings["usergroup"] = time.monotonic() - phase_start
|
||||
timings["usergroup_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
if r.exists(key_bytes):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
timings["pruning"] = time.monotonic() - phase_start
|
||||
timings["pruning_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorIndex.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
if r.exists(key_bytes):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
timings["indexing"] = time.monotonic() - phase_start
|
||||
timings["indexing_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
phase_start = time.monotonic()
|
||||
lock_beat.reacquire()
|
||||
for key_bytes in r_replica.scan_iter(
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
if r.exists(key_bytes):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(
|
||||
tenant_id, key_bytes, r, db_session
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(tenant_id, key_bytes, r, db_session)
|
||||
lock_beat.reacquire()
|
||||
|
||||
timings["permissions"] = time.monotonic() - phase_start
|
||||
timings["permissions_ttl"] = r.ttl(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -1092,8 +1045,6 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
def vespa_metadata_sync_task(
|
||||
self: Task, document_id: str, tenant_id: str | None
|
||||
) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
curr_ind_name, sec_ind_name = get_both_index_names(db_session)
|
||||
@@ -1144,13 +1095,7 @@ def vespa_metadata_sync_task(
|
||||
# r = get_redis_client(tenant_id=tenant_id)
|
||||
# r.delete(redis_syncing_key)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=sync "
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
task_logger.info(f"doc={document_id} action=sync chunks={chunks_affected}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
except Exception as ex:
|
||||
|
||||
@@ -15,12 +15,11 @@ from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.llm.utils import check_message_tokens
|
||||
from onyx.llm.utils import message_to_prompt_and_imgs
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from onyx.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
@@ -32,16 +31,15 @@ def default_build_system_message(
|
||||
prompt_config: PromptConfig,
|
||||
) -> SystemMessage | None:
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
tag_handled_prompt = handle_onyx_date_awareness(
|
||||
system_prompt,
|
||||
prompt_config,
|
||||
add_additional_info_if_no_tag=prompt_config.datetime_aware,
|
||||
)
|
||||
if prompt_config.datetime_aware:
|
||||
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
|
||||
|
||||
if not tag_handled_prompt:
|
||||
if not system_prompt:
|
||||
return None
|
||||
|
||||
return SystemMessage(content=tag_handled_prompt)
|
||||
system_msg = SystemMessage(content=system_prompt)
|
||||
|
||||
return system_msg
|
||||
|
||||
|
||||
def default_build_user_message(
|
||||
@@ -66,11 +64,8 @@ def default_build_user_message(
|
||||
else user_query
|
||||
)
|
||||
user_prompt = user_prompt.strip()
|
||||
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
|
||||
user_msg = HumanMessage(
|
||||
content=build_content_with_imgs(tag_handled_prompt, files)
|
||||
if files
|
||||
else tag_handled_prompt
|
||||
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
|
||||
)
|
||||
return user_msg
|
||||
|
||||
@@ -91,7 +86,6 @@ class AnswerPromptBuilder:
|
||||
provider_type=llm_config.model_provider,
|
||||
model_name=llm_config.model_name,
|
||||
)
|
||||
self.llm_config = llm_config
|
||||
self.llm_tokenizer_encode_func = cast(
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
@@ -100,21 +94,12 @@ class AnswerPromptBuilder:
|
||||
(
|
||||
self.message_history,
|
||||
self.history_token_cnts,
|
||||
) = translate_history_to_basemessages(
|
||||
message_history,
|
||||
exclude_images=not model_supports_image_input(
|
||||
self.llm_config.model_name,
|
||||
self.llm_config.model_provider,
|
||||
),
|
||||
)
|
||||
) = translate_history_to_basemessages(message_history)
|
||||
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(
|
||||
user_message,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
|
||||
)
|
||||
|
||||
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
||||
|
||||
@@ -21,9 +21,9 @@ from onyx.prompts.constants import DEFAULT_IGNORE_STATEMENT
|
||||
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT
|
||||
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT_FOR_TOOL_CALLING
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from onyx.prompts.prompt_utils import build_complete_context_str
|
||||
from onyx.prompts.prompt_utils import build_task_prompt_reminders
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.prompts.token_counts import ADDITIONAL_INFO_TOKEN_CNT
|
||||
from onyx.prompts.token_counts import (
|
||||
CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT,
|
||||
@@ -127,11 +127,10 @@ def build_citations_system_message(
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
if prompt_config.include_citations:
|
||||
system_prompt += REQUIRE_CITATION_STATEMENT
|
||||
tag_handled_prompt = handle_onyx_date_awareness(
|
||||
system_prompt, prompt_config, add_additional_info_if_no_tag=True
|
||||
)
|
||||
if prompt_config.datetime_aware:
|
||||
system_prompt = add_date_time_to_prompt(prompt_str=system_prompt)
|
||||
|
||||
return SystemMessage(content=tag_handled_prompt)
|
||||
return SystemMessage(content=system_prompt)
|
||||
|
||||
|
||||
def build_citations_user_message(
|
||||
|
||||
@@ -9,8 +9,8 @@ from onyx.llm.utils import message_to_prompt_and_imgs
|
||||
from onyx.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.direct_qa_prompts import JSON_PROMPT
|
||||
from onyx.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from onyx.prompts.prompt_utils import build_complete_context_str
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
|
||||
|
||||
def _build_strong_llm_quotes_prompt(
|
||||
@@ -39,11 +39,10 @@ def _build_strong_llm_quotes_prompt(
|
||||
language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "",
|
||||
).strip()
|
||||
|
||||
tag_handled_prompt = handle_onyx_date_awareness(
|
||||
full_prompt, prompt, add_additional_info_if_no_tag=True
|
||||
)
|
||||
if prompt.datetime_aware:
|
||||
full_prompt = add_date_time_to_prompt(prompt_str=full_prompt)
|
||||
|
||||
return HumanMessage(content=tag_handled_prompt)
|
||||
return HumanMessage(content=full_prompt)
|
||||
|
||||
|
||||
def build_quotes_user_message(
|
||||
|
||||
@@ -11,7 +11,6 @@ from onyx.llm.utils import build_content_with_imgs
|
||||
|
||||
def translate_onyx_msg_to_langchain(
|
||||
msg: ChatMessage | PreviousMessage,
|
||||
exclude_images: bool = False,
|
||||
) -> BaseMessage:
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
@@ -19,9 +18,7 @@ def translate_onyx_msg_to_langchain(
|
||||
# attached. Just ignore them for now.
|
||||
if not isinstance(msg, ChatMessage):
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(
|
||||
msg.message, files, message_type=msg.message_type, exclude_images=exclude_images
|
||||
)
|
||||
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
@@ -35,12 +32,9 @@ def translate_onyx_msg_to_langchain(
|
||||
|
||||
def translate_history_to_basemessages(
|
||||
history: list[ChatMessage] | list["PreviousMessage"],
|
||||
exclude_images: bool = False,
|
||||
) -> tuple[list[BaseMessage], list[int]]:
|
||||
history_basemessages = [
|
||||
translate_onyx_msg_to_langchain(msg, exclude_images)
|
||||
for msg in history
|
||||
if msg.token_count != 0
|
||||
translate_onyx_msg_to_langchain(msg) for msg in history if msg.token_count != 0
|
||||
]
|
||||
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
|
||||
return history_basemessages, history_token_counts
|
||||
|
||||
@@ -3,7 +3,6 @@ import os
|
||||
import urllib.parse
|
||||
from typing import cast
|
||||
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentIndexType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
@@ -56,12 +55,12 @@ MASK_CREDENTIAL_PREFIX = (
|
||||
os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false"
|
||||
)
|
||||
|
||||
AUTH_BACKEND = AuthBackend(os.environ.get("AUTH_BACKEND") or AuthBackend.REDIS.value)
|
||||
REDIS_AUTH_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
SESSION_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS")
|
||||
or os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS")
|
||||
or 86400 * 7
|
||||
os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
# Default request timeout, mostly used by connectors
|
||||
@@ -93,12 +92,6 @@ OAUTH_CLIENT_SECRET = (
|
||||
|
||||
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
|
||||
|
||||
# Duration (in seconds) for which the FastAPI Users JWT token remains valid in the user's browser.
|
||||
# By default, this is set to match the Redis expiry time for consistency.
|
||||
AUTH_COOKIE_EXPIRE_TIME_SECONDS = int(
|
||||
os.environ.get("AUTH_COOKIE_EXPIRE_TIME_SECONDS") or 86400 * 7
|
||||
) # 7 days
|
||||
|
||||
# for basic auth
|
||||
REQUIRE_EMAIL_VERIFICATION = (
|
||||
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
|
||||
@@ -200,8 +193,6 @@ REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost"
|
||||
REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379))
|
||||
REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or ""
|
||||
|
||||
# this assumes that other redis settings remain the same as the primary
|
||||
REDIS_REPLICA_HOST = os.environ.get("REDIS_REPLICA_HOST") or REDIS_HOST
|
||||
|
||||
REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:"
|
||||
|
||||
|
||||
@@ -294,8 +294,7 @@ class OnyxRedisLocks:
|
||||
SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot"
|
||||
ANONYMOUS_USER_ENABLED = "anonymous_user_enabled"
|
||||
|
||||
CLOUD_BEAT_TASK_GENERATOR_LOCK = "da_lock:cloud_beat_task_generator"
|
||||
CLOUD_CHECK_ALEMBIC_BEAT_LOCK = "da_lock:cloud_check_alembic"
|
||||
CLOUD_CHECK_INDEXING_BEAT_LOCK = "da_lock:cloud_check_indexing_beat"
|
||||
|
||||
|
||||
class OnyxRedisSignals:
|
||||
@@ -318,11 +317,6 @@ ONYX_CLOUD_TENANT_ID = "cloud"
|
||||
|
||||
|
||||
class OnyxCeleryTask:
|
||||
DEFAULT = "celery"
|
||||
|
||||
CLOUD_BEAT_TASK_GENERATOR = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_generate_beat_tasks"
|
||||
CLOUD_CHECK_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_alembic"
|
||||
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
CHECK_FOR_INDEXING = "check_for_indexing"
|
||||
@@ -330,10 +324,8 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
|
||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
|
||||
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
@@ -351,6 +343,8 @@ class OnyxCeleryTask:
|
||||
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
|
||||
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
|
||||
|
||||
CLOUD_CHECK_FOR_INDEXING = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_for_indexing"
|
||||
|
||||
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
|
||||
|
||||
@@ -71,20 +71,10 @@ class AirtableConnector(LoadConnector):
|
||||
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_field_values(
|
||||
field_id: str,
|
||||
field_info: Any,
|
||||
field_type: str,
|
||||
base_id: str,
|
||||
table_id: str,
|
||||
view_id: str | None,
|
||||
record_id: str,
|
||||
) -> list[tuple[str, str]]:
|
||||
def _get_field_value(self, field_info: Any, field_type: str) -> list[str]:
|
||||
"""
|
||||
Extract value(s) + links from a field regardless of its type.
|
||||
Attachments are represented as multiple sections, and therefore
|
||||
returned as a list of tuples (value, link).
|
||||
Extract value(s) from a field regardless of its type.
|
||||
Returns either a single string or list of strings for attachments.
|
||||
"""
|
||||
if field_info is None:
|
||||
return []
|
||||
@@ -95,11 +85,8 @@ class AirtableConnector(LoadConnector):
|
||||
if field_type == "multipleRecordLinks":
|
||||
return []
|
||||
|
||||
# default link to use for non-attachment fields
|
||||
default_link = f"https://airtable.com/{base_id}/{table_id}/{record_id}"
|
||||
|
||||
if field_type == "multipleAttachments":
|
||||
attachment_texts: list[tuple[str, str]] = []
|
||||
attachment_texts: list[str] = []
|
||||
for attachment in field_info:
|
||||
url = attachment.get("url")
|
||||
filename = attachment.get("filename", "")
|
||||
@@ -122,7 +109,6 @@ class AirtableConnector(LoadConnector):
|
||||
if attachment_content:
|
||||
try:
|
||||
file_ext = get_file_ext(filename)
|
||||
attachment_id = attachment["id"]
|
||||
attachment_text = extract_file_text(
|
||||
BytesIO(attachment_content),
|
||||
filename,
|
||||
@@ -130,20 +116,7 @@ class AirtableConnector(LoadConnector):
|
||||
extension=file_ext,
|
||||
)
|
||||
if attachment_text:
|
||||
# slightly nicer loading experience if we can specify the view ID
|
||||
if view_id:
|
||||
attachment_link = (
|
||||
f"https://airtable.com/{base_id}/{table_id}/{view_id}/{record_id}"
|
||||
f"/{field_id}/{attachment_id}?blocks=hide"
|
||||
)
|
||||
else:
|
||||
attachment_link = (
|
||||
f"https://airtable.com/{base_id}/{table_id}/{record_id}"
|
||||
f"/{field_id}/{attachment_id}?blocks=hide"
|
||||
)
|
||||
attachment_texts.append(
|
||||
(f"{filename}:\n{attachment_text}", attachment_link)
|
||||
)
|
||||
attachment_texts.append(f"{filename}:\n{attachment_text}")
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to process attachment {filename}: {str(e)}"
|
||||
@@ -158,12 +131,12 @@ class AirtableConnector(LoadConnector):
|
||||
combined.append(collab_name)
|
||||
if collab_email:
|
||||
combined.append(f"({collab_email})")
|
||||
return [(" ".join(combined) if combined else str(field_info), default_link)]
|
||||
return [" ".join(combined) if combined else str(field_info)]
|
||||
|
||||
if isinstance(field_info, list):
|
||||
return [(item, default_link) for item in field_info]
|
||||
return [str(item) for item in field_info]
|
||||
|
||||
return [(str(field_info), default_link)]
|
||||
return [str(field_info)]
|
||||
|
||||
def _should_be_metadata(self, field_type: str) -> bool:
|
||||
"""Determine if a field type should be treated as metadata."""
|
||||
@@ -171,12 +144,10 @@ class AirtableConnector(LoadConnector):
|
||||
|
||||
def _process_field(
|
||||
self,
|
||||
field_id: str,
|
||||
field_name: str,
|
||||
field_info: Any,
|
||||
field_type: str,
|
||||
table_id: str,
|
||||
view_id: str | None,
|
||||
record_id: str,
|
||||
) -> tuple[list[Section], dict[str, Any]]:
|
||||
"""
|
||||
@@ -194,21 +165,12 @@ class AirtableConnector(LoadConnector):
|
||||
return [], {}
|
||||
|
||||
# Get the value(s) for the field
|
||||
field_value_and_links = self._extract_field_values(
|
||||
field_id=field_id,
|
||||
field_info=field_info,
|
||||
field_type=field_type,
|
||||
base_id=self.base_id,
|
||||
table_id=table_id,
|
||||
view_id=view_id,
|
||||
record_id=record_id,
|
||||
)
|
||||
if len(field_value_and_links) == 0:
|
||||
field_values = self._get_field_value(field_info, field_type)
|
||||
if len(field_values) == 0:
|
||||
return [], {}
|
||||
|
||||
# Determine if it should be metadata or a section
|
||||
if self._should_be_metadata(field_type):
|
||||
field_values = [value for value, _ in field_value_and_links]
|
||||
if len(field_values) > 1:
|
||||
return [], {field_name: field_values}
|
||||
return [], {field_name: field_values[0]}
|
||||
@@ -216,7 +178,7 @@ class AirtableConnector(LoadConnector):
|
||||
# Otherwise, create relevant sections
|
||||
sections = [
|
||||
Section(
|
||||
link=link,
|
||||
link=f"https://airtable.com/{self.base_id}/{table_id}/{record_id}",
|
||||
text=(
|
||||
f"{field_name}:\n"
|
||||
"------------------------\n"
|
||||
@@ -224,7 +186,7 @@ class AirtableConnector(LoadConnector):
|
||||
"------------------------"
|
||||
),
|
||||
)
|
||||
for text, link in field_value_and_links
|
||||
for text in field_values
|
||||
]
|
||||
return sections, {}
|
||||
|
||||
@@ -257,7 +219,6 @@ class AirtableConnector(LoadConnector):
|
||||
primary_field_value = (
|
||||
fields.get(primary_field_name) if primary_field_name else None
|
||||
)
|
||||
view_id = table_schema.views[0].id if table_schema.views else None
|
||||
|
||||
for field_schema in table_schema.fields:
|
||||
field_name = field_schema.name
|
||||
@@ -265,12 +226,10 @@ class AirtableConnector(LoadConnector):
|
||||
field_type = field_schema.type
|
||||
|
||||
field_sections, field_metadata = self._process_field(
|
||||
field_id=field_schema.id,
|
||||
field_name=field_name,
|
||||
field_info=field_val,
|
||||
field_type=field_type,
|
||||
table_id=table_id,
|
||||
view_id=view_id,
|
||||
record_id=record_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -232,29 +232,20 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
}
|
||||
|
||||
# Get labels
|
||||
label_dicts = (
|
||||
confluence_object.get("metadata", {}).get("labels", {}).get("results", [])
|
||||
)
|
||||
page_labels = [label.get("name") for label in label_dicts if label.get("name")]
|
||||
label_dicts = confluence_object["metadata"]["labels"]["results"]
|
||||
page_labels = [label["name"] for label in label_dicts]
|
||||
if page_labels:
|
||||
doc_metadata["labels"] = page_labels
|
||||
|
||||
# Get last modified and author email
|
||||
version_dict = confluence_object.get("version", {})
|
||||
last_modified = (
|
||||
datetime_from_string(version_dict.get("when"))
|
||||
if version_dict.get("when")
|
||||
else None
|
||||
)
|
||||
author_email = version_dict.get("by", {}).get("email")
|
||||
|
||||
title = confluence_object.get("title", "Untitled Document")
|
||||
last_modified = datetime_from_string(confluence_object["version"]["when"])
|
||||
author_email = confluence_object["version"].get("by", {}).get("email")
|
||||
|
||||
return Document(
|
||||
id=object_url,
|
||||
sections=[Section(link=object_url, text=object_text)],
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=title,
|
||||
semantic_identifier=confluence_object["title"],
|
||||
doc_updated_at=last_modified,
|
||||
primary_owners=(
|
||||
[BasicExpertInfo(email=author_email)] if author_email else None
|
||||
|
||||
@@ -6,7 +6,6 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from urllib.parse import unquote
|
||||
|
||||
import msal # type: ignore
|
||||
from office365.graph_client import GraphClient # type: ignore
|
||||
@@ -83,13 +82,8 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
sites_index = parts.index("sites")
|
||||
site_url = "/".join(parts[: sites_index + 2])
|
||||
folder = (
|
||||
"/".join(unquote(part) for part in parts[sites_index + 2 :])
|
||||
if len(parts) > sites_index + 2
|
||||
else None
|
||||
parts[sites_index + 2] if len(parts) > sites_index + 2 else None
|
||||
)
|
||||
# Handling for new URL structure
|
||||
if folder and folder.startswith("Shared Documents/"):
|
||||
folder = folder[len("Shared Documents/") :]
|
||||
site_data_list.append(
|
||||
SiteData(url=site_url, folder=folder, sites=[], driveitems=[])
|
||||
)
|
||||
@@ -117,19 +111,11 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
query = query.filter(filter_str)
|
||||
driveitems = query.execute_query()
|
||||
if element.folder:
|
||||
expected_path = f"/root:/{element.folder}"
|
||||
filtered_driveitems = [
|
||||
item
|
||||
for item in driveitems
|
||||
if item.parent_reference.path.endswith(expected_path)
|
||||
if element.folder in item.parent_reference.path
|
||||
]
|
||||
if len(filtered_driveitems) == 0:
|
||||
all_paths = [
|
||||
item.parent_reference.path for item in driveitems
|
||||
]
|
||||
logger.warning(
|
||||
f"Nothing found for folder '{expected_path}' in any of valid paths: {all_paths}"
|
||||
)
|
||||
element.driveitems.extend(filtered_driveitems)
|
||||
else:
|
||||
element.driveitems.extend(driveitems)
|
||||
|
||||
@@ -432,7 +432,7 @@ def get_paginated_index_attempts_for_cc_pair_id(
|
||||
stmt = stmt.order_by(IndexAttempt.time_started.desc())
|
||||
|
||||
# Apply pagination
|
||||
stmt = stmt.offset(page * page_size).limit(page_size)
|
||||
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
||||
|
||||
return list(db_session.execute(stmt).scalars().all())
|
||||
|
||||
|
||||
@@ -193,13 +193,13 @@ def fetch_input_prompts_by_user(
|
||||
"""
|
||||
Returns all prompts belonging to the user or public prompts,
|
||||
excluding those the user has specifically disabled.
|
||||
Also, if `user_id` is None and AUTH_TYPE is DISABLED, then all prompts are returned.
|
||||
"""
|
||||
|
||||
# Start with a basic query for InputPrompt
|
||||
query = select(InputPrompt)
|
||||
|
||||
# If we have a user, left join to InputPrompt__User so we can check "disabled"
|
||||
if user_id is not None:
|
||||
# If we have a user, left join to InputPrompt__User to check "disabled"
|
||||
IPU = aliased(InputPrompt__User)
|
||||
query = query.join(
|
||||
IPU,
|
||||
@@ -208,30 +208,25 @@ def fetch_input_prompts_by_user(
|
||||
)
|
||||
|
||||
# Exclude disabled prompts
|
||||
# i.e. keep only those where (IPU.disabled is NULL or False)
|
||||
query = query.where(or_(IPU.disabled.is_(None), IPU.disabled.is_(False)))
|
||||
|
||||
if include_public:
|
||||
# Return both user-owned and public prompts
|
||||
# user-owned or public
|
||||
query = query.where(
|
||||
or_(
|
||||
InputPrompt.user_id == user_id,
|
||||
InputPrompt.is_public,
|
||||
)
|
||||
(InputPrompt.user_id == user_id) | (InputPrompt.is_public)
|
||||
)
|
||||
else:
|
||||
# Return only user-owned prompts
|
||||
# only user-owned prompts
|
||||
query = query.where(InputPrompt.user_id == user_id)
|
||||
|
||||
else:
|
||||
# user_id is None
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# If auth is disabled, return all prompts
|
||||
query = query.where(True) # type: ignore
|
||||
elif include_public:
|
||||
# Anonymous usage
|
||||
query = query.where(InputPrompt.is_public)
|
||||
# If no user is logged in, get all prompts (public and private)
|
||||
if user_id is None and AUTH_TYPE == AuthType.DISABLED:
|
||||
query = query.where(True) # type: ignore
|
||||
|
||||
# Default to returning all prompts
|
||||
# If no user is logged in but we want to include public prompts
|
||||
elif include_public:
|
||||
query = query.where(InputPrompt.is_public)
|
||||
|
||||
if active is not None:
|
||||
query = query.where(InputPrompt.active == active)
|
||||
|
||||
@@ -205,11 +205,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
|
||||
)
|
||||
|
||||
folders: Mapped[list["UserFolder"]] = relationship(
|
||||
"UserFolder", back_populates="user"
|
||||
)
|
||||
files: Mapped[list["UserFile"]] = relationship("UserFile", back_populates="user")
|
||||
|
||||
|
||||
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
|
||||
pass
|
||||
@@ -1435,8 +1430,6 @@ class Tool(Base):
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
# whether to pass through the user's OAuth token as Authorization header
|
||||
passthrough_auth: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
|
||||
# Relationship to Persona through the association table
|
||||
@@ -1563,12 +1556,6 @@ class Persona(Base):
|
||||
secondary="persona__user_group",
|
||||
viewonly=True,
|
||||
)
|
||||
# Relationship to UserFile
|
||||
user_files: Mapped[list["UserFile"]] = relationship(
|
||||
"UserFile",
|
||||
secondary="persona__user_file",
|
||||
back_populates="assistants",
|
||||
)
|
||||
labels: Mapped[list["PersonaLabel"]] = relationship(
|
||||
"PersonaLabel",
|
||||
secondary=Persona__PersonaLabel.__table__,
|
||||
@@ -1585,15 +1572,6 @@ class Persona(Base):
|
||||
)
|
||||
|
||||
|
||||
class Persona__UserFile(Base):
|
||||
__tablename__ = "persona__user_file"
|
||||
|
||||
persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True)
|
||||
user_file_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_file.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class PersonaLabel(Base):
|
||||
__tablename__ = "persona_label"
|
||||
|
||||
@@ -2053,51 +2031,6 @@ class InputPrompt__User(Base):
|
||||
disabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
|
||||
class UserFolder(Base):
|
||||
__tablename__ = "user_folder"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int] = mapped_column(ForeignKey("user.id"), nullable=False)
|
||||
name: Mapped[str] = mapped_column(nullable=False)
|
||||
description: Mapped[str] = mapped_column(nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
default=datetime.datetime.utcnow
|
||||
)
|
||||
|
||||
user: Mapped["User"] = relationship(back_populates="folders")
|
||||
files: Mapped[list["UserFile"]] = relationship(back_populates="folder")
|
||||
|
||||
|
||||
class UserDocument(str, Enum):
|
||||
CHAT = "chat"
|
||||
RECENT = "recent"
|
||||
FILE = "file"
|
||||
|
||||
|
||||
class UserFile(Base):
|
||||
__tablename__ = "user_file"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||
user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id"), nullable=False)
|
||||
assistants: Mapped[list["Persona"]] = relationship(
|
||||
"Persona",
|
||||
secondary=Persona__UserFile.__table__,
|
||||
back_populates="user_files",
|
||||
)
|
||||
folder_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("user_folder.id"), nullable=True
|
||||
)
|
||||
|
||||
file_id: Mapped[str] = mapped_column(nullable=False)
|
||||
document_id: Mapped[str] = mapped_column(nullable=False)
|
||||
name: Mapped[str] = mapped_column(nullable=False)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
default=datetime.datetime.utcnow
|
||||
)
|
||||
user: Mapped["User"] = relationship(back_populates="files")
|
||||
folder: Mapped["UserFolder"] = relationship(back_populates="files")
|
||||
|
||||
|
||||
"""
|
||||
Multi-tenancy related tables
|
||||
"""
|
||||
|
||||
@@ -15,7 +15,6 @@ from onyx.db.models import User
|
||||
from onyx.db.persona import mark_persona_as_deleted
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.db.prompts import get_default_prompt
|
||||
from onyx.tools.built_in_tools import get_search_tool
|
||||
from onyx.utils.errors import EERequiredError
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
@@ -48,10 +47,6 @@ def create_slack_channel_persona(
|
||||
) -> Persona:
|
||||
"""NOTE: does not commit changes"""
|
||||
|
||||
search_tool = get_search_tool(db_session)
|
||||
if search_tool is None:
|
||||
raise ValueError("Search tool not found")
|
||||
|
||||
# create/update persona associated with the Slack channel
|
||||
persona_name = _build_persona_name(channel_name)
|
||||
default_prompt = get_default_prompt(db_session)
|
||||
@@ -65,7 +60,6 @@ def create_slack_channel_persona(
|
||||
llm_filter_extraction=enable_auto_filters,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
prompt_ids=[default_prompt.id],
|
||||
tool_ids=[search_tool.id],
|
||||
document_set_ids=document_set_ids,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
|
||||
@@ -8,64 +8,20 @@ from sqlalchemy.orm import Session
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import SyncRecord
|
||||
from onyx.setup import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def insert_sync_record(
|
||||
db_session: Session,
|
||||
entity_id: int,
|
||||
entity_id: int | None,
|
||||
sync_type: SyncType,
|
||||
) -> SyncRecord:
|
||||
"""Insert a new sync record into the database, cancelling any existing in-progress records.
|
||||
"""Insert a new sync record into the database.
|
||||
|
||||
Args:
|
||||
db_session: The database session to use
|
||||
entity_id: The ID of the entity being synced (document set ID, user group ID, etc.)
|
||||
sync_type: The type of sync operation
|
||||
"""
|
||||
# If an existing in-progress sync record exists, mark as cancelled
|
||||
existing_in_progress_sync_record = fetch_latest_sync_record(
|
||||
db_session, entity_id, sync_type, sync_status=SyncStatus.IN_PROGRESS
|
||||
)
|
||||
|
||||
if existing_in_progress_sync_record is not None:
|
||||
logger.info(
|
||||
f"Cancelling existing in-progress sync record {existing_in_progress_sync_record.id} "
|
||||
f"for entity_id={entity_id} sync_type={sync_type}"
|
||||
)
|
||||
mark_sync_records_as_cancelled(db_session, entity_id, sync_type)
|
||||
|
||||
return _create_sync_record(db_session, entity_id, sync_type)
|
||||
|
||||
|
||||
def mark_sync_records_as_cancelled(
|
||||
db_session: Session,
|
||||
entity_id: int | None,
|
||||
sync_type: SyncType,
|
||||
) -> None:
|
||||
stmt = (
|
||||
update(SyncRecord)
|
||||
.where(
|
||||
and_(
|
||||
SyncRecord.entity_id == entity_id,
|
||||
SyncRecord.sync_type == sync_type,
|
||||
SyncRecord.sync_status == SyncStatus.IN_PROGRESS,
|
||||
)
|
||||
)
|
||||
.values(sync_status=SyncStatus.CANCELED)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _create_sync_record(
|
||||
db_session: Session,
|
||||
entity_id: int | None,
|
||||
sync_type: SyncType,
|
||||
) -> SyncRecord:
|
||||
"""Create and insert a new sync record into the database."""
|
||||
sync_record = SyncRecord(
|
||||
entity_id=entity_id,
|
||||
sync_type=sync_type,
|
||||
@@ -83,7 +39,6 @@ def fetch_latest_sync_record(
|
||||
db_session: Session,
|
||||
entity_id: int,
|
||||
sync_type: SyncType,
|
||||
sync_status: SyncStatus | None = None,
|
||||
) -> SyncRecord | None:
|
||||
"""Fetch the most recent sync record for a given entity ID and status.
|
||||
|
||||
@@ -104,9 +59,6 @@ def fetch_latest_sync_record(
|
||||
.limit(1)
|
||||
)
|
||||
|
||||
if sync_status is not None:
|
||||
stmt = stmt.where(SyncRecord.sync_status == sync_status)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
|
||||
@@ -38,7 +38,6 @@ def create_tool(
|
||||
custom_headers: list[Header] | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
passthrough_auth: bool,
|
||||
) -> Tool:
|
||||
new_tool = Tool(
|
||||
name=name,
|
||||
@@ -49,7 +48,6 @@ def create_tool(
|
||||
if custom_headers
|
||||
else [],
|
||||
user_id=user_id,
|
||||
passthrough_auth=passthrough_auth,
|
||||
)
|
||||
db_session.add(new_tool)
|
||||
db_session.commit()
|
||||
@@ -64,7 +62,6 @@ def update_tool(
|
||||
custom_headers: list[Header] | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
passthrough_auth: bool | None,
|
||||
) -> Tool:
|
||||
tool = get_tool_by_id(tool_id, db_session)
|
||||
if tool is None:
|
||||
@@ -82,8 +79,6 @@ def update_tool(
|
||||
tool.custom_headers = [
|
||||
cast(HeaderItemDict, header.model_dump()) for header in custom_headers
|
||||
]
|
||||
if passthrough_auth is not None:
|
||||
tool.passthrough_auth = passthrough_auth
|
||||
db_session.commit()
|
||||
|
||||
return tool
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import UploadFile
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.server.documents.connector import upload_files
|
||||
from onyx.server.documents.models import FileUploadResponse
|
||||
|
||||
|
||||
def create_user_files(
|
||||
files: List[UploadFile],
|
||||
folder_id: int | None,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> FileUploadResponse:
|
||||
upload_response = upload_files(files, db_session)
|
||||
for file_path, file in zip(upload_response.file_paths, files):
|
||||
new_file = UserFile(
|
||||
user_id=user.id if user else None,
|
||||
folder_id=folder_id if folder_id != -1 else None,
|
||||
file_id=file_path,
|
||||
document_id=file_path,
|
||||
name=file.filename,
|
||||
)
|
||||
db_session.add(new_file)
|
||||
db_session.commit()
|
||||
return upload_response
|
||||
@@ -594,7 +594,6 @@ class VespaIndex(DocumentIndex):
|
||||
primary_index=index_name == self.index_name,
|
||||
)
|
||||
large_chunks_enabled = multipass_config.enable_large_chunks
|
||||
|
||||
enriched_doc_infos = VespaIndex.enrich_basic_chunk_info(
|
||||
index_name=index_name,
|
||||
http_client=http_client,
|
||||
@@ -663,7 +662,6 @@ class VespaIndex(DocumentIndex):
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=large_chunks_enabled,
|
||||
)
|
||||
|
||||
for doc_chunk_ids_batch in batch_generator(
|
||||
chunks_to_delete, BATCH_SIZE
|
||||
):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import re
|
||||
import time
|
||||
from typing import cast
|
||||
|
||||
import httpx
|
||||
@@ -8,10 +7,6 @@ from onyx.configs.app_configs import MANAGED_VESPA
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
|
||||
from onyx.document_index.vespa_constants import VESPA_APP_CONTAINER_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# NOTE: This does not seem to be used in reality despite the Vespa Docs pointing to this code
|
||||
# See here for reference: https://docs.vespa.ai/en/documents.html
|
||||
@@ -74,37 +69,3 @@ def get_vespa_http_client(no_timeout: bool = False, http2: bool = True) -> httpx
|
||||
timeout=None if no_timeout else VESPA_REQUEST_TIMEOUT,
|
||||
http2=http2,
|
||||
)
|
||||
|
||||
|
||||
def wait_for_vespa_with_timeout(wait_interval: int = 5, wait_limit: int = 60) -> bool:
|
||||
"""Waits for Vespa to become ready subject to a timeout.
|
||||
Returns True if Vespa is ready, False otherwise."""
|
||||
|
||||
time_start = time.monotonic()
|
||||
logger.info("Vespa: Readiness probe starting.")
|
||||
while True:
|
||||
try:
|
||||
client = get_vespa_http_client()
|
||||
response = client.get(f"{VESPA_APP_CONTAINER_URL}/state/v1/health")
|
||||
response.raise_for_status()
|
||||
|
||||
response_dict = response.json()
|
||||
if response_dict["status"]["code"] == "up":
|
||||
logger.info("Vespa: Readiness probe succeeded. Continuing...")
|
||||
return True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
if time_elapsed > wait_limit:
|
||||
logger.info(
|
||||
f"Vespa: Readiness probe did not succeed within the timeout "
|
||||
f"({wait_limit} seconds)."
|
||||
)
|
||||
return False
|
||||
|
||||
logger.info(
|
||||
f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={wait_limit:.1f}"
|
||||
)
|
||||
|
||||
time.sleep(wait_interval)
|
||||
|
||||
@@ -2,7 +2,7 @@ from onyx.key_value_store.interface import KeyValueStore
|
||||
from onyx.key_value_store.store import PgRedisKVStore
|
||||
|
||||
|
||||
def get_kv_store(tenant_id: str | None = None) -> KeyValueStore:
|
||||
def get_kv_store() -> KeyValueStore:
|
||||
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
|
||||
# It's read from the global thread level variable
|
||||
return PgRedisKVStore(tenant_id=tenant_id)
|
||||
return PgRedisKVStore()
|
||||
|
||||
@@ -31,27 +31,27 @@ class PgRedisKVStore(KeyValueStore):
|
||||
def __init__(
|
||||
self, redis_client: Redis | None = None, tenant_id: str | None = None
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
# If no redis_client is provided, fall back to the context var
|
||||
if redis_client is not None:
|
||||
self.redis_client = redis_client
|
||||
else:
|
||||
self.redis_client = get_redis_client(tenant_id=self.tenant_id)
|
||||
tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
self.redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@contextmanager
|
||||
def _get_session(self) -> Iterator[Session]:
|
||||
def get_session(self) -> Iterator[Session]:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine, expire_on_commit=False) as session:
|
||||
if MULTI_TENANT:
|
||||
if self.tenant_id == POSTGRES_DEFAULT_SCHEMA:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA:
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User must authenticate"
|
||||
)
|
||||
if not is_valid_schema_name(self.tenant_id):
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(status_code=400, detail="Invalid tenant ID")
|
||||
# Set the search_path to the tenant's schema
|
||||
session.execute(text(f'SET search_path = "{self.tenant_id}"'))
|
||||
session.execute(text(f'SET search_path = "{tenant_id}"'))
|
||||
yield session
|
||||
|
||||
def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None:
|
||||
@@ -66,7 +66,7 @@ class PgRedisKVStore(KeyValueStore):
|
||||
|
||||
encrypted_val = val if encrypt else None
|
||||
plain_val = val if not encrypt else None
|
||||
with self._get_session() as session:
|
||||
with self.get_session() as session:
|
||||
obj = session.query(KVStore).filter_by(key=key).first()
|
||||
if obj:
|
||||
obj.value = plain_val
|
||||
@@ -88,7 +88,7 @@ class PgRedisKVStore(KeyValueStore):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get value from Redis for key '{key}': {str(e)}")
|
||||
|
||||
with self._get_session() as session:
|
||||
with self.get_session() as session:
|
||||
obj = session.query(KVStore).filter_by(key=key).first()
|
||||
if not obj:
|
||||
raise KvKeyNotFoundError
|
||||
@@ -113,7 +113,7 @@ class PgRedisKVStore(KeyValueStore):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete value from Redis for key '{key}': {str(e)}")
|
||||
|
||||
with self._get_session() as session:
|
||||
with self.get_session() as session:
|
||||
result = session.query(KVStore).filter_by(key=key).delete() # type: ignore
|
||||
if result == 0:
|
||||
raise KvKeyNotFoundError
|
||||
|
||||
@@ -275,22 +275,17 @@ class DefaultMultiLLM(LLM):
|
||||
# addtional kwargs (and some kwargs MUST be passed in rather than set as
|
||||
# env variables)
|
||||
if custom_config:
|
||||
# Specifically pass in "vertex_credentials" / "vertex_location" as a
|
||||
# model_kwarg to the completion call for vertex AI. More details here:
|
||||
# Specifically pass in "vertex_credentials" as a model_kwarg to the
|
||||
# completion call for vertex AI. More details here:
|
||||
# https://docs.litellm.ai/docs/providers/vertex
|
||||
vertex_credentials_key = "vertex_credentials"
|
||||
vertex_location_key = "vertex_location"
|
||||
for k, v in custom_config.items():
|
||||
if model_provider == "vertex_ai":
|
||||
if k == vertex_credentials_key:
|
||||
model_kwargs[k] = v
|
||||
continue
|
||||
elif k == vertex_location_key:
|
||||
model_kwargs[k] = v
|
||||
continue
|
||||
|
||||
# for all values, set them as env variables
|
||||
os.environ[k] = v
|
||||
vertex_credentials = custom_config.get(vertex_credentials_key)
|
||||
if vertex_credentials and model_provider == "vertex_ai":
|
||||
model_kwargs[vertex_credentials_key] = vertex_credentials
|
||||
else:
|
||||
# standard case
|
||||
for k, v in custom_config.items():
|
||||
os.environ[k] = v
|
||||
|
||||
if extra_headers:
|
||||
model_kwargs.update({"extra_headers": extra_headers})
|
||||
|
||||
@@ -142,7 +142,6 @@ def build_content_with_imgs(
|
||||
img_urls: list[str] | None = None,
|
||||
b64_imgs: list[str] | None = None,
|
||||
message_type: MessageType = MessageType.USER,
|
||||
exclude_images: bool = False,
|
||||
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
|
||||
files = files or []
|
||||
|
||||
@@ -158,7 +157,7 @@ def build_content_with_imgs(
|
||||
|
||||
message_main_content = _build_content(message, files)
|
||||
|
||||
if exclude_images or (not img_files and not img_urls):
|
||||
if not img_files and not img_urls:
|
||||
return message_main_content
|
||||
|
||||
return cast(
|
||||
@@ -383,19 +382,9 @@ def _strip_colon_from_model_name(model_name: str) -> str:
|
||||
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
|
||||
|
||||
|
||||
def _find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | None:
|
||||
stripped_model_name = _strip_extra_provider_from_model_name(model_name)
|
||||
|
||||
model_names = [
|
||||
model_name,
|
||||
_strip_extra_provider_from_model_name(model_name),
|
||||
# Remove leading extra provider. Usually for cases where user has a
|
||||
# customer model proxy which appends another prefix
|
||||
# remove :XXXX from the end, if present. Needed for ollama.
|
||||
_strip_colon_from_model_name(model_name),
|
||||
_strip_colon_from_model_name(stripped_model_name),
|
||||
]
|
||||
|
||||
def _find_model_obj(
|
||||
model_map: dict, provider: str, model_names: list[str | None]
|
||||
) -> dict | None:
|
||||
# Filter out None values and deduplicate model names
|
||||
filtered_model_names = [name for name in model_names if name]
|
||||
|
||||
@@ -428,10 +417,21 @@ def get_llm_max_tokens(
|
||||
return GEN_AI_MAX_TOKENS
|
||||
|
||||
try:
|
||||
extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
|
||||
model_name
|
||||
)
|
||||
model_obj = _find_model_obj(
|
||||
model_map,
|
||||
model_provider,
|
||||
model_name,
|
||||
[
|
||||
model_name,
|
||||
# Remove leading extra provider. Usually for cases where user has a
|
||||
# customer model proxy which appends another prefix
|
||||
extra_provider_stripped_model_name,
|
||||
# remove :XXXX from the end, if present. Needed for ollama.
|
||||
_strip_colon_from_model_name(model_name),
|
||||
_strip_colon_from_model_name(extra_provider_stripped_model_name),
|
||||
],
|
||||
)
|
||||
if not model_obj:
|
||||
raise RuntimeError(
|
||||
@@ -523,23 +523,3 @@ def get_max_input_tokens(
|
||||
raise RuntimeError("No tokens for input for the LLM given settings")
|
||||
|
||||
return input_toks
|
||||
|
||||
|
||||
def model_supports_image_input(model_name: str, model_provider: str) -> bool:
|
||||
model_map = get_model_map()
|
||||
try:
|
||||
model_obj = _find_model_obj(
|
||||
model_map,
|
||||
model_provider,
|
||||
model_name,
|
||||
)
|
||||
if not model_obj:
|
||||
raise RuntimeError(
|
||||
f"No litellm entry found for {model_provider}/{model_name}"
|
||||
)
|
||||
return model_obj.get("supports_vision", False)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to get model object for {model_provider}/{model_name}"
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -97,7 +97,6 @@ from onyx.server.settings.api import basic_router as settings_router
|
||||
from onyx.server.token_rate_limits.api import (
|
||||
router as token_rate_limit_settings_router,
|
||||
)
|
||||
from onyx.server.user_documents.api import router as user_documents_router
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.setup import setup_multitenant_onyx
|
||||
from onyx.setup import setup_onyx
|
||||
@@ -213,7 +212,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
if not MULTI_TENANT:
|
||||
# We cache this at the beginning so there is no delay in the first telemetry
|
||||
get_or_generate_uuid(tenant_id=None)
|
||||
get_or_generate_uuid()
|
||||
|
||||
# If we are multi-tenant, we need to only set up initial public tables
|
||||
with Session(engine) as db_session:
|
||||
@@ -287,7 +286,6 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, input_prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, cc_pair_router)
|
||||
include_router_with_global_prefix_prepended(application, user_documents_router)
|
||||
include_router_with_global_prefix_prepended(application, folder_router)
|
||||
include_router_with_global_prefix_prepended(application, document_set_router)
|
||||
include_router_with_global_prefix_prepended(application, search_settings_router)
|
||||
|
||||
@@ -14,7 +14,6 @@ from typing import Set
|
||||
|
||||
from prometheus_client import Gauge
|
||||
from prometheus_client import start_http_server
|
||||
from redis.lock import Lock
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
@@ -123,9 +122,6 @@ class SlackbotHandler:
|
||||
self.socket_clients: Dict[tuple[str | None, int], TenantSocketModeClient] = {}
|
||||
self.slack_bot_tokens: Dict[tuple[str | None, int], SlackBotTokens] = {}
|
||||
|
||||
# Store Redis lock objects here so we can release them properly
|
||||
self.redis_locks: Dict[str | None, Lock] = {}
|
||||
|
||||
self.running = True
|
||||
self.pod_id = self.get_pod_id()
|
||||
self._shutdown_event = Event()
|
||||
@@ -163,15 +159,10 @@ class SlackbotHandler:
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
self.acquire_tenants()
|
||||
|
||||
# After we finish acquiring and managing Slack bots,
|
||||
# set the gauge to the number of active tenants (those with Slack bots).
|
||||
active_tenants_gauge.labels(namespace=POD_NAMESPACE, pod=POD_NAME).set(
|
||||
len(self.tenant_ids)
|
||||
)
|
||||
logger.debug(
|
||||
f"Current active tenants with Slack bots: {len(self.tenant_ids)}"
|
||||
)
|
||||
logger.debug(f"Current active tenants: {len(self.tenant_ids)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in Slack acquisition: {e}")
|
||||
self._shutdown_event.wait(timeout=TENANT_ACQUISITION_INTERVAL)
|
||||
@@ -180,9 +171,7 @@ class SlackbotHandler:
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
self.send_heartbeats()
|
||||
logger.debug(
|
||||
f"Sent heartbeats for {len(self.tenant_ids)} active tenants"
|
||||
)
|
||||
logger.debug(f"Sent heartbeats for {len(self.tenant_ids)} tenants")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in heartbeat loop: {e}")
|
||||
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
|
||||
@@ -190,21 +179,17 @@ class SlackbotHandler:
|
||||
def _manage_clients_per_tenant(
|
||||
self, db_session: Session, tenant_id: str | None, bot: SlackBot
|
||||
) -> None:
|
||||
"""
|
||||
- If the tokens are missing or empty, close the socket client and remove them.
|
||||
- If the tokens have changed, close the existing socket client and reconnect.
|
||||
- If the tokens are new, warm up the model and start a new socket client.
|
||||
"""
|
||||
slack_bot_tokens = SlackBotTokens(
|
||||
bot_token=bot.bot_token,
|
||||
app_token=bot.app_token,
|
||||
)
|
||||
tenant_bot_pair = (tenant_id, bot.id)
|
||||
|
||||
# If the tokens are missing or empty, close the socket client and remove them.
|
||||
# If the tokens are not set, we need to close the socket client and delete the tokens
|
||||
# for the tenant and app
|
||||
if not slack_bot_tokens:
|
||||
logger.debug(
|
||||
f"No Slack bot tokens found for tenant={tenant_id}, bot {bot.id}"
|
||||
f"No Slack bot token found for tenant {tenant_id}, bot {bot.id}"
|
||||
)
|
||||
if tenant_bot_pair in self.socket_clients:
|
||||
asyncio.run(self.socket_clients[tenant_bot_pair].close())
|
||||
@@ -219,10 +204,9 @@ class SlackbotHandler:
|
||||
if not tokens_exist or tokens_changed:
|
||||
if tokens_exist:
|
||||
logger.info(
|
||||
f"Slack Bot tokens changed for tenant={tenant_id}, bot {bot.id}; reconnecting"
|
||||
f"Slack Bot tokens have changed for tenant {tenant_id}, bot {bot.id} - reconnecting"
|
||||
)
|
||||
else:
|
||||
# Warm up the model if needed
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
@@ -233,168 +217,77 @@ class SlackbotHandler:
|
||||
|
||||
self.slack_bot_tokens[tenant_bot_pair] = slack_bot_tokens
|
||||
|
||||
# Close any existing connection first
|
||||
if tenant_bot_pair in self.socket_clients:
|
||||
asyncio.run(self.socket_clients[tenant_bot_pair].close())
|
||||
|
||||
self.start_socket_client(bot.id, tenant_id, slack_bot_tokens)
|
||||
|
||||
def acquire_tenants(self) -> None:
|
||||
"""
|
||||
- Attempt to acquire a Redis lock for each tenant.
|
||||
- If acquired, check if that tenant actually has Slack bots.
|
||||
- If yes, store them in self.tenant_ids and manage the socket connections.
|
||||
- If a tenant in self.tenant_ids no longer has Slack bots, remove it (and release the lock in this scope).
|
||||
"""
|
||||
all_tenants = get_all_tenant_ids()
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
# 1) Try to acquire locks for new tenants
|
||||
for tenant_id in all_tenants:
|
||||
for tenant_id in tenant_ids:
|
||||
if (
|
||||
DISALLOWED_SLACK_BOT_TENANT_LIST is not None
|
||||
and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST
|
||||
):
|
||||
logger.debug(f"Tenant {tenant_id} is disallowed; skipping.")
|
||||
logger.debug(f"Tenant {tenant_id} is in the disallowed list, skipping")
|
||||
continue
|
||||
|
||||
# Already acquired in a previous loop iteration?
|
||||
if tenant_id in self.tenant_ids:
|
||||
logger.debug(f"Tenant {tenant_id} already in self.tenant_ids")
|
||||
continue
|
||||
|
||||
# Respect max tenant limit per pod
|
||||
if len(self.tenant_ids) >= MAX_TENANTS_PER_POD:
|
||||
logger.info(
|
||||
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}); not acquiring more."
|
||||
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}) Not acquiring any more tenants"
|
||||
)
|
||||
break
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
# Acquire a Redis lock (non-blocking)
|
||||
rlock = redis_client.lock(
|
||||
OnyxRedisLocks.SLACK_BOT_LOCK, timeout=TENANT_LOCK_EXPIRATION
|
||||
pod_id = self.pod_id
|
||||
acquired = redis_client.set(
|
||||
OnyxRedisLocks.SLACK_BOT_LOCK,
|
||||
pod_id,
|
||||
nx=True,
|
||||
ex=TENANT_LOCK_EXPIRATION,
|
||||
)
|
||||
lock_acquired = rlock.acquire(blocking=False)
|
||||
|
||||
if not lock_acquired and not DEV_MODE:
|
||||
logger.debug(
|
||||
f"Another pod holds the lock for tenant {tenant_id}, skipping."
|
||||
)
|
||||
if not acquired and not DEV_MODE:
|
||||
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
|
||||
continue
|
||||
|
||||
if lock_acquired:
|
||||
logger.debug(f"Acquired lock for tenant {tenant_id}.")
|
||||
self.redis_locks[tenant_id] = rlock
|
||||
else:
|
||||
# DEV_MODE will skip the lock acquisition guard
|
||||
logger.debug(
|
||||
f"Running in DEV_MODE. Not enforcing lock for {tenant_id}."
|
||||
)
|
||||
logger.debug(f"Acquired lock for tenant {tenant_id}")
|
||||
|
||||
# Now check if this tenant actually has Slack bots
|
||||
self.tenant_ids.add(tenant_id)
|
||||
|
||||
for tenant_id in self.tenant_ids:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
|
||||
tenant_id or POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
bots: list[SlackBot] = []
|
||||
try:
|
||||
bots = list(fetch_slack_bots(db_session=db_session))
|
||||
except KvKeyNotFoundError:
|
||||
# No Slackbot tokens, pass
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error fetching Slack bots for tenant {tenant_id}: {e}"
|
||||
)
|
||||
|
||||
if bots:
|
||||
# Mark as active tenant
|
||||
self.tenant_ids.add(tenant_id)
|
||||
bots = fetch_slack_bots(db_session=db_session)
|
||||
for bot in bots:
|
||||
self._manage_clients_per_tenant(
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
bot=bot,
|
||||
)
|
||||
else:
|
||||
# If no Slack bots, release lock immediately (unless in DEV_MODE)
|
||||
if lock_acquired and not DEV_MODE:
|
||||
rlock.release()
|
||||
del self.redis_locks[tenant_id]
|
||||
logger.debug(
|
||||
f"No Slack bots for tenant {tenant_id}; lock released (if held)."
|
||||
)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
# 2) Make sure tenants we're handling still have Slack bots
|
||||
for tenant_id in list(self.tenant_ids):
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(
|
||||
tenant_id or POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Attempt to fetch Slack bots
|
||||
try:
|
||||
bots = list(fetch_slack_bots(db_session=db_session))
|
||||
except KvKeyNotFoundError:
|
||||
# No Slackbot tokens, pass (and remove below)
|
||||
bots = []
|
||||
logger.debug(f"Missing Slack Bot tokens for tenant {tenant_id}")
|
||||
if (tenant_id, bot.id) in self.socket_clients:
|
||||
asyncio.run(self.socket_clients[tenant_id, bot.id].close())
|
||||
del self.socket_clients[tenant_id, bot.id]
|
||||
del self.slack_bot_tokens[tenant_id, bot.id]
|
||||
except Exception as e:
|
||||
logger.exception(f"Error handling tenant {tenant_id}: {e}")
|
||||
bots = []
|
||||
|
||||
if not bots:
|
||||
logger.info(
|
||||
f"Tenant {tenant_id} no longer has Slack bots. Removing."
|
||||
)
|
||||
self._remove_tenant(tenant_id)
|
||||
|
||||
# NOTE: We release the lock here (in the same scope it was acquired)
|
||||
if tenant_id in self.redis_locks and not DEV_MODE:
|
||||
try:
|
||||
self.redis_locks[tenant_id].release()
|
||||
del self.redis_locks[tenant_id]
|
||||
logger.info(f"Released lock for tenant {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error releasing lock for tenant {tenant_id}: {e}"
|
||||
)
|
||||
else:
|
||||
# Manage or reconnect Slack bot sockets
|
||||
for bot in bots:
|
||||
self._manage_clients_per_tenant(
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
bot=bot,
|
||||
)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
def _remove_tenant(self, tenant_id: str | None) -> None:
|
||||
"""
|
||||
Helper to remove a tenant from `self.tenant_ids` and close any socket clients.
|
||||
(Lock release now happens in `acquire_tenants()`, not here.)
|
||||
"""
|
||||
# Close all socket clients for this tenant
|
||||
for (t_id, slack_bot_id), client in list(self.socket_clients.items()):
|
||||
if t_id == tenant_id:
|
||||
asyncio.run(client.close())
|
||||
del self.socket_clients[(t_id, slack_bot_id)]
|
||||
del self.slack_bot_tokens[(t_id, slack_bot_id)]
|
||||
logger.info(
|
||||
f"Stopped SocketModeClient for tenant: {t_id}, app: {slack_bot_id}"
|
||||
)
|
||||
|
||||
# Remove from active set
|
||||
if tenant_id in self.tenant_ids:
|
||||
self.tenant_ids.remove(tenant_id)
|
||||
|
||||
def send_heartbeats(self) -> None:
|
||||
current_time = int(time.time())
|
||||
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} active tenants")
|
||||
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} tenants")
|
||||
for tenant_id in self.tenant_ids:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
heartbeat_key = f"{OnyxRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{self.pod_id}"
|
||||
@@ -422,7 +315,6 @@ class SlackbotHandler:
|
||||
)
|
||||
socket_client.connect()
|
||||
self.socket_clients[tenant_id, slack_bot_id] = socket_client
|
||||
# Ensure tenant is tracked as active
|
||||
self.tenant_ids.add(tenant_id)
|
||||
logger.info(
|
||||
f"Started SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
@@ -430,7 +322,7 @@ class SlackbotHandler:
|
||||
|
||||
def stop_socket_clients(self) -> None:
|
||||
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
|
||||
for (tenant_id, slack_bot_id), client in list(self.socket_clients.items()):
|
||||
for (tenant_id, slack_bot_id), client in self.socket_clients.items():
|
||||
asyncio.run(client.close())
|
||||
logger.info(
|
||||
f"Stopped SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
@@ -448,19 +340,17 @@ class SlackbotHandler:
|
||||
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
|
||||
self.stop_socket_clients()
|
||||
|
||||
# Release locks for all tenants we currently hold
|
||||
# Release locks for all tenants
|
||||
logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants")
|
||||
for tenant_id in list(self.tenant_ids):
|
||||
if tenant_id in self.redis_locks:
|
||||
try:
|
||||
self.redis_locks[tenant_id].release()
|
||||
logger.info(f"Released lock for tenant {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
|
||||
finally:
|
||||
del self.redis_locks[tenant_id]
|
||||
for tenant_id in self.tenant_ids:
|
||||
try:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(OnyxRedisLocks.SLACK_BOT_LOCK)
|
||||
logger.info(f"Released lock for tenant {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error releasing lock for tenant {tenant_id}: {e}")
|
||||
|
||||
# Wait for background threads to finish (with a timeout)
|
||||
# Wait for background threads to finish (with timeout)
|
||||
logger.info("Waiting for background threads to finish...")
|
||||
self.acquire_thread.join(timeout=5)
|
||||
self.heartbeat_thread.join(timeout=5)
|
||||
|
||||
@@ -19,8 +19,9 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_DANSWER_DATETIME_REPLACEMENT_PAT = "[[CURRENT_DATETIME]]"
|
||||
_BASIC_TIME_STR = "The current date is {datetime_info}."
|
||||
MOST_BASIC_PROMPT = "You are a helpful AI assistant."
|
||||
DANSWER_DATETIME_REPLACEMENT = "DANSWER_DATETIME_REPLACEMENT"
|
||||
BASIC_TIME_STR = "The current date is {datetime_info}."
|
||||
|
||||
|
||||
def get_current_llm_day_time(
|
||||
@@ -37,36 +38,23 @@ def get_current_llm_day_time(
|
||||
return f"{formatted_datetime}"
|
||||
|
||||
|
||||
def build_date_time_string() -> str:
|
||||
return ADDITIONAL_INFO.format(
|
||||
datetime_info=_BASIC_TIME_STR.format(datetime_info=get_current_llm_day_time())
|
||||
)
|
||||
|
||||
|
||||
def handle_onyx_date_awareness(
|
||||
prompt_str: str,
|
||||
prompt_config: PromptConfig,
|
||||
add_additional_info_if_no_tag: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
If there is a [[CURRENT_DATETIME]] tag, replace it with the current date and time no matter what.
|
||||
If the prompt is datetime aware, and there are no [[CURRENT_DATETIME]] tags, add it to the prompt.
|
||||
do nothing otherwise.
|
||||
This can later be expanded to support other tags.
|
||||
"""
|
||||
|
||||
if _DANSWER_DATETIME_REPLACEMENT_PAT in prompt_str:
|
||||
def add_date_time_to_prompt(prompt_str: str) -> str:
|
||||
if DANSWER_DATETIME_REPLACEMENT in prompt_str:
|
||||
return prompt_str.replace(
|
||||
_DANSWER_DATETIME_REPLACEMENT_PAT,
|
||||
DANSWER_DATETIME_REPLACEMENT,
|
||||
get_current_llm_day_time(full_sentence=False, include_day_of_week=True),
|
||||
)
|
||||
any_tag_present = any(
|
||||
_DANSWER_DATETIME_REPLACEMENT_PAT in text
|
||||
for text in [prompt_str, prompt_config.system_prompt, prompt_config.task_prompt]
|
||||
)
|
||||
if add_additional_info_if_no_tag and not any_tag_present:
|
||||
return prompt_str + build_date_time_string()
|
||||
return prompt_str
|
||||
|
||||
if prompt_str:
|
||||
return prompt_str + ADDITIONAL_INFO.format(
|
||||
datetime_info=get_current_llm_day_time()
|
||||
)
|
||||
else:
|
||||
return (
|
||||
MOST_BASIC_PROMPT
|
||||
+ " "
|
||||
+ BASIC_TIME_STR.format(datetime_info=get_current_llm_day_time())
|
||||
)
|
||||
|
||||
|
||||
def build_task_prompt_reminders(
|
||||
|
||||
@@ -30,17 +30,10 @@ class RedisConnectorIndex:
|
||||
GENERATOR_LOCK_PREFIX = "da_lock:indexing"
|
||||
|
||||
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
|
||||
TERMINATE_TTL = 600
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# there are gaps in time between states where we need some slack
|
||||
# to correctly transition
|
||||
# it's difficult to prevent
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
# used to signal that the watchdog is running
|
||||
WATCHDOG_PREFIX = PREFIX + "_watchdog"
|
||||
WATCHDOG_TTL = 300
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -66,7 +59,6 @@ class RedisConnectorIndex:
|
||||
)
|
||||
self.terminate_key = f"{self.TERMINATE_PREFIX}_{id}/{search_settings_id}"
|
||||
self.active_key = f"{self.ACTIVE_PREFIX}_{id}/{search_settings_id}"
|
||||
self.watchdog_key = f"{self.WATCHDOG_PREFIX}_{id}/{search_settings_id}"
|
||||
|
||||
@classmethod
|
||||
def fence_key_with_ids(cls, cc_pair_id: int, search_settings_id: int) -> str:
|
||||
@@ -118,24 +110,7 @@ class RedisConnectorIndex:
|
||||
"""This sets a signal. It does not block!"""
|
||||
# We shouldn't need very long to terminate the spawned task.
|
||||
# 10 minute TTL is good.
|
||||
self.redis.set(
|
||||
f"{self.terminate_key}_{celery_task_id}", 0, ex=self.TERMINATE_TTL
|
||||
)
|
||||
|
||||
def set_watchdog(self, value: bool) -> None:
|
||||
"""Signal the state of the watchdog."""
|
||||
if not value:
|
||||
self.redis.delete(self.watchdog_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.watchdog_key, 0, ex=self.WATCHDOG_TTL)
|
||||
|
||||
def watchdog_signaled(self) -> bool:
|
||||
"""Check the state of the watchdog."""
|
||||
if self.redis.exists(self.watchdog_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
self.redis.set(f"{self.terminate_key}_{celery_task_id}", 0, ex=600)
|
||||
|
||||
def set_active(self) -> None:
|
||||
"""This sets a signal to keep the indexing flow from getting cleaned up within
|
||||
@@ -143,7 +118,7 @@ class RedisConnectorIndex:
|
||||
|
||||
The slack in timing is needed to avoid race conditions where simply checking
|
||||
the celery queue and task status could result in race conditions."""
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
self.redis.set(self.active_key, 0, ex=3600)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
|
||||
@@ -21,7 +21,6 @@ from onyx.configs.app_configs import REDIS_HOST
|
||||
from onyx.configs.app_configs import REDIS_PASSWORD
|
||||
from onyx.configs.app_configs import REDIS_POOL_MAX_CONNECTIONS
|
||||
from onyx.configs.app_configs import REDIS_PORT
|
||||
from onyx.configs.app_configs import REDIS_REPLICA_HOST
|
||||
from onyx.configs.app_configs import REDIS_SSL
|
||||
from onyx.configs.app_configs import REDIS_SSL_CA_CERTS
|
||||
from onyx.configs.app_configs import REDIS_SSL_CERT_REQS
|
||||
@@ -133,32 +132,23 @@ class RedisPool:
|
||||
_instance: Optional["RedisPool"] = None
|
||||
_lock: threading.Lock = threading.Lock()
|
||||
_pool: redis.BlockingConnectionPool
|
||||
_replica_pool: redis.BlockingConnectionPool
|
||||
|
||||
def __new__(cls) -> "RedisPool":
|
||||
if not cls._instance:
|
||||
with cls._lock:
|
||||
if not cls._instance:
|
||||
cls._instance = super(RedisPool, cls).__new__(cls)
|
||||
cls._instance._init_pools()
|
||||
cls._instance._init_pool()
|
||||
return cls._instance
|
||||
|
||||
def _init_pools(self) -> None:
|
||||
def _init_pool(self) -> None:
|
||||
self._pool = RedisPool.create_pool(ssl=REDIS_SSL)
|
||||
self._replica_pool = RedisPool.create_pool(
|
||||
host=REDIS_REPLICA_HOST, ssl=REDIS_SSL
|
||||
)
|
||||
|
||||
def get_client(self, tenant_id: str | None) -> Redis:
|
||||
if tenant_id is None:
|
||||
tenant_id = "public"
|
||||
return TenantRedis(tenant_id, connection_pool=self._pool)
|
||||
|
||||
def get_replica_client(self, tenant_id: str | None) -> Redis:
|
||||
if tenant_id is None:
|
||||
tenant_id = "public"
|
||||
return TenantRedis(tenant_id, connection_pool=self._replica_pool)
|
||||
|
||||
@staticmethod
|
||||
def create_pool(
|
||||
host: str = REDIS_HOST,
|
||||
@@ -222,10 +212,6 @@ def get_redis_client(*, tenant_id: str | None) -> Redis:
|
||||
return redis_pool.get_client(tenant_id)
|
||||
|
||||
|
||||
def get_redis_replica_client(*, tenant_id: str | None) -> Redis:
|
||||
return redis_pool.get_replica_client(tenant_id)
|
||||
|
||||
|
||||
SSL_CERT_REQS_MAP = {
|
||||
"none": ssl.CERT_NONE,
|
||||
"optional": ssl.CERT_OPTIONAL,
|
||||
|
||||
@@ -6184,7 +6184,7 @@
|
||||
"chunk_ind": 0
|
||||
},
|
||||
{
|
||||
"url": "https://docs.onyx.app/more/use_cases/support",
|
||||
"url": "https://docs.onyx.app/more/use_cases/customer_support",
|
||||
"title": "Customer Support",
|
||||
"content": "Help your customer support team instantly answer any question across your entire product.\n\nAI Enabled Support\nCustomer support agents have one of the highest breadth jobs. They field requests that cover the entire surface area of the product and need to help your users find success on extremely short timelines. Because they're not the same people who designed or built the system, they often lack the depth of understanding needed - resulting in delays and escalations to other teams. Modern teams are leveraging AI to help their CS team optimize the speed and quality of these critical customer-facing interactions.\n\nThe Importance of Context\nThere are two critical components of AI copilots for customer support. The first is that the AI system needs to be connected with as much information as possible (not just support tools like Zendesk or Intercom) and that the knowledge needs to be as fresh as possible. Sometimes a fix might even be in places rarely checked by CS such as pull requests in a code repository. The second critical component is the ability of the AI system to break down difficult concepts and convoluted processes into more digestible descriptions and for your team members to be able to chat back and forth with the system to build a better understanding.\n\nOnyx takes care of both of these. The system connects up to over 30+ different applications and the knowledge is pulled in constantly so that the information access is always up to date.",
|
||||
"title_embedding": [
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
"chunk_ind": 0
|
||||
},
|
||||
{
|
||||
"url": "https://docs.onyx.app/more/use_cases/support",
|
||||
"url": "https://docs.onyx.app/more/use_cases/customer_support",
|
||||
"title": "Customer Support",
|
||||
"content": "Help your customer support team instantly answer any question across your entire product.\n\nAI Enabled Support\nCustomer support agents have one of the highest breadth jobs. They field requests that cover the entire surface area of the product and need to help your users find success on extremely short timelines. Because they're not the same people who designed or built the system, they often lack the depth of understanding needed - resulting in delays and escalations to other teams. Modern teams are leveraging AI to help their CS team optimize the speed and quality of these critical customer-facing interactions.\n\nThe Importance of Context\nThere are two critical components of AI copilots for customer support. The first is that the AI system needs to be connected with as much information as possible (not just support tools like Zendesk or Intercom) and that the knowledge needs to be as fresh as possible. Sometimes a fix might even be in places rarely checked by CS such as pull requests in a code repository. The second critical component is the ability of the AI system to break down difficult concepts and convoluted processes into more digestible descriptions and for your team members to be able to chat back and forth with the system to build a better understanding.\n\nOnyx takes care of both of these. The system connects up to over 30+ different applications and the knowledge is pulled in constantly so that the information access is always up to date.",
|
||||
"chunk_ind": 0
|
||||
|
||||
@@ -26,7 +26,6 @@ from onyx.db.index_attempt import mock_successful_index_attempt
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
@@ -34,6 +33,7 @@ from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.documents.models import ConnectorBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -218,11 +218,9 @@ def seed_initial_documents(
|
||||
|
||||
# Retries here because the index may take a few seconds to become ready
|
||||
# as we just sent over the Vespa schema and there is a slight delay
|
||||
if not wait_for_vespa_with_timeout():
|
||||
logger.error("Vespa did not become ready within the timeout")
|
||||
raise ValueError("Vespa failed to become ready within the timeout")
|
||||
|
||||
document_index.index(
|
||||
index_with_retries = retry_builder(tries=15)(document_index.index)
|
||||
index_with_retries(
|
||||
chunks=chunks,
|
||||
index_batch_params=IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt={},
|
||||
|
||||
@@ -8,7 +8,7 @@ prompts:
|
||||
# System Prompt (as shown in UI)
|
||||
system: >
|
||||
You are a question answering system that is constantly learning and improving.
|
||||
The current date is [[CURRENT_DATETIME]].
|
||||
The current date is DANSWER_DATETIME_REPLACEMENT.
|
||||
|
||||
You can process and comprehend vast amounts of text and utilize this knowledge to provide
|
||||
grounded, accurate, and concise answers to diverse queries.
|
||||
@@ -24,7 +24,7 @@ prompts:
|
||||
|
||||
If there are no relevant documents, refer to the chat history and your internal knowledge.
|
||||
# Inject a statement at the end of system prompt to inform the LLM of the current date/time
|
||||
# If the [[CURRENT_DATETIME]] is set, the date/time is inserted there instead
|
||||
# If the DANSWER_DATETIME_REPLACEMENT is set, the date/time is inserted there instead
|
||||
# Format looks like: "October 16, 2023 14:30"
|
||||
datetime_aware: true
|
||||
# Prompts the LLM to include citations in the for [1], [2] etc.
|
||||
@@ -51,7 +51,7 @@ prompts:
|
||||
- name: "OnlyLLM"
|
||||
description: "Chat directly with the LLM!"
|
||||
system: >
|
||||
You are a helpful AI assistant. The current date is [[CURRENT_DATETIME]]
|
||||
You are a helpful AI assistant. The current date is DANSWER_DATETIME_REPLACEMENT
|
||||
|
||||
|
||||
You give concise responses to very simple questions, but provide more thorough responses to
|
||||
@@ -69,7 +69,7 @@ prompts:
|
||||
system: >
|
||||
You are a text summarizing assistant that highlights the most important knowledge from the
|
||||
context provided, prioritizing the information that relates to the user query.
|
||||
The current date is [[CURRENT_DATETIME]].
|
||||
The current date is DANSWER_DATETIME_REPLACEMENT.
|
||||
|
||||
You ARE NOT creative and always stick to the provided documents.
|
||||
If there are no documents, refer to the conversation history.
|
||||
@@ -87,7 +87,7 @@ prompts:
|
||||
description: "Recites information from retrieved context! Least creative but most safe!"
|
||||
system: >
|
||||
Quote and cite relevant information from provided context based on the user query.
|
||||
The current date is [[CURRENT_DATETIME]].
|
||||
The current date is DANSWER_DATETIME_REPLACEMENT.
|
||||
|
||||
You only provide quotes that are EXACT substrings from provided documents!
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ router = APIRouter(prefix="/manage")
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/index-attempts")
|
||||
def get_cc_pair_index_attempts(
|
||||
cc_pair_id: int,
|
||||
page_num: int = Query(0, ge=0),
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(10, ge=1, le=1000),
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
@@ -81,7 +81,7 @@ def get_cc_pair_index_attempts(
|
||||
index_attempts = get_paginated_index_attempts_for_cc_pair_id(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
page=page_num,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
)
|
||||
return PaginatedReturn(
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import mimetypes
|
||||
import os
|
||||
import uuid
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -380,47 +377,19 @@ def check_drive_tokens(
|
||||
return AuthStatus(authenticated=True)
|
||||
|
||||
|
||||
def upload_files(files: list[UploadFile], db_session: Session) -> FileUploadResponse:
|
||||
@router.post("/admin/connector/file/upload")
|
||||
def upload_files(
|
||||
files: list[UploadFile],
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FileUploadResponse:
|
||||
for file in files:
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="File name cannot be empty")
|
||||
|
||||
# Skip directories and known macOS metadata entries
|
||||
def should_process_file(file_path: str) -> bool:
|
||||
normalized_path = os.path.normpath(file_path)
|
||||
return not any(part.startswith(".") for part in normalized_path.split(os.sep))
|
||||
|
||||
try:
|
||||
file_store = get_default_file_store(db_session)
|
||||
deduped_file_paths = []
|
||||
|
||||
for file in files:
|
||||
if file.content_type and file.content_type.startswith("application/zip"):
|
||||
with zipfile.ZipFile(file.file, "r") as zf:
|
||||
for file_info in zf.namelist():
|
||||
if zf.getinfo(file_info).is_dir():
|
||||
continue
|
||||
|
||||
if not should_process_file(file_info):
|
||||
continue
|
||||
|
||||
sub_file_bytes = zf.read(file_info)
|
||||
sub_file_name = os.path.join(str(uuid.uuid4()), file_info)
|
||||
deduped_file_paths.append(sub_file_name)
|
||||
|
||||
mime_type, __ = mimetypes.guess_type(file_info)
|
||||
if mime_type is None:
|
||||
mime_type = "application/octet-stream"
|
||||
|
||||
file_store.save_file(
|
||||
file_name=sub_file_name,
|
||||
content=BytesIO(sub_file_bytes),
|
||||
display_name=os.path.basename(file_info),
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
file_type=mime_type,
|
||||
)
|
||||
continue
|
||||
|
||||
file_path = os.path.join(str(uuid.uuid4()), cast(str, file.filename))
|
||||
deduped_file_paths.append(file_path)
|
||||
file_store.save_file(
|
||||
@@ -441,15 +410,6 @@ def upload_files(files: list[UploadFile], db_session: Session) -> FileUploadResp
|
||||
return FileUploadResponse(file_paths=deduped_file_paths)
|
||||
|
||||
|
||||
@router.post("/admin/connector/file/upload")
|
||||
def upload_files_api(
|
||||
files: list[UploadFile],
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FileUploadResponse:
|
||||
return upload_files(files, db_session)
|
||||
|
||||
|
||||
@router.get("/admin/connector")
|
||||
def get_connectors_by_credential(
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
@@ -933,21 +893,81 @@ def connector_run_once(
|
||||
connector_id = run_info.connector_id
|
||||
specified_credential_ids = run_info.credential_ids
|
||||
|
||||
if not specified_credential_ids:
|
||||
try:
|
||||
possible_credential_ids = get_connector_credential_ids(
|
||||
run_info.connector_id, db_session
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No credentials specified for indexing"
|
||||
status_code=404,
|
||||
detail=f"Connector by id {connector_id} does not exist.",
|
||||
)
|
||||
|
||||
try:
|
||||
num_triggers = trigger_indexing_for_cc_pair(
|
||||
specified_credential_ids,
|
||||
connector_id,
|
||||
run_info.from_beginning,
|
||||
tenant_id,
|
||||
db_session,
|
||||
if not specified_credential_ids:
|
||||
credential_ids = possible_credential_ids
|
||||
else:
|
||||
if set(specified_credential_ids).issubset(set(possible_credential_ids)):
|
||||
credential_ids = specified_credential_ids
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Not all specified credentials are associated with connector",
|
||||
)
|
||||
|
||||
if not credential_ids:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Connector has no valid credentials, cannot create index attempts.",
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# Prevents index attempts for cc pairs that already have an index attempt currently running
|
||||
skipped_credentials = [
|
||||
credential_id
|
||||
for credential_id in credential_ids
|
||||
if get_index_attempts_for_cc_pair(
|
||||
cc_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=run_info.connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
only_current=True,
|
||||
db_session=db_session,
|
||||
disinclude_finished=True,
|
||||
)
|
||||
]
|
||||
|
||||
connector_credential_pairs = [
|
||||
get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
for credential_id in credential_ids
|
||||
if credential_id not in skipped_credentials
|
||||
]
|
||||
|
||||
num_triggers = 0
|
||||
for cc_pair in connector_credential_pairs:
|
||||
if cc_pair is not None:
|
||||
indexing_mode = IndexingMode.UPDATE
|
||||
if run_info.from_beginning:
|
||||
indexing_mode = IndexingMode.REINDEX
|
||||
|
||||
mark_ccpair_with_indexing_trigger(cc_pair.id, indexing_mode, db_session)
|
||||
num_triggers += 1
|
||||
|
||||
logger.info(
|
||||
f"connector_run_once - marking cc_pair with indexing trigger: "
|
||||
f"connector={run_info.connector_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"indexing_trigger={indexing_mode}"
|
||||
)
|
||||
|
||||
# run the beat task to pick up the triggers immediately
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
msg = f"Marked {num_triggers} index attempts with indexing triggers."
|
||||
return StatusResponse(
|
||||
@@ -1119,82 +1139,3 @@ def get_basic_connector_indexing_status(
|
||||
for cc_pair in cc_pairs
|
||||
if cc_pair.connector.source != DocumentSource.INGESTION_API
|
||||
]
|
||||
|
||||
|
||||
def trigger_indexing_for_cc_pair(
|
||||
specified_credential_ids: list[int],
|
||||
connector_id: int,
|
||||
from_beginning: bool,
|
||||
tenant_id: str,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
try:
|
||||
possible_credential_ids = get_connector_credential_ids(connector_id, db_session)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Connector by id {connector_id} does not exist: {str(e)}")
|
||||
|
||||
if not specified_credential_ids:
|
||||
credential_ids = possible_credential_ids
|
||||
else:
|
||||
if set(specified_credential_ids).issubset(set(possible_credential_ids)):
|
||||
credential_ids = specified_credential_ids
|
||||
else:
|
||||
raise ValueError(
|
||||
"Not all specified credentials are associated with connector"
|
||||
)
|
||||
|
||||
if not credential_ids:
|
||||
raise ValueError(
|
||||
"Connector has no valid credentials, cannot create index attempts."
|
||||
)
|
||||
|
||||
# Prevents index attempts for cc pairs that already have an index attempt currently running
|
||||
skipped_credentials = [
|
||||
credential_id
|
||||
for credential_id in credential_ids
|
||||
if get_index_attempts_for_cc_pair(
|
||||
cc_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
only_current=True,
|
||||
db_session=db_session,
|
||||
disinclude_finished=True,
|
||||
)
|
||||
]
|
||||
|
||||
connector_credential_pairs = [
|
||||
get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
for credential_id in credential_ids
|
||||
if credential_id not in skipped_credentials
|
||||
]
|
||||
|
||||
num_triggers = 0
|
||||
for cc_pair in connector_credential_pairs:
|
||||
if cc_pair is not None:
|
||||
indexing_mode = IndexingMode.UPDATE
|
||||
if from_beginning:
|
||||
indexing_mode = IndexingMode.REINDEX
|
||||
|
||||
mark_ccpair_with_indexing_trigger(cc_pair.id, indexing_mode, db_session)
|
||||
num_triggers += 1
|
||||
|
||||
logger.info(
|
||||
f"connector_run_once - marking cc_pair with indexing trigger: "
|
||||
f"connector={connector_id} "
|
||||
f"cc_pair={cc_pair.id} "
|
||||
f"indexing_trigger={indexing_mode}"
|
||||
)
|
||||
|
||||
# run the beat task to pick up the triggers immediately
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
return num_triggers
|
||||
|
||||
@@ -411,7 +411,7 @@ class FileUploadResponse(BaseModel):
|
||||
|
||||
|
||||
class ObjectCreationIdResponse(BaseModel):
|
||||
id: int
|
||||
id: int | str
|
||||
credential: CredentialSnapshot | None = None
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -192,7 +191,8 @@ def create_persona(
|
||||
name=build_prompt_name_from_persona_name(persona_upsert_request.name),
|
||||
system_prompt=persona_upsert_request.system_prompt,
|
||||
task_prompt=persona_upsert_request.task_prompt,
|
||||
datetime_aware=persona_upsert_request.datetime_aware,
|
||||
# TODO: The PersonaUpsertRequest should provide the value for datetime_aware
|
||||
datetime_aware=False,
|
||||
include_citations=persona_upsert_request.include_citations,
|
||||
prompt_id=prompt_id,
|
||||
)
|
||||
@@ -236,7 +236,8 @@ def update_persona(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
name=build_prompt_name_from_persona_name(persona_upsert_request.name),
|
||||
datetime_aware=persona_upsert_request.datetime_aware,
|
||||
# TODO: The PersonaUpsertRequest should provide the value for datetime_aware
|
||||
datetime_aware=False,
|
||||
system_prompt=persona_upsert_request.system_prompt,
|
||||
task_prompt=persona_upsert_request.task_prompt,
|
||||
include_citations=persona_upsert_request.include_citations,
|
||||
@@ -276,14 +277,8 @@ def create_label(
|
||||
_: User | None = Depends(current_user),
|
||||
) -> PersonaLabelResponse:
|
||||
"""Create a new assistant label"""
|
||||
try:
|
||||
label_model = create_assistant_label(name=label.name, db_session=db)
|
||||
return PersonaLabelResponse.from_model(label_model)
|
||||
except IntegrityError:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Label with name '{label.name}' already exists. Please choose a different name.",
|
||||
)
|
||||
label_model = create_assistant_label(name=label.name, db_session=db)
|
||||
return PersonaLabelResponse.from_model(label_model)
|
||||
|
||||
|
||||
@admin_router.patch("/label/{label_id}")
|
||||
|
||||
@@ -60,7 +60,6 @@ class PersonaUpsertRequest(BaseModel):
|
||||
description: str
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
datetime_aware: bool
|
||||
document_set_ids: list[int]
|
||||
num_chunks: float
|
||||
include_citations: bool
|
||||
|
||||
@@ -41,16 +41,6 @@ def _validate_tool_definition(definition: dict[str, Any]) -> None:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
def _validate_auth_settings(tool_data: CustomToolCreate | CustomToolUpdate) -> None:
|
||||
if tool_data.passthrough_auth and tool_data.custom_headers:
|
||||
for header in tool_data.custom_headers:
|
||||
if header.key.lower() == "authorization":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Cannot use passthrough auth with custom authorization headers",
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/custom")
|
||||
def create_custom_tool(
|
||||
tool_data: CustomToolCreate,
|
||||
@@ -58,7 +48,6 @@ def create_custom_tool(
|
||||
user: User | None = Depends(current_admin_user),
|
||||
) -> ToolSnapshot:
|
||||
_validate_tool_definition(tool_data.definition)
|
||||
_validate_auth_settings(tool_data)
|
||||
tool = create_tool(
|
||||
name=tool_data.name,
|
||||
description=tool_data.description,
|
||||
@@ -66,7 +55,6 @@ def create_custom_tool(
|
||||
custom_headers=tool_data.custom_headers,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
passthrough_auth=tool_data.passthrough_auth,
|
||||
)
|
||||
return ToolSnapshot.from_model(tool)
|
||||
|
||||
@@ -80,7 +68,6 @@ def update_custom_tool(
|
||||
) -> ToolSnapshot:
|
||||
if tool_data.definition:
|
||||
_validate_tool_definition(tool_data.definition)
|
||||
_validate_auth_settings(tool_data)
|
||||
updated_tool = update_tool(
|
||||
tool_id=tool_id,
|
||||
name=tool_data.name,
|
||||
@@ -89,7 +76,6 @@ def update_custom_tool(
|
||||
custom_headers=tool_data.custom_headers,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
passthrough_auth=tool_data.passthrough_auth,
|
||||
)
|
||||
return ToolSnapshot.from_model(updated_tool)
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ class ToolSnapshot(BaseModel):
|
||||
display_name: str
|
||||
in_code_tool_id: str | None
|
||||
custom_headers: list[Any] | None
|
||||
passthrough_auth: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, tool: Tool) -> "ToolSnapshot":
|
||||
@@ -25,7 +24,6 @@ class ToolSnapshot(BaseModel):
|
||||
display_name=tool.display_name or tool.name,
|
||||
in_code_tool_id=tool.in_code_tool_id,
|
||||
custom_headers=tool.custom_headers,
|
||||
passthrough_auth=tool.passthrough_auth,
|
||||
)
|
||||
|
||||
|
||||
@@ -39,7 +37,6 @@ class CustomToolCreate(BaseModel):
|
||||
description: str | None = None
|
||||
definition: dict[str, Any]
|
||||
custom_headers: list[Header] | None = None
|
||||
passthrough_auth: bool
|
||||
|
||||
|
||||
class CustomToolUpdate(BaseModel):
|
||||
@@ -47,4 +44,3 @@ class CustomToolUpdate(BaseModel):
|
||||
description: str | None = None
|
||||
definition: dict[str, Any] | None = None
|
||||
custom_headers: list[Header] | None = None
|
||||
passthrough_auth: bool | None = None
|
||||
|
||||
@@ -714,6 +714,7 @@ def update_user_pinned_assistants(
|
||||
store = get_kv_store()
|
||||
no_auth_user = fetch_no_auth_user(store)
|
||||
no_auth_user.preferences.pinned_assistants = ordered_assistant_ids
|
||||
print("ordered_assistant_ids", ordered_assistant_ids)
|
||||
set_no_auth_user_preferences(store, no_auth_user.preferences)
|
||||
return
|
||||
else:
|
||||
|
||||
@@ -5,6 +5,7 @@ import os
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -14,6 +15,7 @@ from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import UploadFile
|
||||
from fastapi.responses import StreamingResponse
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -593,6 +595,21 @@ def seed_chat_from_slack(
|
||||
"""File upload"""
|
||||
|
||||
|
||||
def convert_to_jpeg(file: UploadFile) -> Tuple[io.BytesIO, str]:
|
||||
try:
|
||||
with Image.open(file.file) as img:
|
||||
if img.mode != "RGB":
|
||||
img = img.convert("RGB")
|
||||
jpeg_io = io.BytesIO()
|
||||
img.save(jpeg_io, format="JPEG", quality=85)
|
||||
jpeg_io.seek(0)
|
||||
return jpeg_io, "image/jpeg"
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Failed to convert image: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/file")
|
||||
def upload_files_for_chat(
|
||||
files: list[UploadFile],
|
||||
@@ -628,9 +645,6 @@ def upload_files_for_chat(
|
||||
)
|
||||
|
||||
for file in files:
|
||||
if not file.content_type:
|
||||
raise HTTPException(status_code=400, detail="File content type is required")
|
||||
|
||||
if file.content_type not in allowed_content_types:
|
||||
if file.content_type in image_content_types:
|
||||
error_detail = "Unsupported image file type. Supported image types include .jpg, .jpeg, .png, .webp."
|
||||
@@ -662,27 +676,22 @@ def upload_files_for_chat(
|
||||
|
||||
file_info: list[tuple[str, str | None, ChatFileType]] = []
|
||||
for file in files:
|
||||
file_type = (
|
||||
ChatFileType.IMAGE
|
||||
if file.content_type in image_content_types
|
||||
else ChatFileType.CSV
|
||||
if file.content_type in csv_content_types
|
||||
else ChatFileType.DOC
|
||||
if file.content_type in document_content_types
|
||||
else ChatFileType.PLAIN_TEXT
|
||||
)
|
||||
|
||||
if file_type == ChatFileType.IMAGE:
|
||||
file_content = file.file
|
||||
# NOTE: Image conversion to JPEG used to be enforced here.
|
||||
# This was removed to:
|
||||
# 1. Preserve original file content for downloads
|
||||
# 2. Maintain transparency in formats like PNG
|
||||
# 3. Ameliorate issue with file conversion
|
||||
else:
|
||||
if file.content_type in image_content_types:
|
||||
file_type = ChatFileType.IMAGE
|
||||
# Convert image to JPEG
|
||||
file_content, new_content_type = convert_to_jpeg(file)
|
||||
elif file.content_type in csv_content_types:
|
||||
file_type = ChatFileType.CSV
|
||||
file_content = io.BytesIO(file.file.read())
|
||||
|
||||
new_content_type = file.content_type
|
||||
new_content_type = file.content_type or ""
|
||||
elif file.content_type in document_content_types:
|
||||
file_type = ChatFileType.DOC
|
||||
file_content = io.BytesIO(file.file.read())
|
||||
new_content_type = file.content_type or ""
|
||||
else:
|
||||
file_type = ChatFileType.PLAIN_TEXT
|
||||
file_content = io.BytesIO(file.file.read())
|
||||
new_content_type = file.content_type or ""
|
||||
|
||||
# store the file (now JPEG for images)
|
||||
file_id = str(uuid.uuid4())
|
||||
|
||||
@@ -1,269 +0,0 @@
|
||||
import time
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Form
|
||||
from fastapi import HTTPException
|
||||
from fastapi import UploadFile
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.connector import create_connector
|
||||
from onyx.db.connector_credential_pair import add_credential_to_connector
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserFolder
|
||||
from onyx.db.user_documents import create_user_files
|
||||
from onyx.server.documents.models import ConnectorBase
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.server.documents.models import FileUploadResponse
|
||||
from onyx.server.user_documents.models import FileResponse
|
||||
from onyx.server.user_documents.models import FileSystemResponse
|
||||
from onyx.server.user_documents.models import FolderDetailResponse
|
||||
from onyx.server.user_documents.models import FolderResponse
|
||||
from onyx.server.user_documents.models import MessageResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class FolderCreationRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
|
||||
|
||||
@router.post("/user/folder")
|
||||
def create_folder(
|
||||
request: FolderCreationRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FolderDetailResponse:
|
||||
new_folder = UserFolder(
|
||||
user_id=user.id if user else None,
|
||||
name=request.name,
|
||||
description=request.description,
|
||||
)
|
||||
db_session.add(new_folder)
|
||||
db_session.commit()
|
||||
return FolderDetailResponse(
|
||||
id=new_folder.id,
|
||||
name=new_folder.name,
|
||||
description=new_folder.description,
|
||||
files=[],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/user/folder",
|
||||
)
|
||||
def get_folders(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> List[FolderResponse]:
|
||||
user_id = user.id if user else None
|
||||
folders = db_session.query(UserFolder).filter(UserFolder.user_id == user_id).all()
|
||||
return [FolderResponse.from_model(folder) for folder in folders]
|
||||
|
||||
|
||||
@router.get("/user/folder/{folder_id}")
|
||||
def get_folder(
|
||||
folder_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FolderDetailResponse:
|
||||
user_id = user.id if user else None
|
||||
folder = (
|
||||
db_session.query(UserFolder)
|
||||
.filter(UserFolder.id == folder_id, UserFolder.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="Folder not found")
|
||||
|
||||
return FolderDetailResponse(
|
||||
id=folder.id,
|
||||
name=folder.name,
|
||||
files=[FileResponse.from_model(file) for file in folder.files],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/user/file/upload")
|
||||
def upload_user_files(
|
||||
files: List[UploadFile] = File(...),
|
||||
folder_id: int | None = Form(None),
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FileUploadResponse:
|
||||
file_upload_response = FileUploadResponse(
|
||||
file_paths=create_user_files(files, folder_id, user, db_session).file_paths
|
||||
)
|
||||
for path in file_upload_response.file_paths:
|
||||
connector_base = ConnectorBase(
|
||||
name=f"UserFile-{int(time.time())}",
|
||||
source=DocumentSource.FILE,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={
|
||||
"file_locations": [path],
|
||||
},
|
||||
refresh_freq=None,
|
||||
prune_freq=None,
|
||||
indexing_start=None,
|
||||
)
|
||||
connector = create_connector(
|
||||
db_session=db_session,
|
||||
connector_data=connector_base,
|
||||
)
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={},
|
||||
admin_public=True,
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=True,
|
||||
groups=[],
|
||||
name=f"UserFileCredential-{int(time.time())}",
|
||||
)
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
|
||||
add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
cc_pair_name=f"UserFileCCPair-{int(time.time())}",
|
||||
access_type=AccessType.PUBLIC,
|
||||
auto_sync_options=None,
|
||||
groups=[],
|
||||
)
|
||||
|
||||
# TODO: functional document indexing
|
||||
# trigger_document_indexing(db_session, user.id)
|
||||
return file_upload_response
|
||||
|
||||
|
||||
@router.put("/user/folder/{folder_id}")
|
||||
def update_folder(
|
||||
folder_id: int,
|
||||
name: str,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FolderDetailResponse:
|
||||
user_id = user.id if user else None
|
||||
folder = (
|
||||
db_session.query(UserFolder)
|
||||
.filter(UserFolder.id == folder_id, UserFolder.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="Folder not found")
|
||||
folder.name = name
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return FolderDetailResponse(
|
||||
id=folder.id,
|
||||
name=folder.name,
|
||||
files=[FileResponse.from_model(file) for file in folder.files],
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/user/folder/{folder_id}")
|
||||
def delete_folder(
|
||||
folder_id: int,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> MessageResponse:
|
||||
user_id = user.id if user else None
|
||||
folder = (
|
||||
db_session.query(UserFolder)
|
||||
.filter(UserFolder.id == folder_id, UserFolder.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not folder:
|
||||
raise HTTPException(status_code=404, detail="Folder not found")
|
||||
db_session.delete(folder)
|
||||
db_session.commit()
|
||||
return MessageResponse(message="Folder deleted successfully")
|
||||
|
||||
|
||||
@router.delete("/user/file/{file_id}")
|
||||
def delete_file(
|
||||
file_id: int,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> MessageResponse:
|
||||
user_id = user.id if user else None
|
||||
file = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id == file_id, UserFile.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not file:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
db_session.delete(file)
|
||||
db_session.commit()
|
||||
return MessageResponse(message="File deleted successfully")
|
||||
|
||||
|
||||
class FileMoveRequest(BaseModel):
|
||||
file_id: int
|
||||
new_folder_id: int | None
|
||||
|
||||
|
||||
@router.put("/user/file/{file_id}/move")
|
||||
def move_file(
|
||||
request: FileMoveRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FileResponse:
|
||||
user_id = user.id if user else None
|
||||
file = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id == request.file_id, UserFile.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not file:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
file.folder_id = request.new_folder_id
|
||||
db_session.commit()
|
||||
return FileResponse.from_model(file)
|
||||
|
||||
|
||||
@router.get("/user/file-system")
|
||||
def get_file_system(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FileSystemResponse:
|
||||
user_id = user.id if user else None
|
||||
folders = db_session.query(UserFolder).filter(UserFolder.user_id == user_id).all()
|
||||
files = db_session.query(UserFile).filter(UserFile.user_id == user_id).all()
|
||||
return FileSystemResponse(
|
||||
folders=[FolderResponse.from_model(folder) for folder in folders],
|
||||
files=[FileResponse.from_model(file) for file in files],
|
||||
)
|
||||
|
||||
|
||||
@router.put("/user/file/{file_id}/rename")
|
||||
def rename_file(
|
||||
file_id: int,
|
||||
name: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FileResponse:
|
||||
user_id = user.id if user else None
|
||||
file = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id == file_id, UserFile.user_id == user_id)
|
||||
.first()
|
||||
)
|
||||
if not file:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
file.name = name
|
||||
db_session.commit()
|
||||
return FileResponse.from_model(file)
|
||||
@@ -1,49 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserFolder
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class FolderResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, model: UserFolder) -> "FolderResponse":
|
||||
return cls(id=model.id, name=model.name, description=model.description)
|
||||
|
||||
|
||||
class FileResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
document_id: str
|
||||
folder_id: int | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, model: UserFile) -> "FileResponse":
|
||||
return cls(
|
||||
id=model.id,
|
||||
name=model.name,
|
||||
folder_id=model.folder_id,
|
||||
document_id=model.document_id,
|
||||
)
|
||||
|
||||
|
||||
class FolderDetailResponse(FolderResponse):
|
||||
files: List[FileResponse]
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class FileSystemResponse(BaseModel):
|
||||
folders: list[FolderResponse]
|
||||
files: list[FileResponse]
|
||||
@@ -104,10 +104,13 @@ def load_builtin_tools(db_session: Session) -> None:
|
||||
logger.notice("All built-in tools are loaded/verified.")
|
||||
|
||||
|
||||
def get_search_tool(db_session: Session) -> ToolDBModel | None:
|
||||
def auto_add_search_tool_to_personas(db_session: Session) -> None:
|
||||
"""
|
||||
Retrieves for the SearchTool from the BUILT_IN_TOOLS list.
|
||||
Automatically adds the SearchTool to all Persona objects in the database that have
|
||||
`num_chunks` either unset or set to a value that isn't 0. This is done to migrate
|
||||
Persona objects that were created before the concept of Tools were added.
|
||||
"""
|
||||
# Fetch the SearchTool from the database based on in_code_tool_id from BUILT_IN_TOOLS
|
||||
search_tool_id = next(
|
||||
(
|
||||
tool["in_code_tool_id"]
|
||||
@@ -116,7 +119,6 @@ def get_search_tool(db_session: Session) -> ToolDBModel | None:
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if not search_tool_id:
|
||||
raise RuntimeError("SearchTool not found in the BUILT_IN_TOOLS list.")
|
||||
|
||||
@@ -124,18 +126,6 @@ def get_search_tool(db_session: Session) -> ToolDBModel | None:
|
||||
select(ToolDBModel).where(ToolDBModel.in_code_tool_id == search_tool_id)
|
||||
).scalar_one_or_none()
|
||||
|
||||
return search_tool
|
||||
|
||||
|
||||
def auto_add_search_tool_to_personas(db_session: Session) -> None:
|
||||
"""
|
||||
Automatically adds the SearchTool to all Persona objects in the database that have
|
||||
`num_chunks` either unset or set to a value that isn't 0. This is done to migrate
|
||||
Persona objects that were created before the concept of Tools were added.
|
||||
"""
|
||||
# Fetch the SearchTool from the database based on in_code_tool_id from BUILT_IN_TOOLS
|
||||
search_tool = get_search_tool(db_session)
|
||||
|
||||
if not search_tool:
|
||||
raise RuntimeError("SearchTool not found in the database.")
|
||||
|
||||
|
||||
@@ -146,11 +146,6 @@ def construct_tools(
|
||||
"""Constructs tools based on persona configuration and available APIs"""
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
|
||||
# Get user's OAuth token if available
|
||||
user_oauth_token = None
|
||||
if user and user.oauth_accounts:
|
||||
user_oauth_token = user.oauth_accounts[0].access_token
|
||||
|
||||
for db_tool_model in persona.tools:
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(
|
||||
@@ -241,9 +236,6 @@ def construct_tools(
|
||||
custom_tool_config.additional_headers or {}
|
||||
)
|
||||
),
|
||||
user_oauth_token=(
|
||||
user_oauth_token if db_tool_model.passthrough_auth else None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -80,12 +80,10 @@ class CustomTool(BaseTool):
|
||||
method_spec: MethodSpec,
|
||||
base_url: str,
|
||||
custom_headers: list[HeaderItemDict] | None = None,
|
||||
user_oauth_token: str | None = None,
|
||||
) -> None:
|
||||
self._base_url = base_url
|
||||
self._method_spec = method_spec
|
||||
self._tool_definition = self._method_spec.to_tool_definition()
|
||||
self._user_oauth_token = user_oauth_token
|
||||
|
||||
self._name = self._method_spec.name
|
||||
self._description = self._method_spec.summary
|
||||
@@ -93,20 +91,6 @@ class CustomTool(BaseTool):
|
||||
header_list_to_header_dict(custom_headers) if custom_headers else {}
|
||||
)
|
||||
|
||||
# Check for both Authorization header and OAuth token
|
||||
has_auth_header = any(
|
||||
key.lower() == "authorization" for key in self.headers.keys()
|
||||
)
|
||||
if has_auth_header and self._user_oauth_token:
|
||||
logger.warning(
|
||||
f"Tool '{self._name}' has both an Authorization "
|
||||
"header and OAuth token set. This is likely a configuration "
|
||||
"error as the OAuth token will override the custom header."
|
||||
)
|
||||
|
||||
if self._user_oauth_token:
|
||||
self.headers["Authorization"] = f"Bearer {self._user_oauth_token}"
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
@@ -364,7 +348,6 @@ def build_custom_tools_from_openapi_schema_and_headers(
|
||||
openapi_schema: dict[str, Any],
|
||||
custom_headers: list[HeaderItemDict] | None = None,
|
||||
dynamic_schema_info: DynamicSchemaInfo | None = None,
|
||||
user_oauth_token: str | None = None,
|
||||
) -> list[CustomTool]:
|
||||
if dynamic_schema_info:
|
||||
# Process dynamic schema information
|
||||
@@ -383,13 +366,7 @@ def build_custom_tools_from_openapi_schema_and_headers(
|
||||
url = openapi_to_url(openapi_schema)
|
||||
method_specs = openapi_to_method_specs(openapi_schema)
|
||||
return [
|
||||
CustomTool(
|
||||
method_spec,
|
||||
url,
|
||||
custom_headers,
|
||||
user_oauth_token=user_oauth_token,
|
||||
)
|
||||
for method_spec in method_specs
|
||||
CustomTool(method_spec, url, custom_headers) for method_spec in method_specs
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -16,7 +16,6 @@ from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.prompts.constants import GENERAL_SEP_PAT
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolResponse
|
||||
@@ -317,22 +316,12 @@ class ImageGenerationTool(Tool):
|
||||
for img in img_generation_response
|
||||
if img.image_data is not None
|
||||
]
|
||||
|
||||
user_prompt = build_image_generation_user_prompt(
|
||||
query=prompt_builder.get_user_message_content(),
|
||||
supports_image_input=model_supports_image_input(
|
||||
prompt_builder.llm_config.model_name,
|
||||
prompt_builder.llm_config.model_provider,
|
||||
),
|
||||
prompts=[
|
||||
prompt
|
||||
for response in img_generation_response
|
||||
for prompt in response.revised_prompt
|
||||
],
|
||||
img_urls=img_urls,
|
||||
b64_imgs=b64_imgs,
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=prompt_builder.get_user_message_content(),
|
||||
img_urls=img_urls,
|
||||
b64_imgs=b64_imgs,
|
||||
)
|
||||
)
|
||||
|
||||
prompt_builder.update_user_prompt(user_prompt)
|
||||
|
||||
return prompt_builder
|
||||
|
||||
@@ -9,34 +9,16 @@ You have just created the attached images in response to the following query: "{
|
||||
Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists.
|
||||
"""
|
||||
|
||||
IMG_GENERATION_SUMMARY_PROMPT_NO_IMAGES = """
|
||||
You have generated images based on the following query: "{query}".
|
||||
The prompts used to create these images were: {prompts}
|
||||
|
||||
Describe the two images you generated, summarizing the key elements and content in a sentence or two.
|
||||
Be specific about what was generated and respond as if you have seen them,
|
||||
without including any disclaimers or speculations.
|
||||
"""
|
||||
|
||||
|
||||
def build_image_generation_user_prompt(
|
||||
query: str,
|
||||
supports_image_input: bool,
|
||||
img_urls: list[str] | None = None,
|
||||
b64_imgs: list[str] | None = None,
|
||||
prompts: list[str] | None = None,
|
||||
) -> HumanMessage:
|
||||
if supports_image_input:
|
||||
return HumanMessage(
|
||||
content=build_content_with_imgs(
|
||||
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
|
||||
b64_imgs=b64_imgs,
|
||||
img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return HumanMessage(
|
||||
content=IMG_GENERATION_SUMMARY_PROMPT_NO_IMAGES.format(
|
||||
query=query, prompts=prompts
|
||||
).strip()
|
||||
return HumanMessage(
|
||||
content=build_content_with_imgs(
|
||||
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
|
||||
b64_imgs=b64_imgs,
|
||||
img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -35,16 +35,12 @@ class LongTermLogger:
|
||||
def _cleanup_old_files(self, category_path: Path) -> None:
|
||||
try:
|
||||
files = sorted(
|
||||
[f for f in category_path.glob("*.json")],
|
||||
[f for f in category_path.glob("*.json") if f.is_file()],
|
||||
key=lambda x: x.stat().st_mtime, # Sort by modification time
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# Delete oldest files that exceed the limit
|
||||
for file in files[self.max_files_per_category :]:
|
||||
if not file.is_file():
|
||||
logger.debug(f"File already deleted: {file}")
|
||||
continue
|
||||
try:
|
||||
file.unlink()
|
||||
except Exception as e:
|
||||
|
||||
@@ -11,7 +11,7 @@ from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED
|
||||
from onyx.configs.constants import KV_CUSTOMER_UUID_KEY
|
||||
from onyx.configs.constants import KV_INSTANCE_DOMAIN_KEY
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.milestone import create_milestone_if_not_exists
|
||||
from onyx.db.models import User
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -41,7 +41,7 @@ def _get_or_generate_customer_id_mt(tenant_id: str) -> str:
|
||||
return str(uuid.uuid5(uuid.NAMESPACE_X500, tenant_id))
|
||||
|
||||
|
||||
def get_or_generate_uuid(tenant_id: str | None) -> str:
|
||||
def get_or_generate_uuid(tenant_id: str | None = None) -> str:
|
||||
# TODO: split out the whole "instance UUID" generation logic into a separate
|
||||
# utility function. Telemetry should not be aware at all of how the UUID is
|
||||
# generated/stored.
|
||||
@@ -52,7 +52,7 @@ def get_or_generate_uuid(tenant_id: str | None) -> str:
|
||||
if _CACHED_UUID is not None:
|
||||
return _CACHED_UUID
|
||||
|
||||
kv_store = get_kv_store(tenant_id=tenant_id)
|
||||
kv_store = get_kv_store()
|
||||
|
||||
try:
|
||||
_CACHED_UUID = cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY))
|
||||
@@ -63,18 +63,18 @@ def get_or_generate_uuid(tenant_id: str | None) -> str:
|
||||
return _CACHED_UUID
|
||||
|
||||
|
||||
def _get_or_generate_instance_domain(tenant_id: str | None = None) -> str | None: #
|
||||
def _get_or_generate_instance_domain() -> str | None: #
|
||||
global _CACHED_INSTANCE_DOMAIN
|
||||
|
||||
if _CACHED_INSTANCE_DOMAIN is not None:
|
||||
return _CACHED_INSTANCE_DOMAIN
|
||||
|
||||
kv_store = get_kv_store(tenant_id=tenant_id)
|
||||
kv_store = get_kv_store()
|
||||
|
||||
try:
|
||||
_CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
first_user = db_session.query(User).first()
|
||||
if first_user:
|
||||
_CACHED_INSTANCE_DOMAIN = first_user.email.split("@")[-1]
|
||||
@@ -94,16 +94,16 @@ def optional_telemetry(
|
||||
if DISABLE_TELEMETRY:
|
||||
return
|
||||
|
||||
tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
try:
|
||||
|
||||
def telemetry_logic() -> None:
|
||||
try:
|
||||
customer_uuid = (
|
||||
_get_or_generate_customer_id_mt(tenant_id)
|
||||
_get_or_generate_customer_id_mt(
|
||||
tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
)
|
||||
if MULTI_TENANT
|
||||
else get_or_generate_uuid(tenant_id)
|
||||
else get_or_generate_uuid()
|
||||
)
|
||||
payload = {
|
||||
"data": data,
|
||||
@@ -115,15 +115,12 @@ def optional_telemetry(
|
||||
"is_cloud": MULTI_TENANT,
|
||||
}
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
payload["instance_domain"] = _get_or_generate_instance_domain(
|
||||
tenant_id
|
||||
)
|
||||
payload["instance_domain"] = _get_or_generate_instance_domain()
|
||||
requests.post(
|
||||
_DANSWER_TELEMETRY_ENDPOINT,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=payload,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# This way it silences all thread level logging as well
|
||||
pass
|
||||
|
||||
@@ -72,19 +72,6 @@ def run_jobs() -> None:
|
||||
"--queues=connector_indexing",
|
||||
]
|
||||
|
||||
cmd_worker_monitoring = [
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--pool=threads",
|
||||
"--concurrency=1",
|
||||
"--prefetch-multiplier=1",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"--queues=monitoring",
|
||||
]
|
||||
|
||||
cmd_beat = [
|
||||
"celery",
|
||||
"-A",
|
||||
@@ -110,13 +97,6 @@ def run_jobs() -> None:
|
||||
cmd_worker_indexing, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
|
||||
worker_monitoring_process = subprocess.Popen(
|
||||
cmd_worker_monitoring,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
|
||||
beat_process = subprocess.Popen(
|
||||
cmd_beat, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
||||
)
|
||||
@@ -134,23 +114,18 @@ def run_jobs() -> None:
|
||||
worker_indexing_thread = threading.Thread(
|
||||
target=monitor_process, args=("INDEX", worker_indexing_process)
|
||||
)
|
||||
worker_monitoring_thread = threading.Thread(
|
||||
target=monitor_process, args=("MONITORING", worker_monitoring_process)
|
||||
)
|
||||
beat_thread = threading.Thread(target=monitor_process, args=("BEAT", beat_process))
|
||||
|
||||
worker_primary_thread.start()
|
||||
worker_light_thread.start()
|
||||
worker_heavy_thread.start()
|
||||
worker_indexing_thread.start()
|
||||
worker_monitoring_thread.start()
|
||||
beat_thread.start()
|
||||
|
||||
worker_primary_thread.join()
|
||||
worker_light_thread.join()
|
||||
worker_heavy_thread.join()
|
||||
worker_indexing_thread.join()
|
||||
worker_monitoring_thread.join()
|
||||
beat_thread.join()
|
||||
|
||||
|
||||
|
||||
@@ -197,7 +197,7 @@ ai_platform_doc = SeedPresaveDocument(
|
||||
)
|
||||
|
||||
customer_support_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/support",
|
||||
url="https://docs.onyx.app/more/use_cases/customer_support",
|
||||
title=customer_support_title,
|
||||
content=customer_support,
|
||||
title_embedding=model.encode(f"search_document: {customer_support_title}"),
|
||||
|
||||
@@ -1,30 +1,20 @@
|
||||
# Tool to run helpful operations on Redis in production
|
||||
# This is targeted for internal usage and may not have all the necessary parameters
|
||||
# for general usage across custom deployments
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from logging import getLogger
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from redis import Redis
|
||||
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||
from onyx.configs.app_configs import REDIS_DB_NUMBER
|
||||
from onyx.configs.app_configs import REDIS_HOST
|
||||
from onyx.configs.app_configs import REDIS_PASSWORD
|
||||
from onyx.configs.app_configs import REDIS_PORT
|
||||
from onyx.configs.app_configs import REDIS_SSL
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import RedisPool
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# Tool to run helpful operations on Redis in production
|
||||
# This is targeted for internal usage and may not have all the necessary parameters
|
||||
# for general usage across custom deployments
|
||||
|
||||
# Configure the logger
|
||||
logging.basicConfig(
|
||||
@@ -39,18 +29,6 @@ SCAN_ITER_COUNT = 10000
|
||||
BATCH_DEFAULT = 1000
|
||||
|
||||
|
||||
def get_user_id(user_email: str) -> tuple[UUID, str]:
|
||||
tenant_id = (
|
||||
get_tenant_id_for_email(user_email) if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as session:
|
||||
user = get_user_by_email(user_email, session)
|
||||
if user is None:
|
||||
raise ValueError(f"User not found for email: {user_email}")
|
||||
return user.id, tenant_id
|
||||
|
||||
|
||||
def onyx_redis(
|
||||
command: str,
|
||||
batch: int,
|
||||
@@ -59,14 +37,13 @@ def onyx_redis(
|
||||
port: int,
|
||||
db: int,
|
||||
password: str | None,
|
||||
user_email: str | None = None,
|
||||
) -> int:
|
||||
pool = RedisPool.create_pool(
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
password=password if password else "",
|
||||
ssl=REDIS_SSL,
|
||||
ssl=True,
|
||||
ssl_cert_reqs="optional",
|
||||
ssl_ca_certs=None,
|
||||
)
|
||||
@@ -95,25 +72,6 @@ def onyx_redis(
|
||||
return purge_by_match_and_type(
|
||||
"*connectorsync:vespa_syncing*", "string", batch, dry_run, r
|
||||
)
|
||||
elif command == "get_user_token":
|
||||
if not user_email:
|
||||
logger.error("You must specify --user-email with get_user_token")
|
||||
return 1
|
||||
token_key = get_user_token_from_redis(r, user_email)
|
||||
if token_key:
|
||||
print(f"Token key for user {user_email}: {token_key}")
|
||||
return 0
|
||||
else:
|
||||
print(f"No token found for user {user_email}")
|
||||
return 2
|
||||
elif command == "delete_user_token":
|
||||
if not user_email:
|
||||
logger.error("You must specify --user-email with delete_user_token")
|
||||
return 1
|
||||
if delete_user_token_from_redis(r, user_email, dry_run):
|
||||
return 0
|
||||
else:
|
||||
return 2
|
||||
else:
|
||||
pass
|
||||
|
||||
@@ -176,104 +134,6 @@ def purge_by_match_and_type(
|
||||
return 0
|
||||
|
||||
|
||||
def get_user_token_from_redis(r: Redis, user_email: str) -> str | None:
|
||||
"""
|
||||
Scans Redis keys for a user token that matches user_email or user_id fields.
|
||||
Returns the token key if found, else None.
|
||||
"""
|
||||
user_id, tenant_id = get_user_id(user_email)
|
||||
|
||||
# Scan for keys matching the auth key prefix
|
||||
auth_keys = r.scan_iter(f"{REDIS_AUTH_KEY_PREFIX}*", count=SCAN_ITER_COUNT)
|
||||
|
||||
matching_key = None
|
||||
|
||||
for key in auth_keys:
|
||||
key_str = key.decode("utf-8")
|
||||
jwt_token = r.get(key_str)
|
||||
|
||||
if not jwt_token:
|
||||
continue
|
||||
|
||||
try:
|
||||
jwt_token_str = (
|
||||
jwt_token.decode("utf-8")
|
||||
if isinstance(jwt_token, bytes)
|
||||
else str(jwt_token)
|
||||
)
|
||||
|
||||
if jwt_token_str.startswith("b'") and jwt_token_str.endswith("'"):
|
||||
jwt_token_str = jwt_token_str[2:-1] # Remove b'' wrapper
|
||||
|
||||
jwt_data = json.loads(jwt_token_str)
|
||||
if jwt_data.get("tenant_id") == tenant_id and str(
|
||||
jwt_data.get("sub")
|
||||
) == str(user_id):
|
||||
matching_key = key_str
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to decode JSON for key: {key_str}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing JWT for key: {key_str}. Error: {str(e)}")
|
||||
|
||||
if matching_key:
|
||||
return matching_key[len(REDIS_AUTH_KEY_PREFIX) :]
|
||||
return None
|
||||
|
||||
|
||||
def delete_user_token_from_redis(
|
||||
r: Redis, user_email: str, dry_run: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
Scans Redis keys for a user token matching user_email and deletes it if found.
|
||||
Returns True if something was deleted, otherwise False.
|
||||
"""
|
||||
user_id, tenant_id = get_user_id(user_email)
|
||||
|
||||
# Scan for keys matching the auth key prefix
|
||||
auth_keys = r.scan_iter(f"{REDIS_AUTH_KEY_PREFIX}*", count=SCAN_ITER_COUNT)
|
||||
matching_key = None
|
||||
|
||||
for key in auth_keys:
|
||||
key_str = key.decode("utf-8")
|
||||
jwt_token = r.get(key_str)
|
||||
|
||||
if not jwt_token:
|
||||
continue
|
||||
|
||||
try:
|
||||
jwt_token_str = (
|
||||
jwt_token.decode("utf-8")
|
||||
if isinstance(jwt_token, bytes)
|
||||
else str(jwt_token)
|
||||
)
|
||||
|
||||
if jwt_token_str.startswith("b'") and jwt_token_str.endswith("'"):
|
||||
jwt_token_str = jwt_token_str[2:-1] # Remove b'' wrapper
|
||||
|
||||
jwt_data = json.loads(jwt_token_str)
|
||||
if jwt_data.get("tenant_id") == tenant_id and str(
|
||||
jwt_data.get("sub")
|
||||
) == str(user_id):
|
||||
matching_key = key_str
|
||||
break
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to decode JSON for key: {key_str}")
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing JWT for key: {key_str}. Error: {str(e)}")
|
||||
|
||||
if matching_key:
|
||||
if dry_run:
|
||||
logger.info(f"(DRY-RUN) Would delete token key: {matching_key}")
|
||||
else:
|
||||
r.delete(matching_key)
|
||||
logger.info(f"Deleted token for user: {user_email}")
|
||||
return True
|
||||
else:
|
||||
logger.info(f"No token found for user: {user_email}")
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Onyx Redis Manager")
|
||||
parser.add_argument("--command", type=str, help="Operation to run", required=True)
|
||||
@@ -325,13 +185,6 @@ if __name__ == "__main__":
|
||||
required=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--user-email",
|
||||
type=str,
|
||||
help="User email for get or delete user token",
|
||||
required=False,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
exitcode = onyx_redis(
|
||||
command=args.command,
|
||||
@@ -341,6 +194,5 @@ if __name__ == "__main__":
|
||||
port=args.port,
|
||||
db=args.db,
|
||||
password=args.password,
|
||||
user_email=args.user_email,
|
||||
)
|
||||
sys.exit(exitcode)
|
||||
|
||||
@@ -1,544 +0,0 @@
|
||||
"""
|
||||
Vespa Debugging Tool!
|
||||
|
||||
Usage:
|
||||
python vespa_debug_tool.py --action <action> [options]
|
||||
|
||||
Actions:
|
||||
config : Print Vespa configuration
|
||||
connect : Check Vespa connectivity
|
||||
list_docs : List documents
|
||||
search : Search documents
|
||||
update : Update a document
|
||||
delete : Delete a document
|
||||
get_acls : Get document ACLs
|
||||
|
||||
Options:
|
||||
--tenant-id : Tenant ID
|
||||
--connector-id : Connector ID
|
||||
--n : Number of documents (default 10)
|
||||
--query : Search query
|
||||
--doc-id : Document ID
|
||||
--fields : Fields to update (JSON)
|
||||
|
||||
Example:
|
||||
python vespa_debug_tool.py --action list_docs --tenant-id my_tenant --connector-id 1 --n 5
|
||||
"""
|
||||
import argparse
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import and_
|
||||
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.document_index_utils import get_document_chunk_ids
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
||||
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
|
||||
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_SETS
|
||||
from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import SEARCH_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import SOURCE_TYPE
|
||||
from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import VESPA_APP_CONTAINER_URL
|
||||
from onyx.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DocumentFilter(BaseModel):
|
||||
# Document filter for link matching.
|
||||
link: str | None = None
|
||||
|
||||
|
||||
def build_vespa_filters(
|
||||
filters: IndexFilters,
|
||||
*,
|
||||
include_hidden: bool = False,
|
||||
remove_trailing_and: bool = False,
|
||||
) -> str:
|
||||
# Build a combined Vespa filter string from the given IndexFilters.
|
||||
def _build_or_filters(key: str, vals: list[str] | None) -> str:
|
||||
if vals is None:
|
||||
return ""
|
||||
valid_vals = [val for val in vals if val]
|
||||
if not key or not valid_vals:
|
||||
return ""
|
||||
eq_elems = [f'{key} contains "{elem}"' for elem in valid_vals]
|
||||
or_clause = " or ".join(eq_elems)
|
||||
return f"({or_clause})"
|
||||
|
||||
def _build_time_filter(
|
||||
cutoff: datetime | None,
|
||||
untimed_doc_cutoff: timedelta = timedelta(days=92),
|
||||
) -> str:
|
||||
if not cutoff:
|
||||
return ""
|
||||
include_untimed = datetime.now(timezone.utc) - untimed_doc_cutoff > cutoff
|
||||
cutoff_secs = int(cutoff.timestamp())
|
||||
if include_untimed:
|
||||
return f"!({DOC_UPDATED_AT} < {cutoff_secs})"
|
||||
return f"({DOC_UPDATED_AT} >= {cutoff_secs})"
|
||||
|
||||
filter_str = ""
|
||||
if not include_hidden:
|
||||
filter_str += f"AND !({HIDDEN}=true) "
|
||||
|
||||
if filters.tenant_id and MULTI_TENANT:
|
||||
filter_str += f'AND ({TENANT_ID} contains "{filters.tenant_id}") '
|
||||
|
||||
if filters.access_control_list is not None:
|
||||
acl_str = _build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list)
|
||||
if acl_str:
|
||||
filter_str += f"AND {acl_str} "
|
||||
|
||||
source_strs = (
|
||||
[s.value for s in filters.source_type] if filters.source_type else None
|
||||
)
|
||||
source_str = _build_or_filters(SOURCE_TYPE, source_strs)
|
||||
if source_str:
|
||||
filter_str += f"AND {source_str} "
|
||||
|
||||
tags = filters.tags
|
||||
if tags:
|
||||
tag_attributes = [tag.tag_key + INDEX_SEPARATOR + tag.tag_value for tag in tags]
|
||||
else:
|
||||
tag_attributes = None
|
||||
tag_str = _build_or_filters(METADATA_LIST, tag_attributes)
|
||||
if tag_str:
|
||||
filter_str += f"AND {tag_str} "
|
||||
|
||||
doc_set_str = _build_or_filters(DOCUMENT_SETS, filters.document_set)
|
||||
if doc_set_str:
|
||||
filter_str += f"AND {doc_set_str} "
|
||||
|
||||
time_filter = _build_time_filter(filters.time_cutoff)
|
||||
if time_filter:
|
||||
filter_str += f"AND {time_filter} "
|
||||
|
||||
if remove_trailing_and:
|
||||
while filter_str.endswith(" and "):
|
||||
filter_str = filter_str[:-5]
|
||||
while filter_str.endswith("AND "):
|
||||
filter_str = filter_str[:-4]
|
||||
|
||||
return filter_str.strip()
|
||||
|
||||
|
||||
def print_vespa_config() -> None:
|
||||
# Print Vespa configuration.
|
||||
logger.info("Printing Vespa configuration.")
|
||||
print(f"Vespa Application Endpoint: {VESPA_APPLICATION_ENDPOINT}")
|
||||
print(f"Vespa App Container URL: {VESPA_APP_CONTAINER_URL}")
|
||||
print(f"Vespa Search Endpoint: {SEARCH_ENDPOINT}")
|
||||
print(f"Vespa Document ID Endpoint: {DOCUMENT_ID_ENDPOINT}")
|
||||
|
||||
|
||||
def check_vespa_connectivity() -> None:
|
||||
# Check connectivity to Vespa endpoints.
|
||||
logger.info("Checking Vespa connectivity.")
|
||||
endpoints = [
|
||||
f"{VESPA_APPLICATION_ENDPOINT}/ApplicationStatus",
|
||||
f"{VESPA_APPLICATION_ENDPOINT}/tenant",
|
||||
f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/application/",
|
||||
f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/application/default",
|
||||
]
|
||||
|
||||
for endpoint in endpoints:
|
||||
try:
|
||||
with get_vespa_http_client() as client:
|
||||
response = client.get(endpoint)
|
||||
logger.info(
|
||||
f"Connected to Vespa at {endpoint}, status code {response.status_code}"
|
||||
)
|
||||
print(f"Successfully connected to Vespa at {endpoint}")
|
||||
print(f"Status code: {response.status_code}")
|
||||
print(f"Response: {response.text[:200]}...")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to Vespa at {endpoint}: {str(e)}")
|
||||
print(f"Failed to connect to Vespa at {endpoint}: {str(e)}")
|
||||
|
||||
print("Vespa connectivity check completed.")
|
||||
|
||||
|
||||
def get_vespa_info() -> Dict[str, Any]:
|
||||
# Get info about the default Vespa application.
|
||||
url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/application/default"
|
||||
with get_vespa_http_client() as client:
|
||||
response = client.get(url)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def get_index_name(tenant_id: str) -> str:
|
||||
# Return the index name for a given tenant.
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
if not search_settings:
|
||||
raise ValueError(f"No search settings found for tenant {tenant_id}")
|
||||
return search_settings.index_name
|
||||
|
||||
|
||||
def query_vespa(
|
||||
yql: str, tenant_id: Optional[str] = None, limit: int = 10
|
||||
) -> List[Dict[str, Any]]:
|
||||
# Perform a Vespa query using YQL syntax.
|
||||
filters = IndexFilters(tenant_id=tenant_id, access_control_list=[])
|
||||
filter_string = build_vespa_filters(filters, remove_trailing_and=True)
|
||||
full_yql = yql.strip()
|
||||
if filter_string:
|
||||
full_yql = f"{full_yql} {filter_string}"
|
||||
full_yql = f"{full_yql} limit {limit}"
|
||||
|
||||
params = {"yql": full_yql, "timeout": "10s"}
|
||||
search_request = SearchRequest(query="", limit=limit, offset=0)
|
||||
params.update(search_request.model_dump())
|
||||
|
||||
logger.info(f"Executing Vespa query: {full_yql}")
|
||||
with get_vespa_http_client() as client:
|
||||
response = client.get(SEARCH_ENDPOINT, params=params)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
documents = result.get("root", {}).get("children", [])
|
||||
logger.info(f"Found {len(documents)} documents from query.")
|
||||
return documents
|
||||
|
||||
|
||||
def get_first_n_documents(n: int = 10) -> List[Dict[str, Any]]:
|
||||
# Get the first n documents from any source.
|
||||
yql = "select * from sources * where true"
|
||||
return query_vespa(yql, limit=n)
|
||||
|
||||
|
||||
def print_documents(documents: List[Dict[str, Any]]) -> None:
|
||||
# Pretty-print a list of documents.
|
||||
for doc in documents:
|
||||
print(json.dumps(doc, indent=2))
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
def get_documents_for_tenant_connector(
|
||||
tenant_id: str, connector_id: int, n: int = 10
|
||||
) -> None:
|
||||
# Get and print documents for a specific tenant and connector.
|
||||
index_name = get_index_name(tenant_id)
|
||||
logger.info(
|
||||
f"Fetching documents for tenant={tenant_id}, connector_id={connector_id}"
|
||||
)
|
||||
yql = f"select * from sources {index_name} where true"
|
||||
documents = query_vespa(yql, tenant_id, limit=n)
|
||||
print(
|
||||
f"First {len(documents)} documents for tenant {tenant_id}, connector {connector_id}:"
|
||||
)
|
||||
print_documents(documents)
|
||||
|
||||
|
||||
def search_documents(
|
||||
tenant_id: str, connector_id: int, query: str, n: int = 10
|
||||
) -> None:
|
||||
# Search documents for a specific tenant and connector.
|
||||
index_name = get_index_name(tenant_id)
|
||||
logger.info(
|
||||
f"Searching documents for tenant={tenant_id}, connector_id={connector_id}, query='{query}'"
|
||||
)
|
||||
yql = f"select * from sources {index_name} where userInput(@query)"
|
||||
documents = query_vespa(yql, tenant_id, limit=n)
|
||||
print(f"Search results for query '{query}' in tenant {tenant_id}:")
|
||||
print_documents(documents)
|
||||
|
||||
|
||||
def update_document(
|
||||
tenant_id: str, connector_id: int, doc_id: str, fields: Dict[str, Any]
|
||||
) -> None:
|
||||
# Update a specific document.
|
||||
index_name = get_index_name(tenant_id)
|
||||
logger.info(
|
||||
f"Updating document doc_id={doc_id} in tenant={tenant_id}, connector_id={connector_id}"
|
||||
)
|
||||
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name) + f"/{doc_id}"
|
||||
update_request = {"fields": {k: {"assign": v} for k, v in fields.items()}}
|
||||
with get_vespa_http_client() as client:
|
||||
response = client.put(url, json=update_request)
|
||||
response.raise_for_status()
|
||||
logger.info(f"Document {doc_id} updated successfully.")
|
||||
print(f"Document {doc_id} updated successfully")
|
||||
|
||||
|
||||
def delete_document(tenant_id: str, connector_id: int, doc_id: str) -> None:
|
||||
# Delete a specific document.
|
||||
index_name = get_index_name(tenant_id)
|
||||
logger.info(
|
||||
f"Deleting document doc_id={doc_id} in tenant={tenant_id}, connector_id={connector_id}"
|
||||
)
|
||||
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name) + f"/{doc_id}"
|
||||
with get_vespa_http_client() as client:
|
||||
response = client.delete(url)
|
||||
response.raise_for_status()
|
||||
logger.info(f"Document {doc_id} deleted successfully.")
|
||||
print(f"Document {doc_id} deleted successfully")
|
||||
|
||||
|
||||
def list_documents(n: int = 10, tenant_id: Optional[str] = None) -> None:
|
||||
# List documents from any source, filtered by tenant if provided.
|
||||
logger.info(f"Listing up to {n} documents for tenant={tenant_id or 'ALL'}")
|
||||
yql = "select * from sources * where true"
|
||||
if tenant_id:
|
||||
yql += f" and tenant_id contains '{tenant_id}'"
|
||||
documents = query_vespa(yql, tenant_id=tenant_id, limit=n)
|
||||
print(f"Total documents found: {len(documents)}")
|
||||
logger.info(f"Total documents found: {len(documents)}")
|
||||
print(f"First {min(n, len(documents))} documents:")
|
||||
for doc in documents[:n]:
|
||||
print(json.dumps(doc, indent=2))
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
def get_document_and_chunk_counts(
|
||||
tenant_id: str, cc_pair_id: int, filter_doc: DocumentFilter | None = None
|
||||
) -> Dict[str, int]:
|
||||
# Return a dict mapping each document ID to its chunk count for a given connector.
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as session:
|
||||
doc_ids_data = (
|
||||
session.query(DocumentByConnectorCredentialPair.id, Document.link)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== ConnectorCredentialPair.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== ConnectorCredentialPair.credential_id,
|
||||
),
|
||||
)
|
||||
.join(Document, DocumentByConnectorCredentialPair.id == Document.id)
|
||||
.filter(ConnectorCredentialPair.id == cc_pair_id)
|
||||
.distinct()
|
||||
.all()
|
||||
)
|
||||
doc_ids = []
|
||||
for doc_id, link in doc_ids_data:
|
||||
if filter_doc and filter_doc.link:
|
||||
if link and filter_doc.link.lower() in link.lower():
|
||||
doc_ids.append(doc_id)
|
||||
else:
|
||||
doc_ids.append(doc_id)
|
||||
chunk_counts_data = (
|
||||
session.query(Document.id, Document.chunk_count)
|
||||
.filter(Document.id.in_(doc_ids))
|
||||
.all()
|
||||
)
|
||||
return {
|
||||
doc_id: chunk_count
|
||||
for doc_id, chunk_count in chunk_counts_data
|
||||
if chunk_count is not None
|
||||
}
|
||||
|
||||
|
||||
def get_chunk_ids_for_connector(
|
||||
tenant_id: str,
|
||||
cc_pair_id: int,
|
||||
index_name: str,
|
||||
filter_doc: DocumentFilter | None = None,
|
||||
) -> List[UUID]:
|
||||
# Return chunk IDs for a given connector.
|
||||
doc_id_to_new_chunk_cnt = get_document_and_chunk_counts(
|
||||
tenant_id, cc_pair_id, filter_doc
|
||||
)
|
||||
doc_infos: List[EnrichedDocumentIndexingInfo] = [
|
||||
VespaIndex.enrich_basic_chunk_info(
|
||||
index_name=index_name,
|
||||
http_client=get_vespa_http_client(),
|
||||
document_id=doc_id,
|
||||
previous_chunk_count=doc_id_to_new_chunk_cnt.get(doc_id, 0),
|
||||
new_chunk_count=0,
|
||||
)
|
||||
for doc_id in doc_id_to_new_chunk_cnt.keys()
|
||||
]
|
||||
chunk_ids = get_document_chunk_ids(
|
||||
enriched_document_info_list=doc_infos,
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=False,
|
||||
)
|
||||
if not isinstance(chunk_ids, list):
|
||||
raise ValueError(f"Expected list of chunk IDs, got {type(chunk_ids)}")
|
||||
return chunk_ids
|
||||
|
||||
|
||||
def get_document_acls(
|
||||
tenant_id: str,
|
||||
cc_pair_id: int,
|
||||
n: int | None = 10,
|
||||
filter_doc: DocumentFilter | None = None,
|
||||
) -> None:
|
||||
# Fetch document ACLs for the given tenant and connector pair.
|
||||
index_name = get_index_name(tenant_id)
|
||||
logger.info(
|
||||
f"Fetching document ACLs for tenant={tenant_id}, cc_pair_id={cc_pair_id}"
|
||||
)
|
||||
chunk_ids: List[UUID] = get_chunk_ids_for_connector(
|
||||
tenant_id, cc_pair_id, index_name, filter_doc
|
||||
)
|
||||
vespa_client = get_vespa_http_client()
|
||||
|
||||
target_ids = chunk_ids if n is None else chunk_ids[:n]
|
||||
logger.info(
|
||||
f"Found {len(chunk_ids)} chunk IDs, showing ACLs for {len(target_ids)}."
|
||||
)
|
||||
for doc_chunk_id in target_ids:
|
||||
document_url = (
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{str(doc_chunk_id)}"
|
||||
)
|
||||
response = vespa_client.get(document_url)
|
||||
if response.status_code == 200:
|
||||
fields = response.json().get("fields", {})
|
||||
document_id = fields.get("document_id") or fields.get(
|
||||
"documentid", "Unknown"
|
||||
)
|
||||
acls = fields.get("access_control_list", {})
|
||||
title = fields.get("title", "")
|
||||
source_type = fields.get("source_type", "")
|
||||
source_links_raw = fields.get("source_links", "{}")
|
||||
try:
|
||||
source_links = json.loads(source_links_raw)
|
||||
except json.JSONDecodeError:
|
||||
source_links = {}
|
||||
|
||||
print(f"Document Chunk ID: {doc_chunk_id}")
|
||||
print(f"Document ID: {document_id}")
|
||||
print(f"ACLs:\n{json.dumps(acls, indent=2)}")
|
||||
print(f"Source Links: {source_links}")
|
||||
print(f"Title: {title}")
|
||||
print(f"Source Type: {source_type}")
|
||||
if MULTI_TENANT:
|
||||
print(f"Tenant ID: {fields.get('tenant_id', 'N/A')}")
|
||||
print("-" * 80)
|
||||
else:
|
||||
logger.error(f"Failed to fetch document for chunk ID: {doc_chunk_id}")
|
||||
print(f"Failed to fetch document for chunk ID: {doc_chunk_id}")
|
||||
print(f"Status Code: {response.status_code}")
|
||||
print("-" * 80)
|
||||
|
||||
|
||||
class VespaDebugging:
|
||||
# Class for managing Vespa debugging actions.
|
||||
def __init__(self, tenant_id: str | None = None):
|
||||
self.tenant_id = POSTGRES_DEFAULT_SCHEMA if not tenant_id else tenant_id
|
||||
|
||||
def print_config(self) -> None:
|
||||
# Print Vespa config.
|
||||
print_vespa_config()
|
||||
|
||||
def check_connectivity(self) -> None:
|
||||
# Check Vespa connectivity.
|
||||
check_vespa_connectivity()
|
||||
|
||||
def list_documents(self, n: int = 10) -> None:
|
||||
# List documents for a tenant.
|
||||
list_documents(n, self.tenant_id)
|
||||
|
||||
def search_documents(self, connector_id: int, query: str, n: int = 10) -> None:
|
||||
# Search documents for a tenant and connector.
|
||||
search_documents(self.tenant_id, connector_id, query, n)
|
||||
|
||||
def update_document(
|
||||
self, connector_id: int, doc_id: str, fields: Dict[str, Any]
|
||||
) -> None:
|
||||
# Update a document.
|
||||
update_document(self.tenant_id, connector_id, doc_id, fields)
|
||||
|
||||
def delete_document(self, connector_id: int, doc_id: str) -> None:
|
||||
# Delete a document.
|
||||
delete_document(self.tenant_id, connector_id, doc_id)
|
||||
|
||||
def acls_by_link(self, cc_pair_id: int, link: str) -> None:
|
||||
# Get ACLs for a document matching a link.
|
||||
get_document_acls(
|
||||
self.tenant_id, cc_pair_id, n=None, filter_doc=DocumentFilter(link=link)
|
||||
)
|
||||
|
||||
def acls(self, cc_pair_id: int, n: int | None = 10) -> None:
|
||||
# Get ACLs for a connector.
|
||||
get_document_acls(self.tenant_id, cc_pair_id, n)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# Main CLI entry point.
|
||||
parser = argparse.ArgumentParser(description="Vespa debugging tool")
|
||||
parser.add_argument(
|
||||
"--action",
|
||||
choices=[
|
||||
"config",
|
||||
"connect",
|
||||
"list_docs",
|
||||
"search",
|
||||
"update",
|
||||
"delete",
|
||||
"get_acls",
|
||||
],
|
||||
required=True,
|
||||
help="Action to perform",
|
||||
)
|
||||
parser.add_argument("--tenant-id", help="Tenant ID")
|
||||
parser.add_argument("--connector-id", type=int, help="Connector ID")
|
||||
parser.add_argument(
|
||||
"--n", type=int, default=10, help="Number of documents to retrieve"
|
||||
)
|
||||
parser.add_argument("--query", help="Search query (for search action)")
|
||||
parser.add_argument("--doc-id", help="Document ID (for update and delete actions)")
|
||||
parser.add_argument(
|
||||
"--fields", help="Fields to update, in JSON format (for update)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
vespa_debug = VespaDebugging(args.tenant_id)
|
||||
|
||||
if args.action == "config":
|
||||
vespa_debug.print_config()
|
||||
elif args.action == "connect":
|
||||
vespa_debug.check_connectivity()
|
||||
elif args.action == "list_docs":
|
||||
vespa_debug.list_documents(args.n)
|
||||
elif args.action == "search":
|
||||
if not args.query or args.connector_id is None:
|
||||
parser.error("--query and --connector-id are required for search action")
|
||||
vespa_debug.search_documents(args.connector_id, args.query, args.n)
|
||||
elif args.action == "update":
|
||||
if not args.doc_id or not args.fields or args.connector_id is None:
|
||||
parser.error(
|
||||
"--doc-id, --fields, and --connector-id are required for update action"
|
||||
)
|
||||
fields = json.loads(args.fields)
|
||||
vespa_debug.update_document(args.connector_id, args.doc_id, fields)
|
||||
elif args.action == "delete":
|
||||
if not args.doc_id or args.connector_id is None:
|
||||
parser.error("--doc-id and --connector-id are required for delete action")
|
||||
vespa_debug.delete_document(args.connector_id, args.doc_id)
|
||||
elif args.action == "get_acls":
|
||||
if args.connector_id is None:
|
||||
parser.error("--connector-id is required for get_acls action")
|
||||
vespa_debug.acls(args.connector_id, args.n)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
backend/test
Normal file
0
backend/test
Normal file
@@ -45,7 +45,7 @@ def create_test_document(
|
||||
submitted_by: str,
|
||||
assignee: str,
|
||||
days_since_status_change: int | None,
|
||||
attachments: list[tuple[str, str]] | None = None,
|
||||
attachments: list | None = None,
|
||||
) -> Document:
|
||||
link_base = f"https://airtable.com/{os.environ['AIRTABLE_TEST_BASE_ID']}/{os.environ['AIRTABLE_TEST_TABLE_ID']}"
|
||||
sections = [
|
||||
@@ -60,11 +60,11 @@ def create_test_document(
|
||||
]
|
||||
|
||||
if attachments:
|
||||
for attachment_text, attachment_link in attachments:
|
||||
for attachment in attachments:
|
||||
sections.append(
|
||||
Section(
|
||||
text=f"Attachment:\n------------------------\n{attachment_text}\n------------------------",
|
||||
link=attachment_link,
|
||||
text=f"Attachment:\n------------------------\n{attachment}\n------------------------",
|
||||
link=f"{link_base}/{id}",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -142,13 +142,7 @@ def test_airtable_connector_basic(
|
||||
days_since_status_change=0,
|
||||
assignee="Chris Weaver (chris@onyx.app)",
|
||||
submitted_by="Chris Weaver (chris@onyx.app)",
|
||||
attachments=[
|
||||
(
|
||||
"Test.pdf:\ntesting!!!",
|
||||
# hard code link for now
|
||||
"https://airtable.com/appCXJqDFS4gea8tn/tblRxFQsTlBBZdRY1/viwVUEJjWPd8XYjh8/reccSlIA4pZEFxPBg/fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide",
|
||||
)
|
||||
],
|
||||
attachments=["Test.pdf:\ntesting!!!"],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -38,12 +38,6 @@ def get_credentials() -> dict[str, str]:
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"Cannot get Zendesk developer account to ensure zendesk account does not "
|
||||
"expire after 2 weeks"
|
||||
)
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"connector_fixture", ["zendesk_article_connector", "zendesk_ticket_connector"]
|
||||
)
|
||||
@@ -102,12 +96,6 @@ def test_zendesk_connector_basic(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason=(
|
||||
"Cannot get Zendesk developer account to ensure zendesk account does not "
|
||||
"expire after 2 weeks"
|
||||
)
|
||||
)
|
||||
def test_zendesk_connector_slim(zendesk_article_connector: ZendeskConnector) -> None:
|
||||
# Get full doc IDs
|
||||
all_full_doc_ids = set()
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.server.documents.models import IndexAttemptSnapshot
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.test_models import DATestIndexAttempt
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
class IndexAttemptManager:
|
||||
@staticmethod
|
||||
def create_test_index_attempts(
|
||||
num_attempts: int,
|
||||
cc_pair_id: int,
|
||||
from_beginning: bool = False,
|
||||
status: IndexingStatus = IndexingStatus.SUCCESS,
|
||||
new_docs_indexed: int = 10,
|
||||
total_docs_indexed: int = 10,
|
||||
docs_removed_from_index: int = 0,
|
||||
error_msg: str | None = None,
|
||||
base_time: datetime | None = None,
|
||||
) -> list[DATestIndexAttempt]:
|
||||
if base_time is None:
|
||||
base_time = datetime.now()
|
||||
|
||||
attempts = []
|
||||
with get_session_context_manager() as db_session:
|
||||
# Get the current search settings
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
if (
|
||||
not search_settings
|
||||
or search_settings.status != IndexModelStatus.PRESENT
|
||||
):
|
||||
raise ValueError("No current search settings found with PRESENT status")
|
||||
|
||||
for i in range(num_attempts):
|
||||
time_created = base_time - timedelta(hours=i)
|
||||
|
||||
index_attempt = IndexAttempt(
|
||||
connector_credential_pair_id=cc_pair_id,
|
||||
from_beginning=from_beginning,
|
||||
status=status,
|
||||
new_docs_indexed=new_docs_indexed,
|
||||
total_docs_indexed=total_docs_indexed,
|
||||
docs_removed_from_index=docs_removed_from_index,
|
||||
error_msg=error_msg,
|
||||
time_created=time_created,
|
||||
time_started=time_created,
|
||||
time_updated=time_created,
|
||||
search_settings_id=search_settings.id,
|
||||
)
|
||||
|
||||
db_session.add(index_attempt)
|
||||
db_session.flush() # To get the ID
|
||||
|
||||
attempts.append(
|
||||
DATestIndexAttempt(
|
||||
id=index_attempt.id,
|
||||
status=index_attempt.status,
|
||||
new_docs_indexed=index_attempt.new_docs_indexed,
|
||||
total_docs_indexed=index_attempt.total_docs_indexed,
|
||||
docs_removed_from_index=index_attempt.docs_removed_from_index,
|
||||
error_msg=index_attempt.error_msg,
|
||||
time_started=index_attempt.time_started,
|
||||
time_updated=index_attempt.time_updated,
|
||||
)
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return attempts
|
||||
|
||||
@staticmethod
|
||||
def get_index_attempt_page(
|
||||
cc_pair_id: int,
|
||||
page: int = 0,
|
||||
page_size: int = 10,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> PaginatedReturn[IndexAttemptSnapshot]:
|
||||
query_params: dict[str, str | int] = {
|
||||
"page_num": page,
|
||||
"page_size": page_size,
|
||||
}
|
||||
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/manage/admin/cc-pair/{cc_pair_id}/index-attempts?{urlencode(query_params, doseq=True)}",
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return PaginatedReturn(
|
||||
items=[IndexAttemptSnapshot(**item) for item in data["items"]],
|
||||
total_items=data["total_items"],
|
||||
)
|
||||
@@ -26,7 +26,6 @@ class PersonaManager:
|
||||
is_public: bool = True,
|
||||
llm_filter_extraction: bool = True,
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO,
|
||||
datetime_aware: bool = False,
|
||||
prompt_ids: list[int] | None = None,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
@@ -47,7 +46,6 @@ class PersonaManager:
|
||||
description=description,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
include_citations=include_citations,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
@@ -106,7 +104,6 @@ class PersonaManager:
|
||||
is_public: bool | None = None,
|
||||
llm_filter_extraction: bool | None = None,
|
||||
recency_bias: RecencyBiasSetting | None = None,
|
||||
datetime_aware: bool = False,
|
||||
prompt_ids: list[int] | None = None,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
@@ -124,7 +121,6 @@ class PersonaManager:
|
||||
description=description or persona.description,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
include_citations=include_citations,
|
||||
num_chunks=num_chunks or persona.num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter or persona.llm_relevance_filter,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
@@ -12,8 +10,6 @@ from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
from onyx.server.documents.models import IndexAttemptSnapshot
|
||||
from onyx.server.documents.models import IndexingStatus
|
||||
from onyx.server.documents.models import InputType
|
||||
|
||||
"""
|
||||
@@ -175,32 +171,3 @@ class DATestSettings(BaseModel):
|
||||
gpu_enabled: bool | None = None
|
||||
product_gating: DATestGatingType = DATestGatingType.NONE
|
||||
anonymous_user_enabled: bool | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DATestIndexAttempt:
|
||||
id: int
|
||||
status: IndexingStatus | None
|
||||
new_docs_indexed: int | None
|
||||
total_docs_indexed: int | None
|
||||
docs_removed_from_index: int | None
|
||||
error_msg: str | None
|
||||
time_started: datetime | None
|
||||
time_updated: datetime | None
|
||||
|
||||
@classmethod
|
||||
def from_index_attempt_snapshot(
|
||||
cls, index_attempt: IndexAttemptSnapshot
|
||||
) -> "DATestIndexAttempt":
|
||||
return cls(
|
||||
id=index_attempt.id,
|
||||
status=index_attempt.status,
|
||||
new_docs_indexed=index_attempt.new_docs_indexed,
|
||||
total_docs_indexed=index_attempt.total_docs_indexed,
|
||||
docs_removed_from_index=index_attempt.docs_removed_from_index,
|
||||
error_msg=index_attempt.error_msg,
|
||||
time_started=datetime.fromisoformat(index_attempt.time_started)
|
||||
if index_attempt.time_started
|
||||
else None,
|
||||
time_updated=datetime.fromisoformat(index_attempt.time_updated),
|
||||
)
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.db.models import IndexingStatus
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestIndexAttempt
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _verify_index_attempt_pagination(
|
||||
cc_pair_id: int,
|
||||
index_attempts: list[DATestIndexAttempt],
|
||||
page_size: int = 5,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
retrieved_attempts: list[int] = []
|
||||
last_time_started = None # Track the last time_started seen
|
||||
|
||||
for i in range(0, len(index_attempts), page_size):
|
||||
paginated_result = IndexAttemptManager.get_index_attempt_page(
|
||||
cc_pair_id=cc_pair_id,
|
||||
page=(i // page_size),
|
||||
page_size=page_size,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
|
||||
# Verify that the total items is equal to the length of the index attempts list
|
||||
assert paginated_result.total_items == len(index_attempts)
|
||||
# Verify that the number of items in the page is equal to the page size
|
||||
assert len(paginated_result.items) == min(page_size, len(index_attempts) - i)
|
||||
|
||||
# Verify time ordering within the page (descending order)
|
||||
for attempt in paginated_result.items:
|
||||
if last_time_started is not None:
|
||||
assert (
|
||||
attempt.time_started <= last_time_started
|
||||
), "Index attempts not in descending time order"
|
||||
last_time_started = attempt.time_started
|
||||
|
||||
# Add the retrieved index attempts to the list of retrieved attempts
|
||||
retrieved_attempts.extend([attempt.id for attempt in paginated_result.items])
|
||||
|
||||
# Create a set of all the expected index attempt IDs
|
||||
all_expected_attempts = set(attempt.id for attempt in index_attempts)
|
||||
# Create a set of all the retrieved index attempt IDs
|
||||
all_retrieved_attempts = set(retrieved_attempts)
|
||||
|
||||
# Verify that the set of retrieved attempts is equal to the set of expected attempts
|
||||
assert all_expected_attempts == all_retrieved_attempts
|
||||
|
||||
|
||||
def test_index_attempt_pagination(reset: None) -> None:
|
||||
# Create an admin user to perform actions
|
||||
user_performing_action: DATestUser = UserManager.create(
|
||||
name="admin_performing_action",
|
||||
is_first_user=True,
|
||||
)
|
||||
|
||||
# Create a CC pair to attach index attempts to
|
||||
cc_pair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
|
||||
# Create 300 successful index attempts
|
||||
base_time = datetime.now()
|
||||
all_attempts = IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=300,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=IndexingStatus.SUCCESS,
|
||||
base_time=base_time,
|
||||
)
|
||||
|
||||
# Verify basic pagination with different page sizes
|
||||
print("Verifying basic pagination with page size 5")
|
||||
_verify_index_attempt_pagination(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempts=all_attempts,
|
||||
page_size=5,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
|
||||
# Test with a larger page size
|
||||
print("Verifying pagination with page size 100")
|
||||
_verify_index_attempt_pagination(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempts=all_attempts,
|
||||
page_size=100,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
@@ -3,7 +3,6 @@ import uuid
|
||||
import requests
|
||||
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@@ -21,15 +20,14 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: str) -> dict | None
|
||||
return next((p for p in providers if p["id"] == provider_id), None)
|
||||
|
||||
|
||||
def test_create_llm_provider_without_display_model_names(reset: None) -> None:
|
||||
def test_create_llm_provider_without_display_model_names(
|
||||
admin_user: DATestUser,
|
||||
) -> None:
|
||||
"""Test creating an LLM provider without specifying
|
||||
display_model_names and verify it's null in response"""
|
||||
# Create admin user
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
# Create LLM provider without model_names
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": str(uuid.uuid4()),
|
||||
@@ -51,15 +49,12 @@ def test_create_llm_provider_without_display_model_names(reset: None) -> None:
|
||||
assert provider_data["display_model_names"] is None
|
||||
|
||||
|
||||
def test_update_llm_provider_model_names(reset: None) -> None:
|
||||
def test_update_llm_provider_model_names(admin_user: DATestUser) -> None:
|
||||
"""Test updating an LLM provider's model_names"""
|
||||
# Create admin user
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
# First create provider without model_names
|
||||
name = str(uuid.uuid4())
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": name,
|
||||
@@ -95,14 +90,11 @@ def test_update_llm_provider_model_names(reset: None) -> None:
|
||||
assert provider_data["model_names"] == _DEFAULT_MODELS
|
||||
|
||||
|
||||
def test_delete_llm_provider(reset: None) -> None:
|
||||
def test_delete_llm_provider(admin_user: DATestUser) -> None:
|
||||
"""Test deleting an LLM provider"""
|
||||
# Create admin user
|
||||
admin_user = UserManager.create(name="admin_user")
|
||||
|
||||
# Create a provider
|
||||
response = requests.put(
|
||||
f"{API_SERVER_URL}/admin/llm/provider?is_creation=true",
|
||||
f"{API_SERVER_URL}/admin/llm/provider",
|
||||
headers=admin_user.headers,
|
||||
json={
|
||||
"name": "test-provider-delete",
|
||||
|
||||
@@ -1,92 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from onyx.background.celery.tasks.llm_model_update.tasks import (
|
||||
_process_model_list_response,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_data,expected_result,expected_error,error_match",
|
||||
[
|
||||
# Success cases
|
||||
(
|
||||
["gpt-4", "gpt-3.5-turbo", "claude-2"],
|
||||
["gpt-4", "gpt-3.5-turbo", "claude-2"],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
[
|
||||
{"model_name": "gpt-4", "other_field": "value"},
|
||||
{"model_name": "gpt-3.5-turbo", "other_field": "value"},
|
||||
],
|
||||
["gpt-4", "gpt-3.5-turbo"],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
[
|
||||
{"id": "gpt-4", "other_field": "value"},
|
||||
{"id": "gpt-3.5-turbo", "other_field": "value"},
|
||||
],
|
||||
["gpt-4", "gpt-3.5-turbo"],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
{"data": ["gpt-4", "gpt-3.5-turbo"]},
|
||||
["gpt-4", "gpt-3.5-turbo"],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
{"models": ["gpt-4", "gpt-3.5-turbo"]},
|
||||
["gpt-4", "gpt-3.5-turbo"],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
(
|
||||
{"models": [{"id": "gpt-4"}, {"id": "gpt-3.5-turbo"}]},
|
||||
["gpt-4", "gpt-3.5-turbo"],
|
||||
None,
|
||||
None,
|
||||
),
|
||||
# Error cases
|
||||
(
|
||||
"not a list",
|
||||
None,
|
||||
ValueError,
|
||||
"Invalid response from API - expected list",
|
||||
),
|
||||
(
|
||||
{"wrong_field": []},
|
||||
None,
|
||||
ValueError,
|
||||
"Invalid response from API - expected dict with 'data' or 'models' field",
|
||||
),
|
||||
(
|
||||
[{"wrong_field": "value"}],
|
||||
None,
|
||||
ValueError,
|
||||
"Invalid item in model list - expected dict with model_name or id",
|
||||
),
|
||||
(
|
||||
[42],
|
||||
None,
|
||||
ValueError,
|
||||
"Invalid item in model list - expected string or dict",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_process_model_list_response(
|
||||
input_data: dict | list,
|
||||
expected_result: list[str] | None,
|
||||
expected_error: type[Exception] | None,
|
||||
error_match: str | None,
|
||||
) -> None:
|
||||
if expected_error:
|
||||
with pytest.raises(expected_error, match=error_match):
|
||||
_process_model_list_response(input_data)
|
||||
else:
|
||||
result = _process_model_list_response(input_data)
|
||||
assert result == expected_result
|
||||
@@ -18,9 +18,6 @@ FROM base AS builder
|
||||
RUN apk add --no-cache libc6-compat
|
||||
WORKDIR /app
|
||||
|
||||
# Add NODE_OPTIONS argument
|
||||
ARG NODE_OPTIONS
|
||||
|
||||
# pull in source code / package.json / package-lock.json
|
||||
COPY . .
|
||||
|
||||
@@ -81,8 +78,7 @@ ENV NEXT_PUBLIC_GTM_ENABLED=${NEXT_PUBLIC_GTM_ENABLED}
|
||||
ARG NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED
|
||||
ENV NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=${NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED}
|
||||
|
||||
# Use NODE_OPTIONS in the build command
|
||||
RUN NODE_OPTIONS="${NODE_OPTIONS}" npx next build
|
||||
RUN npx next build
|
||||
|
||||
# Step 2. Production image, copy all the files and run next
|
||||
FROM base AS runner
|
||||
|
||||
@@ -22,7 +22,6 @@ const cspHeader = `
|
||||
|
||||
/** @type {import('next').NextConfig} */
|
||||
const nextConfig = {
|
||||
productionBrowserSourceMaps: false,
|
||||
output: "standalone",
|
||||
publicRuntimeConfig: {
|
||||
version,
|
||||
@@ -72,7 +71,7 @@ const nextConfig = {
|
||||
|
||||
// Sentry configuration for error monitoring:
|
||||
// - Without SENTRY_AUTH_TOKEN and NEXT_PUBLIC_SENTRY_DSN: Sentry is completely disabled
|
||||
// - With both configured: Capture errors and limited performance data
|
||||
// - With both configured: Only unhandled errors are captured (no performance/session tracking)
|
||||
|
||||
// Determine if Sentry should be enabled
|
||||
const sentryEnabled = Boolean(
|
||||
@@ -86,16 +85,12 @@ const sentryWebpackPluginOptions = {
|
||||
authToken: process.env.SENTRY_AUTH_TOKEN,
|
||||
silent: !sentryEnabled, // Silence output when Sentry is disabled
|
||||
dryRun: !sentryEnabled, // Don't upload source maps when Sentry is disabled
|
||||
...(sentryEnabled && {
|
||||
sourceMaps: {
|
||||
include: ["./.next"],
|
||||
ignore: ["node_modules"],
|
||||
urlPrefix: "~/_next",
|
||||
stripPrefix: ["webpack://_N_E/"],
|
||||
validate: true,
|
||||
cleanArtifacts: true,
|
||||
},
|
||||
}),
|
||||
sourceMaps: {
|
||||
include: ["./.next"],
|
||||
validate: false,
|
||||
urlPrefix: "~/_next",
|
||||
skip: !sentryEnabled,
|
||||
},
|
||||
};
|
||||
|
||||
// Export the module with conditional Sentry configuration
|
||||
|
||||
1214
web/package-lock.json
generated
1214
web/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -24,15 +24,13 @@
|
||||
"@radix-ui/react-label": "^2.1.1",
|
||||
"@radix-ui/react-popover": "^1.1.2",
|
||||
"@radix-ui/react-radio-group": "^1.2.2",
|
||||
"@radix-ui/react-scroll-area": "^1.2.2",
|
||||
"@radix-ui/react-select": "^2.1.2",
|
||||
"@radix-ui/react-separator": "^1.1.0",
|
||||
"@radix-ui/react-slot": "^1.1.0",
|
||||
"@radix-ui/react-switch": "^1.1.1",
|
||||
"@radix-ui/react-tabs": "^1.1.1",
|
||||
"@radix-ui/react-tooltip": "^1.1.3",
|
||||
"@sentry/nextjs": "^8.50.0",
|
||||
"@sentry/tracing": "^7.120.3",
|
||||
"@sentry/nextjs": "^8.34.0",
|
||||
"@stripe/stripe-js": "^4.6.0",
|
||||
"@types/js-cookie": "^3.0.3",
|
||||
"@types/lodash": "^4.17.0",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user