mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-19 00:35:46 +00:00
Compare commits
41 Commits
bb
...
bugfix/ext
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4bf05c1947 | ||
|
|
b70db15622 | ||
|
|
e9492ce9ec | ||
|
|
35574369ed | ||
|
|
eff433bdc5 | ||
|
|
3260d793d1 | ||
|
|
1a7aca06b9 | ||
|
|
c6434db7eb | ||
|
|
667b9e04c5 | ||
|
|
29c84d7707 | ||
|
|
17c915b11b | ||
|
|
95ca592d6d | ||
|
|
e39a27fd6b | ||
|
|
26d3c952c6 | ||
|
|
53683e2f3c | ||
|
|
0c0113a481 | ||
|
|
c0f381e471 | ||
|
|
5ed83f1148 | ||
|
|
9db7b67a6c | ||
|
|
2850048c6b | ||
|
|
61058e5fcd | ||
|
|
c87261cda7 | ||
|
|
e030b0a6fc | ||
|
|
61136975ad | ||
|
|
0c74bbf9ed | ||
|
|
12b2126e69 | ||
|
|
037943c6ff | ||
|
|
f9485b1325 | ||
|
|
552a0630fe | ||
|
|
5bf520d8b8 | ||
|
|
7dc5a77946 | ||
|
|
03abd4a1bc | ||
|
|
16d6d708f6 | ||
|
|
9740ed32b5 | ||
|
|
b56877cc2e | ||
|
|
da5c83a96d | ||
|
|
818225c60e | ||
|
|
5a4d007cf9 | ||
|
|
5e32f9d922 | ||
|
|
fb931ee4de | ||
|
|
bc2c56dfb6 |
@@ -65,6 +65,7 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
|
||||
@@ -12,7 +12,32 @@ env:
|
||||
BUILDKIT_PROGRESS: plain
|
||||
|
||||
jobs:
|
||||
# 1) Preliminary job to check if the changed files are relevant
|
||||
check_model_server_changes:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
changed: ${{ steps.check.outputs.changed }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check if relevant files changed
|
||||
id: check
|
||||
run: |
|
||||
# Default to "false"
|
||||
echo "changed=false" >> $GITHUB_OUTPUT
|
||||
|
||||
# Compare the previous commit (github.event.before) to the current one (github.sha)
|
||||
# If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
|
||||
# set changed=true
|
||||
if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
|
||||
| grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
|
||||
echo "changed=true" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
build-amd64:
|
||||
needs: [check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-amd64"]
|
||||
steps:
|
||||
@@ -52,6 +77,8 @@ jobs:
|
||||
provenance: false
|
||||
|
||||
build-arm64:
|
||||
needs: [check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
|
||||
steps:
|
||||
@@ -91,7 +118,8 @@ jobs:
|
||||
provenance: false
|
||||
|
||||
merge-and-scan:
|
||||
needs: [build-amd64, build-arm64]
|
||||
needs: [build-amd64, build-arm64, check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Login to Docker Hub
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""set built in to default
|
||||
|
||||
Revision ID: 2cdeff6d8c93
|
||||
Revises: f5437cc136c5
|
||||
Create Date: 2025-02-11 14:57:51.308775
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2cdeff6d8c93"
|
||||
down_revision = "f5437cc136c5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Prior to this migration / point in the codebase history,
|
||||
# built in personas were implicitly treated as default personas (with no option to change this)
|
||||
# This migration makes that explicit
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET is_default_persona = TRUE
|
||||
WHERE builtin_persona = TRUE
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Add background errors table
|
||||
|
||||
Revision ID: f39c5794c10a
|
||||
Revises: 2cdeff6d8c93
|
||||
Create Date: 2025-02-12 17:11:14.527876
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f39c5794c10a"
|
||||
down_revision = "2cdeff6d8c93"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"background_error",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("message", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("cc_pair_id", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["cc_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("background_error")
|
||||
@@ -1,44 +1,46 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
beat_cloud_tasks as base_beat_system_tasks,
|
||||
)
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
cloud_tasks_to_schedule as base_cloud_tasks_to_schedule,
|
||||
beat_task_templates as base_beat_task_templates,
|
||||
)
|
||||
from onyx.background.celery.tasks.beat_schedule import generate_cloud_tasks
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
tasks_to_schedule as base_tasks_to_schedule,
|
||||
get_tasks_to_schedule as base_get_tasks_to_schedule,
|
||||
)
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
ee_cloud_tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
ee_beat_system_tasks: list[dict] = []
|
||||
|
||||
ee_beat_task_templates: list[dict] = []
|
||||
ee_beat_task_templates.extend(
|
||||
[
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-ttl-management",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
},
|
||||
},
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
ee_tasks_to_schedule: list[dict] = []
|
||||
|
||||
@@ -65,9 +67,14 @@ if not MULTI_TENANT:
|
||||
]
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return ee_cloud_tasks_to_schedule + base_cloud_tasks_to_schedule
|
||||
def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
|
||||
beat_system_tasks = ee_beat_system_tasks + base_beat_system_tasks
|
||||
beat_task_templates = ee_beat_task_templates + base_beat_task_templates
|
||||
cloud_tasks = generate_cloud_tasks(
|
||||
beat_system_tasks, beat_task_templates, beat_multiplier
|
||||
)
|
||||
return cloud_tasks
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return ee_tasks_to_schedule + base_tasks_to_schedule
|
||||
return ee_tasks_to_schedule + base_get_tasks_to_schedule()
|
||||
|
||||
@@ -77,3 +77,5 @@ POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
|
||||
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
|
||||
|
||||
ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user"
|
||||
|
||||
GATED_TENANTS_KEY = "gated_tenants"
|
||||
|
||||
@@ -15,6 +15,9 @@ def make_persona_private(
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""NOTE(rkuo): This function batches all updates into a single commit. If we don't
|
||||
dedupe the inputs, the commit will exception."""
|
||||
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -23,19 +26,22 @@ def make_persona_private(
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
if user_ids:
|
||||
for user_uuid in user_ids:
|
||||
db_session.add(Persona__User(persona_id=persona_id, user_id=user_uuid))
|
||||
user_ids_set = set(user_ids)
|
||||
for user_id in user_ids_set:
|
||||
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
|
||||
|
||||
create_notification(
|
||||
user_id=user_uuid,
|
||||
user_id=user_id,
|
||||
notif_type=NotificationType.PERSONA_SHARED,
|
||||
db_session=db_session,
|
||||
additional_data=PersonaSharedNotificationData(
|
||||
persona_id=persona_id,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
if group_ids:
|
||||
for group_id in group_ids:
|
||||
group_ids_set = set(group_ids)
|
||||
for group_id in group_ids_set:
|
||||
db_session.add(
|
||||
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
@@ -10,7 +11,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _build_group_member_email_map(
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_client: OnyxConfluence, cc_pair_id: int
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
@@ -18,8 +19,11 @@ def _build_group_member_email_map(
|
||||
|
||||
user = user_result.get("user", {})
|
||||
if not user:
|
||||
logger.warning(f"user result missing user field: {user_result}")
|
||||
msg = f"user result missing user field: {user_result}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
continue
|
||||
|
||||
email = user.get("email")
|
||||
if not email:
|
||||
# This field is only present in Confluence Server
|
||||
@@ -32,7 +36,12 @@ def _build_group_member_email_map(
|
||||
)
|
||||
if not email:
|
||||
# If we still don't have an email, skip this user
|
||||
logger.warning(f"user result missing email field: {user_result}")
|
||||
msg = f"user result missing email field: {user_result}"
|
||||
if user.get("type") == "app":
|
||||
logger.warning(msg)
|
||||
else:
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
continue
|
||||
|
||||
all_users_groups: set[str] = set()
|
||||
@@ -42,11 +51,18 @@ def _build_group_member_email_map(
|
||||
group_member_emails.setdefault(group_id, set()).add(email)
|
||||
all_users_groups.add(group_id)
|
||||
|
||||
if not group_member_emails:
|
||||
logger.warning(f"No groups found for user with email: {email}")
|
||||
if not all_users_groups:
|
||||
msg = f"No groups found for user with email: {email}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
else:
|
||||
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
|
||||
|
||||
if not group_member_emails:
|
||||
msg = "No groups found for any users."
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
|
||||
return group_member_emails
|
||||
|
||||
|
||||
@@ -61,6 +77,7 @@ def confluence_group_sync(
|
||||
|
||||
group_member_email_map = _build_group_member_email_map(
|
||||
confluence_client=confluence_client,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
onyx_groups: list[ExternalUserGroup] = []
|
||||
all_found_emails = set()
|
||||
|
||||
@@ -83,6 +83,7 @@ def handle_search_request(
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
skip_query_analysis=False,
|
||||
db_session=db_session,
|
||||
bypass_acl=False,
|
||||
)
|
||||
|
||||
@@ -18,11 +18,16 @@ from ee.onyx.server.tenants.anonymous_user_path import (
|
||||
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import AnonymousUserPath
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
@@ -39,12 +44,9 @@ from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.settings.store import store_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
@@ -126,37 +128,29 @@ async def login_as_anonymous_user(
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
) -> None:
|
||||
) -> ProductGatingResponse:
|
||||
"""
|
||||
Gating the product means that the product is not available to the tenant.
|
||||
They will be directed to the billing page.
|
||||
We gate the product when
|
||||
1) User has ended free trial without adding payment method
|
||||
2) User's card has declined
|
||||
We gate the product when their subscription has ended.
|
||||
"""
|
||||
tenant_id = product_gating_request.tenant_id
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
store_product_gating(
|
||||
product_gating_request.tenant_id, product_gating_request.application_status
|
||||
)
|
||||
return ProductGatingResponse(updated=True, error=None)
|
||||
|
||||
settings = load_settings()
|
||||
settings.product_gating = product_gating_request.product_gating
|
||||
store_settings(settings)
|
||||
|
||||
if product_gating_request.notification:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
create_notification(None, product_gating_request.notification, db_session)
|
||||
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to gate product")
|
||||
return ProductGatingResponse(updated=False, error=str(e))
|
||||
|
||||
|
||||
@router.get("/billing-information", response_model=BillingInformation)
|
||||
@router.get("/billing-information")
|
||||
async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation:
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
return BillingInformation(
|
||||
**fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
|
||||
)
|
||||
return fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
@@ -169,9 +163,10 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/cloud-settings",
|
||||
return_url=f"{WEB_DOMAIN}/admin/billing",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
@@ -180,6 +175,20 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/impersonate")
|
||||
async def impersonate_user(
|
||||
impersonate_request: ImpersonateRequest,
|
||||
|
||||
@@ -6,6 +6,7 @@ import stripe
|
||||
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -14,6 +15,19 @@ stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def fetch_stripe_checkout_session(tenant_id: str) -> str:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.post(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()["sessionId"]
|
||||
|
||||
|
||||
def fetch_tenant_stripe_information(tenant_id: str) -> dict:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
@@ -27,7 +41,7 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
|
||||
return response.json()
|
||||
|
||||
|
||||
def fetch_billing_information(tenant_id: str) -> dict:
|
||||
def fetch_billing_information(tenant_id: str) -> BillingInformation:
|
||||
logger.info("Fetching billing information")
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
@@ -38,7 +52,7 @@ def fetch_billing_information(tenant_id: str) -> dict:
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
billing_info = response.json()
|
||||
billing_info = BillingInformation(**response.json())
|
||||
return billing_info
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.server.settings.models import GatingType
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
|
||||
|
||||
class CheckoutSessionCreationRequest(BaseModel):
|
||||
@@ -15,15 +16,24 @@ class CreateTenantRequest(BaseModel):
|
||||
|
||||
class ProductGatingRequest(BaseModel):
|
||||
tenant_id: str
|
||||
product_gating: GatingType
|
||||
notification: NotificationType | None = None
|
||||
application_status: ApplicationStatus
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
subscribed: bool
|
||||
|
||||
|
||||
class BillingInformation(BaseModel):
|
||||
stripe_subscription_id: str
|
||||
status: str
|
||||
current_period_start: datetime
|
||||
current_period_end: datetime
|
||||
number_of_seats: int
|
||||
cancel_at_period_end: bool
|
||||
canceled_at: datetime | None
|
||||
trial_start: datetime | None
|
||||
trial_end: datetime | None
|
||||
seats: int
|
||||
subscription_status: str
|
||||
billing_start: str
|
||||
billing_end: str
|
||||
payment_method_enabled: bool
|
||||
|
||||
|
||||
@@ -48,3 +58,12 @@ class TenantDeletionPayload(BaseModel):
|
||||
|
||||
class AnonymousUserPath(BaseModel):
|
||||
anonymous_user_path: str | None
|
||||
|
||||
|
||||
class ProductGatingResponse(BaseModel):
|
||||
updated: bool
|
||||
error: str | None
|
||||
|
||||
|
||||
class SubscriptionSessionResponse(BaseModel):
|
||||
sessionId: str
|
||||
|
||||
51
backend/ee/onyx/server/tenants/product_gating.py
Normal file
51
backend/ee/onyx/server/tenants/product_gating.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing import cast
|
||||
|
||||
from ee.onyx.configs.app_configs import GATED_TENANTS_KEY
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.settings.store import store_settings
|
||||
from onyx.setup import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def update_tenant_gating(tenant_id: str, status: ApplicationStatus) -> None:
|
||||
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
# Store the full status
|
||||
status_key = f"tenant:{tenant_id}:status"
|
||||
redis_client.set(status_key, status.value)
|
||||
|
||||
# Maintain the GATED_ACCESS set
|
||||
if status == ApplicationStatus.GATED_ACCESS:
|
||||
redis_client.sadd(GATED_TENANTS_KEY, tenant_id)
|
||||
else:
|
||||
redis_client.srem(GATED_TENANTS_KEY, tenant_id)
|
||||
|
||||
|
||||
def store_product_gating(tenant_id: str, application_status: ApplicationStatus) -> None:
|
||||
try:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
settings = load_settings()
|
||||
settings.application_status = application_status
|
||||
store_settings(settings)
|
||||
|
||||
# Store gated tenant information in Redis
|
||||
update_tenant_gating(tenant_id, application_status)
|
||||
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to gate product")
|
||||
raise
|
||||
|
||||
|
||||
def get_gated_tenants() -> set[str]:
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))
|
||||
@@ -28,3 +28,9 @@ class EmbeddingModelTextType:
|
||||
@staticmethod
|
||||
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
|
||||
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
|
||||
|
||||
|
||||
class GPUStatus:
|
||||
CUDA = "cuda"
|
||||
MAC_MPS = "mps"
|
||||
NONE = "none"
|
||||
|
||||
@@ -12,6 +12,7 @@ import voyageai # type: ignore
|
||||
from cohere import AsyncClient as CohereAsyncClient
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from litellm import aembedding
|
||||
from litellm.exceptions import RateLimitError
|
||||
@@ -320,6 +321,7 @@ async def embed_text(
|
||||
prefix: str | None,
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
gpu_type: str = "UNKNOWN",
|
||||
) -> list[Embedding]:
|
||||
if not all(texts):
|
||||
logger.error("Empty strings provided for embedding")
|
||||
@@ -373,8 +375,11 @@ async def embed_text(
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info(
|
||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||
f"with provider {provider_type} in {elapsed:.2f}"
|
||||
f"event=embedding_provider "
|
||||
f"texts={len(texts)} "
|
||||
f"chars={total_chars} "
|
||||
f"provider={provider_type} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
elif model_name is not None:
|
||||
logger.info(
|
||||
@@ -403,6 +408,14 @@ async def embed_text(
|
||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||
f"with local model {model_name} in {elapsed:.2f}"
|
||||
)
|
||||
logger.info(
|
||||
f"event=embedding_model "
|
||||
f"texts={len(texts)} "
|
||||
f"chars={total_chars} "
|
||||
f"model={model_name} "
|
||||
f"gpu={gpu_type} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
else:
|
||||
logger.error("Neither model name nor provider specified for embedding")
|
||||
raise ValueError(
|
||||
@@ -455,8 +468,15 @@ async def litellm_rerank(
|
||||
|
||||
|
||||
@router.post("/bi-encoder-embed")
|
||||
async def process_embed_request(
|
||||
async def route_bi_encoder_embed(
|
||||
request: Request,
|
||||
embed_request: EmbedRequest,
|
||||
) -> EmbedResponse:
|
||||
return await process_embed_request(embed_request, request.app.state.gpu_type)
|
||||
|
||||
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
|
||||
) -> EmbedResponse:
|
||||
if not embed_request.texts:
|
||||
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
||||
@@ -484,6 +504,7 @@ async def process_embed_request(
|
||||
api_url=embed_request.api_url,
|
||||
api_version=embed_request.api_version,
|
||||
prefix=prefix,
|
||||
gpu_type=gpu_type,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except RateLimitError as e:
|
||||
|
||||
@@ -16,6 +16,7 @@ from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_intent_model
|
||||
from model_server.encoders import router as encoders_router
|
||||
from model_server.management_endpoints import router as management_router
|
||||
from model_server.utils import get_gpu_type
|
||||
from onyx import __version__
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
@@ -58,12 +59,10 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
if torch.cuda.is_available():
|
||||
logger.notice("CUDA GPU is available")
|
||||
elif torch.backends.mps.is_available():
|
||||
logger.notice("Mac MPS is available")
|
||||
else:
|
||||
logger.notice("GPU is not available, using CPU")
|
||||
gpu_type = get_gpu_type()
|
||||
logger.notice(f"Torch GPU Detection: gpu_type={gpu_type}")
|
||||
|
||||
app.state.gpu_type = gpu_type
|
||||
|
||||
if TEMP_HF_CACHE_PATH.is_dir():
|
||||
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import torch
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Response
|
||||
|
||||
from model_server.constants import GPUStatus
|
||||
from model_server.utils import get_gpu_type
|
||||
|
||||
router = APIRouter(prefix="/api")
|
||||
|
||||
|
||||
@@ -11,10 +13,7 @@ async def healthcheck() -> Response:
|
||||
|
||||
|
||||
@router.get("/gpu-status")
|
||||
async def gpu_status() -> dict[str, bool | str]:
|
||||
if torch.cuda.is_available():
|
||||
return {"gpu_available": True, "type": "cuda"}
|
||||
elif torch.backends.mps.is_available():
|
||||
return {"gpu_available": True, "type": "mps"}
|
||||
else:
|
||||
return {"gpu_available": False, "type": "none"}
|
||||
async def route_gpu_status() -> dict[str, bool | str]:
|
||||
gpu_type = get_gpu_type()
|
||||
gpu_available = gpu_type != GPUStatus.NONE
|
||||
return {"gpu_available": gpu_available, "type": gpu_type}
|
||||
|
||||
@@ -8,6 +8,9 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from model_server.constants import GPUStatus
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -58,3 +61,12 @@ def simple_log_function_time(
|
||||
return cast(F, wrapped_sync_func)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_gpu_type() -> str:
|
||||
if torch.cuda.is_available():
|
||||
return GPUStatus.CUDA
|
||||
if torch.backends.mps.is_available():
|
||||
return GPUStatus.MAC_MPS
|
||||
|
||||
return GPUStatus.NONE
|
||||
|
||||
@@ -21,10 +21,11 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
|
||||
from onyx.configs.agent_configs import AGENT_RERANKING_STATS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.pipeline import retrieval_preprocessing
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.postprocessing.postprocessing import rerank_sections
|
||||
from onyx.context.search.postprocessing.postprocessing import should_rerank
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
|
||||
|
||||
def rerank_documents(
|
||||
@@ -39,6 +40,8 @@ def rerank_documents(
|
||||
|
||||
# Rerank post retrieval and verification. First, create a search query
|
||||
# then create the list of reranked sections
|
||||
# If no question defined/question is None in the state, use the original
|
||||
# question from the search request as query
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = (
|
||||
@@ -47,39 +50,28 @@ def rerank_documents(
|
||||
assert (
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
with get_session_context_manager() as db_session:
|
||||
# we ignore some of the user specified fields since this search is
|
||||
# internal to agentic search, but we still want to pass through
|
||||
# persona (for stuff like document sets) and rerank settings
|
||||
# (to not make an unnecessary db call).
|
||||
search_request = SearchRequest(
|
||||
query=question,
|
||||
persona=graph_config.inputs.search_request.persona,
|
||||
rerank_settings=graph_config.inputs.search_request.rerank_settings,
|
||||
)
|
||||
_search_query = retrieval_preprocessing(
|
||||
search_request=search_request,
|
||||
user=graph_config.tooling.search_tool.user, # bit of a hack
|
||||
llm=graph_config.tooling.fast_llm,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# skip section filtering
|
||||
# Note that these are passed in values from the API and are overrides which are typically None
|
||||
rerank_settings = graph_config.inputs.search_request.rerank_settings
|
||||
|
||||
if (
|
||||
_search_query.rerank_settings
|
||||
and _search_query.rerank_settings.rerank_model_name
|
||||
and _search_query.rerank_settings.num_rerank > 0
|
||||
and len(verified_documents) > 0
|
||||
):
|
||||
if rerank_settings is None:
|
||||
with get_session_context_manager() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
|
||||
if should_rerank(rerank_settings) and len(verified_documents) > 0:
|
||||
if len(verified_documents) > 1:
|
||||
reranked_documents = rerank_sections(
|
||||
_search_query,
|
||||
verified_documents,
|
||||
query_str=question,
|
||||
# if runnable, then rerank_settings is not None
|
||||
rerank_settings=cast(RerankingDetails, rerank_settings),
|
||||
sections_to_rerank=verified_documents,
|
||||
)
|
||||
else:
|
||||
num = "No" if len(verified_documents) == 0 else "One"
|
||||
logger.warning(f"{num} verified document(s) found, skipping reranking")
|
||||
logger.warning(
|
||||
f"{len(verified_documents)} verified document(s) found, skipping reranking"
|
||||
)
|
||||
reranked_documents = verified_documents
|
||||
else:
|
||||
logger.warning("No reranking settings found, using unranked documents")
|
||||
|
||||
@@ -23,6 +23,7 @@ from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
@@ -67,9 +68,12 @@ def retrieve_documents(
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=query_to_retrieve,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=not state.base_search,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
|
||||
@@ -58,6 +58,7 @@ from onyx.prompts.agent_search import (
|
||||
)
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
@@ -218,7 +219,10 @@ def get_test_config(
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
)
|
||||
|
||||
chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID")
|
||||
chat_session_id = (
|
||||
os.environ.get("ONYX_AS_CHAT_SESSION_ID")
|
||||
or "00000000-0000-0000-0000-000000000000"
|
||||
)
|
||||
assert (
|
||||
chat_session_id is not None
|
||||
), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
|
||||
@@ -341,8 +345,12 @@ def retrieve_search_docs(
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=question,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=None,
|
||||
skip_query_analysis=False,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import smtplib
|
||||
from datetime import datetime
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from textwrap import dedent
|
||||
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
@@ -13,23 +13,150 @@ from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.models import User
|
||||
|
||||
HTML_EMAIL_TEMPLATE = """\
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width" />
|
||||
<title>{title}</title>
|
||||
<style>
|
||||
body, table, td, a {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
|
||||
text-size-adjust: 100%;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
-webkit-font-smoothing: antialiased;
|
||||
-webkit-text-size-adjust: none;
|
||||
}}
|
||||
body {{
|
||||
background-color: #f7f7f7;
|
||||
color: #333;
|
||||
}}
|
||||
.body-content {{
|
||||
color: #333;
|
||||
}}
|
||||
.email-container {{
|
||||
width: 100%;
|
||||
max-width: 600px;
|
||||
margin: 0 auto;
|
||||
background-color: #ffffff;
|
||||
border-radius: 6px;
|
||||
overflow: hidden;
|
||||
border: 1px solid #eaeaea;
|
||||
}}
|
||||
.header {{
|
||||
background-color: #000000;
|
||||
padding: 20px;
|
||||
text-align: center;
|
||||
}}
|
||||
.header img {{
|
||||
max-width: 140px;
|
||||
}}
|
||||
.body-content {{
|
||||
padding: 20px 30px;
|
||||
}}
|
||||
.title {{
|
||||
font-size: 20px;
|
||||
font-weight: bold;
|
||||
margin: 0 0 10px;
|
||||
}}
|
||||
.message {{
|
||||
font-size: 16px;
|
||||
line-height: 1.5;
|
||||
margin: 0 0 20px;
|
||||
}}
|
||||
.cta-button {{
|
||||
display: inline-block;
|
||||
padding: 12px 20px;
|
||||
background-color: #000000;
|
||||
color: #ffffff !important;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
font-weight: 500;
|
||||
}}
|
||||
.footer {{
|
||||
font-size: 13px;
|
||||
color: #6A7280;
|
||||
text-align: center;
|
||||
padding: 20px;
|
||||
}}
|
||||
.footer a {{
|
||||
color: #6b7280;
|
||||
text-decoration: underline;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<table role="presentation" class="email-container" cellpadding="0" cellspacing="0">
|
||||
<tr>
|
||||
<td class="header">
|
||||
<img
|
||||
style="background-color: #ffffff; border-radius: 8px;"
|
||||
src="https://www.onyx.app/logos/customer/onyx.png"
|
||||
alt="Onyx Logo"
|
||||
>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="body-content">
|
||||
<h1 class="title">{heading}</h1>
|
||||
<div class="message">
|
||||
{message}
|
||||
</div>
|
||||
{cta_block}
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="footer">
|
||||
© {year} Onyx. All rights reserved.
|
||||
<br>
|
||||
Have questions? Join our Slack community <a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA">here</a>.
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def build_html_email(
|
||||
heading: str, message: str, cta_text: str | None = None, cta_link: str | None = None
|
||||
) -> str:
|
||||
if cta_text and cta_link:
|
||||
cta_block = f'<a class="cta-button" href="{cta_link}">{cta_text}</a>'
|
||||
else:
|
||||
cta_block = ""
|
||||
return HTML_EMAIL_TEMPLATE.format(
|
||||
title=heading,
|
||||
heading=heading,
|
||||
message=message,
|
||||
cta_block=cta_block,
|
||||
year=datetime.now().year,
|
||||
)
|
||||
|
||||
|
||||
def send_email(
|
||||
user_email: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
html_body: str,
|
||||
text_body: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
if not EMAIL_CONFIGURED:
|
||||
raise ValueError("Email is not configured.")
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
msg.attach(MIMEText(body))
|
||||
part_text = MIMEText(text_body, "plain")
|
||||
part_html = MIMEText(html_body, "html")
|
||||
|
||||
msg.attach(part_text)
|
||||
msg.attach(part_html)
|
||||
|
||||
try:
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
@@ -40,26 +167,44 @@ def send_email(
|
||||
raise e
|
||||
|
||||
|
||||
def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
# Example usage of the reusable HTML
|
||||
subject = "Your Onyx Subscription Has Been Canceled"
|
||||
heading = "Subscription Canceled"
|
||||
message = (
|
||||
"<p>We’re sorry to see you go.</p>"
|
||||
"<p>Your subscription has been canceled and will end on your next billing date.</p>"
|
||||
"<p>If you change your mind, you can always come back!</p>"
|
||||
)
|
||||
cta_text = "Renew Subscription"
|
||||
cta_link = "https://www.onyx.app/pricing"
|
||||
html_content = build_html_email(heading, message, cta_text, cta_link)
|
||||
text_content = (
|
||||
"We're sorry to see you go.\n"
|
||||
"Your subscription has been canceled and will end on your next billing date.\n"
|
||||
"If you change your mind, visit https://www.onyx.app/pricing"
|
||||
)
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
|
||||
|
||||
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
||||
subject = "Invitation to Join Onyx Organization"
|
||||
body = dedent(
|
||||
f"""\
|
||||
Hello,
|
||||
|
||||
You have been invited to join an organization on Onyx.
|
||||
|
||||
To join the organization, please visit the following link:
|
||||
|
||||
{WEB_DOMAIN}/auth/signup?email={user_email}
|
||||
|
||||
You'll be asked to set a password or login with Google to complete your registration.
|
||||
|
||||
Best regards,
|
||||
The Onyx Team
|
||||
"""
|
||||
heading = "You've Been Invited!"
|
||||
message = (
|
||||
f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
|
||||
"<p>To join the organization, please click the button below to set a password "
|
||||
"or login with Google and complete your registration.</p>"
|
||||
)
|
||||
|
||||
send_email(user_email, subject, body, current_user.email)
|
||||
cta_text = "Join Organization"
|
||||
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
|
||||
html_content = build_html_email(heading, message, cta_text, cta_link)
|
||||
text_content = (
|
||||
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
|
||||
"To join the organization, please visit the following link:\n"
|
||||
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
|
||||
"You'll be asked to set a password or login with Google to complete your registration."
|
||||
)
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
|
||||
|
||||
def send_forgot_password_email(
|
||||
@@ -68,13 +213,15 @@ def send_forgot_password_email(
|
||||
mail_from: str = EMAIL_FROM,
|
||||
tenant_id: str | None = None,
|
||||
) -> None:
|
||||
# Builds a forgot password email with or without fancy HTML
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
if tenant_id:
|
||||
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
|
||||
# Keep search param same name as cookie for simplicity
|
||||
body = f"Click the following link to reset your password: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
|
||||
html_content = build_html_email("Reset Your Password", message)
|
||||
text_content = f"Click the following link to reset your password: {link}"
|
||||
send_email(user_email, subject, html_content, text_content, mail_from)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
@@ -82,7 +229,12 @@ def send_user_verification_email(
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
# Builds a verification email
|
||||
subject = "Onyx Email Verification"
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
body = f"Click the following link to verify your email address: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
message = (
|
||||
f"<p>Click the following link to verify your email address:</p><p>{link}</p>"
|
||||
)
|
||||
html_content = build_html_email("Verify Your Email", message)
|
||||
text_content = f"Click the following link to verify your email address: {link}"
|
||||
send_email(user_email, subject, html_content, text_content, mail_from)
|
||||
|
||||
@@ -1,41 +1,56 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.signals import beat_init
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.beat")
|
||||
|
||||
|
||||
class DynamicTenantScheduler(PersistentScheduler):
|
||||
"""This scheduler is useful because we can dynamically adjust task generation rates
|
||||
through it."""
|
||||
|
||||
RELOAD_INTERVAL = 60
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
logger.info("Initializing DynamicTenantScheduler")
|
||||
super().__init__(*args, **kwargs)
|
||||
self._reload_interval = timedelta(minutes=2)
|
||||
|
||||
self.last_beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
|
||||
self._reload_interval = timedelta(
|
||||
seconds=DynamicTenantScheduler.RELOAD_INTERVAL
|
||||
)
|
||||
self._last_reload = self.app.now() - self._reload_interval
|
||||
|
||||
# Let the parent class handle store initialization
|
||||
self.setup_schedule()
|
||||
self._try_updating_schedule()
|
||||
logger.info(f"Set reload interval to {self._reload_interval}")
|
||||
task_logger.info(
|
||||
f"DynamicTenantScheduler initialized: reload_interval={self._reload_interval}"
|
||||
)
|
||||
|
||||
def setup_schedule(self) -> None:
|
||||
logger.info("Setting up initial schedule")
|
||||
super().setup_schedule()
|
||||
logger.info("Initial schedule setup complete")
|
||||
|
||||
def tick(self) -> float:
|
||||
retval = super().tick()
|
||||
@@ -44,36 +59,35 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
self._last_reload is None
|
||||
or (now - self._last_reload) > self._reload_interval
|
||||
):
|
||||
logger.info("Reload interval reached, initiating task update")
|
||||
task_logger.debug("Reload interval reached, initiating task update")
|
||||
try:
|
||||
self._try_updating_schedule()
|
||||
except (AttributeError, KeyError) as e:
|
||||
logger.exception(f"Failed to process task configuration: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error updating tasks: {str(e)}")
|
||||
except (AttributeError, KeyError):
|
||||
task_logger.exception("Failed to process task configuration")
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected error updating tasks")
|
||||
|
||||
self._last_reload = now
|
||||
logger.info("Task update completed, reset reload timer")
|
||||
|
||||
return retval
|
||||
|
||||
def _generate_schedule(
|
||||
self, tenant_ids: list[str] | list[None]
|
||||
self, tenant_ids: list[str] | list[None], beat_multiplier: float
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Given a list of tenant id's, generates a new beat schedule for celery."""
|
||||
logger.info("Fetching tasks to schedule")
|
||||
|
||||
new_schedule: dict[str, dict[str, Any]] = {}
|
||||
|
||||
if MULTI_TENANT:
|
||||
# cloud tasks only need the single task beat across all tenants
|
||||
# cloud tasks are system wide and thus only need to be on the beat schedule
|
||||
# once for all tenants
|
||||
get_cloud_tasks_to_schedule = fetch_versioned_implementation(
|
||||
"onyx.background.celery.tasks.beat_schedule",
|
||||
"get_cloud_tasks_to_schedule",
|
||||
)
|
||||
|
||||
cloud_tasks_to_schedule: list[
|
||||
dict[str, Any]
|
||||
] = get_cloud_tasks_to_schedule()
|
||||
cloud_tasks_to_schedule: list[dict[str, Any]] = get_cloud_tasks_to_schedule(
|
||||
beat_multiplier
|
||||
)
|
||||
for task in cloud_tasks_to_schedule:
|
||||
task_name = task["name"]
|
||||
cloud_task = {
|
||||
@@ -82,11 +96,14 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
"kwargs": task.get("kwargs", {}),
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(f"Adding options to task {task_name}: {options}")
|
||||
task_logger.debug(f"Adding options to task {task_name}: {options}")
|
||||
cloud_task["options"] = options
|
||||
new_schedule[task_name] = cloud_task
|
||||
|
||||
# regular task beats are multiplied across all tenants
|
||||
# note that currently this just schedules for a single tenant in self hosted
|
||||
# and doesn't do anything in the cloud because it's much more scalable
|
||||
# to schedule a single cloud beat task to dispatch per tenant tasks.
|
||||
get_tasks_to_schedule = fetch_versioned_implementation(
|
||||
"onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
|
||||
)
|
||||
@@ -95,7 +112,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
|
||||
logger.info(
|
||||
task_logger.debug(
|
||||
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
|
||||
)
|
||||
continue
|
||||
@@ -104,14 +121,14 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
task_name = task["name"]
|
||||
tenant_task_name = f"{task['name']}-{tenant_id}"
|
||||
|
||||
logger.debug(f"Creating task configuration for {tenant_task_name}")
|
||||
task_logger.debug(f"Creating task configuration for {tenant_task_name}")
|
||||
tenant_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": {"tenant_id": tenant_id},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(
|
||||
task_logger.debug(
|
||||
f"Adding options to task {tenant_task_name}: {options}"
|
||||
)
|
||||
tenant_task["options"] = options
|
||||
@@ -121,44 +138,57 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
|
||||
def _try_updating_schedule(self) -> None:
|
||||
"""Only updates the actual beat schedule on the celery app when it changes"""
|
||||
do_update = False
|
||||
|
||||
logger.info("_try_updating_schedule starting")
|
||||
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
task_logger.debug("_try_updating_schedule starting")
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
logger.info(f"Found {len(tenant_ids)} IDs")
|
||||
task_logger.debug(f"Found {len(tenant_ids)} IDs")
|
||||
|
||||
# get current schedule and extract current tenants
|
||||
current_schedule = self.schedule.items()
|
||||
|
||||
# there are no more per tenant beat tasks, so comment this out
|
||||
# NOTE: we may not actualy need this scheduler any more and should
|
||||
# test reverting to a regular beat schedule implementation
|
||||
# get potential new state
|
||||
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
|
||||
if beat_multiplier_raw is not None:
|
||||
try:
|
||||
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
|
||||
beat_multiplier = float(beat_multiplier_bytes.decode())
|
||||
except ValueError:
|
||||
task_logger.error(
|
||||
f"Invalid beat_multiplier value: {beat_multiplier_raw}"
|
||||
)
|
||||
|
||||
# current_tenants = set()
|
||||
# for task_name, _ in current_schedule:
|
||||
# task_name = cast(str, task_name)
|
||||
# if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
|
||||
# continue
|
||||
new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)
|
||||
|
||||
# if "_" in task_name:
|
||||
# # example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
|
||||
# # -> "12345678-abcd-efgh-ijkl-12345678"
|
||||
# current_tenants.add(task_name.split("_")[-1])
|
||||
# logger.info(f"Found {len(current_tenants)} existing items in schedule")
|
||||
# if the schedule or beat multiplier has changed, update
|
||||
while True:
|
||||
if beat_multiplier != self.last_beat_multiplier:
|
||||
do_update = True
|
||||
break
|
||||
|
||||
# for tenant_id in tenant_ids:
|
||||
# if tenant_id not in current_tenants:
|
||||
# logger.info(f"Processing new tenant: {tenant_id}")
|
||||
if not DynamicTenantScheduler._compare_schedules(
|
||||
current_schedule, new_schedule
|
||||
):
|
||||
do_update = True
|
||||
break
|
||||
|
||||
new_schedule = self._generate_schedule(tenant_ids)
|
||||
break
|
||||
|
||||
if DynamicTenantScheduler._compare_schedules(current_schedule, new_schedule):
|
||||
logger.info(
|
||||
"_try_updating_schedule: Current schedule is up to date, no changes needed"
|
||||
if not do_update:
|
||||
# exit early if nothing changed
|
||||
task_logger.info(
|
||||
f"_try_updating_schedule - Schedule unchanged: "
|
||||
f"tasks={len(new_schedule)} "
|
||||
f"beat_multiplier={beat_multiplier}"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
# schedule needs updating
|
||||
task_logger.debug(
|
||||
"Schedule update required",
|
||||
extra={
|
||||
"new_tasks": len(new_schedule),
|
||||
@@ -185,11 +215,19 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
# Ensure changes are persisted
|
||||
self.sync()
|
||||
|
||||
logger.info("_try_updating_schedule: Schedule updated successfully")
|
||||
task_logger.info(
|
||||
f"_try_updating_schedule - Schedule updated: "
|
||||
f"prev_num_tasks={len(current_schedule)} "
|
||||
f"prev_beat_multiplier={self.last_beat_multiplier} "
|
||||
f"tasks={len(new_schedule)} "
|
||||
f"beat_multiplier={beat_multiplier}"
|
||||
)
|
||||
|
||||
self.last_beat_multiplier = beat_multiplier
|
||||
|
||||
@staticmethod
|
||||
def _compare_schedules(schedule1: dict, schedule2: dict) -> bool:
|
||||
"""Compare schedules to determine if an update is needed.
|
||||
"""Compare schedules by task name only to determine if an update is needed.
|
||||
True if equivalent, False if not."""
|
||||
current_tasks = set(name for name, _ in schedule1)
|
||||
new_tasks = set(schedule2.keys())
|
||||
@@ -201,7 +239,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("beat_init signal received.")
|
||||
task_logger.info("beat_init signal received.")
|
||||
|
||||
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
|
||||
@@ -144,7 +144,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
r.delete(OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
r.delete(OnyxRedisConstants.ACTIVE_FENCES)
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
@@ -18,7 +19,8 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# hack to slow down task dispatch in the cloud until
|
||||
# we have a better implementation (backpressure, etc)
|
||||
CLOUD_BEAT_SCHEDULE_MULTIPLIER = 8
|
||||
# Note that DynamicTenantScheduler can adjust the runtime value for this via Redis
|
||||
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
|
||||
|
||||
# tasks that run in either self-hosted on cloud
|
||||
beat_task_templates: list[dict] = []
|
||||
@@ -55,16 +57,7 @@ beat_task_templates.extend(
|
||||
{
|
||||
"name": "check-for-pruning",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
@@ -121,7 +114,7 @@ def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
# constant options for cloud beat task generators
|
||||
task_schedule: timedelta = task["schedule"]
|
||||
cloud_task["schedule"] = task_schedule * CLOUD_BEAT_SCHEDULE_MULTIPLIER
|
||||
cloud_task["schedule"] = task_schedule
|
||||
cloud_task["options"] = {}
|
||||
cloud_task["options"]["priority"] = OnyxCeleryPriority.HIGHEST
|
||||
cloud_task["options"]["expires"] = BEAT_EXPIRES_DEFAULT
|
||||
@@ -140,14 +133,14 @@ def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
|
||||
return cloud_task
|
||||
|
||||
|
||||
# tasks that only run in the cloud
|
||||
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be filtered
|
||||
# by the DynamicTenantScheduler
|
||||
cloud_tasks_to_schedule: list[dict] = [
|
||||
# tasks that only run in the cloud and are system wide
|
||||
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be seen
|
||||
# by the DynamicTenantScheduler as system wide task and not a per tenant task
|
||||
beat_cloud_tasks: list[dict] = [
|
||||
# cloud specific tasks
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
|
||||
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-alembic",
|
||||
"task": OnyxCeleryTask.CLOUD_MONITOR_ALEMBIC,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
@@ -155,20 +148,74 @@ cloud_tasks_to_schedule: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-celery-queues",
|
||||
"task": OnyxCeleryTask.CLOUD_MONITOR_CELERY_QUEUES,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# generate our cloud and self-hosted beat tasks from the templates
|
||||
for beat_task_template in beat_task_templates:
|
||||
cloud_task = make_cloud_generator_task(beat_task_template)
|
||||
cloud_tasks_to_schedule.append(cloud_task)
|
||||
|
||||
# tasks that only run self hosted
|
||||
tasks_to_schedule: list[dict] = []
|
||||
if not MULTI_TENANT:
|
||||
tasks_to_schedule = beat_task_templates
|
||||
tasks_to_schedule.extend(
|
||||
[
|
||||
{
|
||||
"name": "monitor-celery-queues",
|
||||
"task": OnyxCeleryTask.MONITOR_CELERY_QUEUES,
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
tasks_to_schedule.extend(beat_task_templates)
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return cloud_tasks_to_schedule
|
||||
def generate_cloud_tasks(
|
||||
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
beat_tasks: system wide tasks that can be sent as is
|
||||
beat_templates: task templates that will be transformed into per tenant tasks via
|
||||
the cloud_beat_task_generator
|
||||
beat_multiplier: a multiplier that can be applied on top of the task schedule
|
||||
to speed up or slow down the task generation rate. useful in production.
|
||||
|
||||
Returns a list of cloud tasks, which consists of incoming tasks + tasks generated
|
||||
from incoming templates.
|
||||
"""
|
||||
|
||||
if beat_multiplier <= 0:
|
||||
raise ValueError("beat_multiplier must be positive!")
|
||||
|
||||
cloud_tasks: list[dict] = []
|
||||
|
||||
# generate our tenant aware cloud tasks from the templates
|
||||
for beat_template in beat_templates:
|
||||
cloud_task = make_cloud_generator_task(beat_template)
|
||||
cloud_tasks.append(cloud_task)
|
||||
|
||||
# factor in the cloud multiplier for the above
|
||||
for cloud_task in cloud_tasks:
|
||||
cloud_task["schedule"] = cloud_task["schedule"] * beat_multiplier
|
||||
|
||||
# add the fixed cloud/system beat tasks. No multiplier for these.
|
||||
cloud_tasks.extend(copy.deepcopy(beat_tasks))
|
||||
return cloud_tasks
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
|
||||
return generate_cloud_tasks(beat_cloud_tasks, beat_task_templates, beat_multiplier)
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -12,18 +16,35 @@ from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.connector_credential_pair import add_deletion_failure_message
|
||||
from onyx.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.index_attempt import delete_index_attempts
|
||||
from onyx.db.search_settings import get_all_search_settings
|
||||
from onyx.db.sync_record import cleanup_sync_records
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
|
||||
|
||||
class TaskDependencyError(RuntimeError):
|
||||
@@ -42,6 +63,7 @@ def check_for_connector_deletion_task(
|
||||
self: Task, *, tenant_id: str | None
|
||||
) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
@@ -77,6 +99,18 @@ def check_for_connector_deletion_task(
|
||||
# clear the stop signal if it exists ... no longer needed
|
||||
redis_connector.stop.set_fence(False)
|
||||
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
if not r.exists(key_bytes):
|
||||
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -212,3 +246,158 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
|
||||
return tasks_generated
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
fence_data = redis_connector.delete.payload
|
||||
if not fence_data:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
if fence_data.num_tasks is None:
|
||||
# the fence is setting up but isn't ready yet
|
||||
return
|
||||
|
||||
remaining = redis_connector.delete.get_remaining()
|
||||
task_logger.info(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
|
||||
)
|
||||
if remaining > 0:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.IN_PROGRESS,
|
||||
num_docs_synced=remaining,
|
||||
)
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
doc_ids = get_document_ids_for_connector_credential_pair(
|
||||
db_session, cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
if len(doc_ids) > 0:
|
||||
# NOTE(rkuo): if this happens, documents somehow got added while
|
||||
# deletion was in progress. Likely a bug gating off pruning and indexing
|
||||
# work before deletion starts.
|
||||
task_logger.warning(
|
||||
"Connector deletion - documents still found after taskset completion. "
|
||||
"Clearing the current deletion attempt and allowing deletion to restart: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"docs_deleted={fence_data.num_tasks} "
|
||||
f"docs_remaining={len(doc_ids)}"
|
||||
)
|
||||
|
||||
# We don't want to waive off why we get into this state, but resetting
|
||||
# our attempt and letting the deletion restart is a good way to recover
|
||||
redis_connector.delete.reset()
|
||||
raise RuntimeError(
|
||||
"Connector deletion - documents still found after taskset completion"
|
||||
)
|
||||
|
||||
# clean up the rest of the related Postgres entities
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# user groups
|
||||
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
|
||||
"onyx.db.user_group",
|
||||
"delete_user_group_cc_pair_relationship__no_commit",
|
||||
noop_fallback,
|
||||
)
|
||||
cleanup_user_groups(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Connector deletion - Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
stack_trace = traceback.format_exc()
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair_id, error_message)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.FAILED,
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
task_logger.exception(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Connector deletion succeeded: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"docs_deleted={fence_data.num_tasks}"
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
|
||||
@@ -175,6 +175,24 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
)
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=300)
|
||||
|
||||
# use a lookup table to find active fences. We still have to verify the fence
|
||||
# exists since it is an optimization and not the source of truth.
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
if not r.exists(key_bytes):
|
||||
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(
|
||||
tenant_id, key_bytes, r, db_session
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -349,6 +367,7 @@ def connector_permission_sync_generator_task(
|
||||
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
@@ -459,14 +478,15 @@ def update_external_document_permissions_task(
|
||||
)
|
||||
doc_id = document_external_access.doc_id
|
||||
external_access = document_external_access.external_access
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Add the users to the DB if they don't exist
|
||||
batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
emails=list(external_access.external_user_emails),
|
||||
continue_on_error=True,
|
||||
)
|
||||
# Then we upsert the document's external permissions in postgres
|
||||
# Then upsert the document's external permissions
|
||||
created_new_doc = upsert_document_external_perms(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
@@ -490,11 +510,11 @@ def update_external_document_permissions_task(
|
||||
f"action=update_permissions "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Exception in update_external_document_permissions_task: "
|
||||
f"connector_id={connector_id} "
|
||||
f"doc_id={doc_id}"
|
||||
f"connector_id={connector_id} doc_id={doc_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -755,7 +775,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
raise
|
||||
|
||||
|
||||
"""Monitoring CCPair permissions utils, called in monitor_vespa_sync"""
|
||||
"""Monitoring CCPair permissions utils"""
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
|
||||
@@ -26,11 +26,11 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -72,18 +72,26 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if external group sync is due."""
|
||||
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
# skip external group sync if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
task_logger.error(
|
||||
f"Recieved non-sync CC Pair {cc_pair.id} for external "
|
||||
f"group sync. Actual access type: {cc_pair.access_type}"
|
||||
)
|
||||
return False
|
||||
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
task_logger.debug(
|
||||
f"Skipping group sync for CC Pair {cc_pair.id} - "
|
||||
f"CC Pair is being deleted"
|
||||
)
|
||||
return False
|
||||
|
||||
# If there is not group sync function for the connector, we don't run the sync
|
||||
# This is fine because all sources dont necessarily have a concept of groups
|
||||
if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source):
|
||||
task_logger.debug(
|
||||
f"Skipping group sync for CC Pair {cc_pair.id} - "
|
||||
f"no group sync function for {cc_pair.connector.source}"
|
||||
)
|
||||
return False
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
@@ -125,6 +133,9 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
task_logger.warning(
|
||||
f"Failed to acquire beat lock for external group sync: {tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -205,20 +216,12 @@ def try_creating_external_group_sync_task(
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Dont kick off a new sync if the previous one is still running
|
||||
if redis_connector.external_group_sync.fenced:
|
||||
logger.warning(
|
||||
f"Skipping external group sync for CC Pair {cc_pair_id} - already running."
|
||||
)
|
||||
return None
|
||||
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
@@ -269,9 +272,6 @@ def try_creating_external_group_sync_task(
|
||||
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return payload_id
|
||||
|
||||
@@ -304,22 +304,26 @@ def connector_external_group_sync_generator_task(
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"connector_external_group_sync_generator_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
if not redis_connector.external_group_sync.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"connector_external_group_sync_generator_task - fence not found: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
payload = redis_connector.external_group_sync.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError(
|
||||
"connector_external_group_sync_generator_task: payload invalid or not found"
|
||||
)
|
||||
msg = "connector_external_group_sync_generator_task: payload invalid or not found"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
@@ -344,9 +348,9 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
msg = f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
task_logger.error(msg)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -367,9 +371,9 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if ext_group_sync_func is None:
|
||||
raise ValueError(
|
||||
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
msg = f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
logger.info(
|
||||
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
@@ -400,9 +404,9 @@ def connector_external_group_sync_generator_task(
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
|
||||
)
|
||||
msg = f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
|
||||
task_logger.exception(msg)
|
||||
emit_background_error(msg + f"\n\n{e}", cc_pair_id=cc_pair_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
@@ -492,9 +496,11 @@ def validate_external_group_sync_fence(
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
msg = (
|
||||
f"validate_external_group_sync_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
emit_background_error(msg)
|
||||
task_logger.error(msg)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
@@ -509,12 +515,14 @@ def validate_external_group_sync_fence(
|
||||
try:
|
||||
payload = redis_connector.external_group_sync.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
msg = (
|
||||
"validate_external_group_sync_fence - "
|
||||
"Resetting fence because fence schema is out of date: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
task_logger.exception(msg)
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
|
||||
redis_connector.external_group_sync.reset()
|
||||
return
|
||||
@@ -551,12 +559,15 @@ def validate_external_group_sync_fence(
|
||||
# return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
logger.warning(
|
||||
"validate_external_group_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
emit_background_error(
|
||||
message=(
|
||||
"validate_external_group_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
),
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.reset()
|
||||
|
||||
@@ -6,13 +6,18 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import sentry_sdk
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.result import AsyncResult
|
||||
from celery.states import READY_STATES
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
@@ -30,6 +35,7 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
@@ -37,6 +43,7 @@ from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingMode
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
@@ -47,9 +54,12 @@ from onyx.db.swap_index import check_index_swap
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_utils import is_fence
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
@@ -60,6 +70,150 @@ from shared_configs.configs import SENTRY_DSN
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"Connector indexing: could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
parts = composite_id.split("/")
|
||||
if len(parts) != 2:
|
||||
return
|
||||
|
||||
cc_pair_id = int(parts[0])
|
||||
search_settings_id = int(parts[1])
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
if not redis_connector_index.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
elapsed_started_str = None
|
||||
if payload.started:
|
||||
elapsed_started = datetime.now(timezone.utc) - payload.started
|
||||
elapsed_started_str = f"{elapsed_started.total_seconds():.2f}"
|
||||
|
||||
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
|
||||
|
||||
progress = redis_connector_index.get_progress()
|
||||
if progress is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
# the task is still setting up
|
||||
return
|
||||
|
||||
# never use any blocking methods on the result from inside a task!
|
||||
result: AsyncResult = AsyncResult(payload.celery_task_id)
|
||||
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
|
||||
# Verify: if the generator isn't complete, the task must not be in READY state
|
||||
# inner = get_completion / generator_complete not signaled
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int is None: # inner signal not set ... possible error
|
||||
task_state = result.state
|
||||
if (
|
||||
task_state in READY_STATES
|
||||
): # outer signal in terminal state ... possible error
|
||||
# Now double check!
|
||||
if redis_connector_index.get_completion() is None:
|
||||
# inner signal still not set (and cannot change when outer result_state is READY)
|
||||
# Task is finished but generator complete isn't set.
|
||||
# We have a problem! Worker may have crashed.
|
||||
task_result = str(result.result)
|
||||
task_traceback = str(result.traceback)
|
||||
|
||||
msg = (
|
||||
f"Connector indexing aborted or exceptioned: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"celery_task={payload.celery_task_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"result.state={task_state} "
|
||||
f"result.result={task_result} "
|
||||
f"result.traceback={task_traceback}"
|
||||
)
|
||||
task_logger.warning(msg)
|
||||
|
||||
try:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session, payload.index_attempt_id
|
||||
)
|
||||
if index_attempt:
|
||||
if (
|
||||
index_attempt.status != IndexingStatus.CANCELED
|
||||
and index_attempt.status != IndexingStatus.FAILED
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Connector indexing - Transient exception marking index attempt as failed: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
if redis_connector_index.watchdog_signaled():
|
||||
# if the generator is complete, don't clean up until the watchdog has exited
|
||||
task_logger.info(
|
||||
f"Connector indexing - Delaying finalization until watchdog has exited: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
soft_time_limit=300,
|
||||
@@ -91,6 +245,25 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
try:
|
||||
locked = True
|
||||
|
||||
# SPECIAL 0/3: sync lookup table for active fences
|
||||
# we want to run this less frequently than the overall task
|
||||
if not redis_client.exists(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE):
|
||||
# build a lookup table of existing fences
|
||||
# this is just a migration concern and should be unnecessary once
|
||||
# lookup tables are rolled out
|
||||
for key_bytes in redis_client_replica.scan_iter(
|
||||
count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
if is_fence(key_bytes) and not redis_client.sismember(
|
||||
OnyxRedisConstants.ACTIVE_FENCES, key_bytes
|
||||
):
|
||||
logger.warning(f"Adding {key_bytes} to the lookup table.")
|
||||
redis_client.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
|
||||
redis_client.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
|
||||
|
||||
# 1/3: KICKOFF
|
||||
|
||||
# check for search settings swap
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_search_settings = check_index_swap(db_session=db_session)
|
||||
@@ -197,6 +370,8 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
# 2/3: VALIDATE
|
||||
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
# This shouldn't ever happen!
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -236,6 +411,26 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
task_logger.exception("Exception while validating indexing fences")
|
||||
|
||||
redis_client.set(OnyxRedisSignals.BLOCK_VALIDATE_INDEXING_FENCES, 1, ex=60)
|
||||
|
||||
# 3/3: FINALIZE
|
||||
lock_beat.reacquire()
|
||||
keys = cast(
|
||||
set[Any], redis_client_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES)
|
||||
)
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
if not redis_client.exists(key_bytes):
|
||||
redis_client.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(
|
||||
tenant_id, key_bytes, redis_client_replica, db_session
|
||||
)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
|
||||
@@ -17,7 +17,8 @@ from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -420,6 +421,7 @@ def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]
|
||||
- Throughput (docs/min) (only if success)
|
||||
- Raw start/end times for each sync
|
||||
"""
|
||||
|
||||
one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1)
|
||||
|
||||
# Get all sync records that ended in the last hour
|
||||
@@ -587,6 +589,10 @@ def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]
|
||||
entity = db_session.scalar(
|
||||
select(UserGroup).where(UserGroup.id == sync_record.entity_id)
|
||||
)
|
||||
else:
|
||||
# Only user groups and document set sync records have
|
||||
# an associated entity we can use for latency metrics
|
||||
continue
|
||||
|
||||
if entity is None:
|
||||
task_logger.error(
|
||||
@@ -717,7 +723,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
|
||||
name=OnyxCeleryTask.CLOUD_MONITOR_ALEMBIC,
|
||||
)
|
||||
def cloud_check_alembic() -> bool | None:
|
||||
"""A task to verify that all tenants are on the same alembic revision.
|
||||
@@ -777,7 +783,7 @@ def cloud_check_alembic() -> bool | None:
|
||||
|
||||
tenant_to_revision[tenant_id] = result_scalar
|
||||
except Exception:
|
||||
task_logger.warning(f"Tenant {tenant_id} has no revision!")
|
||||
task_logger.error(f"Tenant {tenant_id} has no revision!")
|
||||
tenant_to_revision[tenant_id] = ALEMBIC_NULL_REVISION
|
||||
|
||||
# get the total count of each revision
|
||||
@@ -847,3 +853,55 @@ def cloud_check_alembic() -> bool | None:
|
||||
f"cloud_check_alembic finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_MONITOR_CELERY_QUEUES, ignore_result=True, bind=True
|
||||
)
|
||||
def cloud_monitor_celery_queues(
|
||||
self: Task,
|
||||
) -> None:
|
||||
return monitor_celery_queues_helper(self)
|
||||
|
||||
|
||||
@shared_task(name=OnyxCeleryTask.MONITOR_CELERY_QUEUES, ignore_result=True, bind=True)
|
||||
def monitor_celery_queues(self: Task, *, tenant_id: str | None) -> None:
|
||||
return monitor_celery_queues_helper(self)
|
||||
|
||||
|
||||
def monitor_celery_queues_helper(
|
||||
task: Task,
|
||||
) -> None:
|
||||
"""A task to monitor all celery queue lengths."""
|
||||
|
||||
r_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery)
|
||||
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
|
||||
n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery)
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
n_external_group_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
n_permissions_upsert = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
|
||||
n_indexing_prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"indexing_prefetched={len(n_indexing_prefetched)} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
f"external_group_sync={n_external_group_sync} "
|
||||
f"permissions_upsert={n_permissions_upsert} "
|
||||
)
|
||||
|
||||
@@ -122,34 +122,39 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
return None
|
||||
|
||||
try:
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
# the entire task needs to run frequently in order to finalize pruning
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
# but pruning only kicks off once per hour
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
continue
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
if not _is_pruning_due(cc_pair):
|
||||
continue
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
payload_id = try_creating_prune_generator_task(
|
||||
self.app, cc_pair, db_session, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
continue
|
||||
if not _is_pruning_due(cc_pair):
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}"
|
||||
)
|
||||
payload_id = try_creating_prune_generator_task(
|
||||
self.app, cc_pair, db_session, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}"
|
||||
)
|
||||
r.set(OnyxRedisSignals.BLOCK_PRUNING, 1, ex=3600)
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
lock_beat.reacquire()
|
||||
@@ -163,6 +168,22 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
task_logger.exception("Exception while validating pruning fences")
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_PRUNING_FENCES, 1, ex=300)
|
||||
|
||||
# use a lookup table to find active fences. We still have to verify the fence
|
||||
# exists since it is an optimization and not the source of truth.
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
if not r.exists(key_bytes):
|
||||
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -481,7 +502,7 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
|
||||
|
||||
"""Monitoring pruning utils, called in monitor_vespa_sync"""
|
||||
"""Monitoring pruning utils"""
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
|
||||
@@ -8,6 +8,7 @@ from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
from tenacity import RetryError
|
||||
|
||||
from ee.onyx.server.tenants.product_gating import get_gated_tenants
|
||||
from onyx.access.access import get_access_for_document
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
@@ -252,7 +253,11 @@ def cloud_beat_task_generator(
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
gated_tenants = get_gated_tenants()
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id in gated_tenants:
|
||||
continue
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
lock_beat.reacquire()
|
||||
@@ -270,6 +275,7 @@ def cloud_beat_task_generator(
|
||||
queue=queue,
|
||||
priority=priority,
|
||||
expires=expires,
|
||||
ignore_result=True,
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -13,8 +9,6 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.result import AsyncResult
|
||||
from celery.states import READY_STATES
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -22,47 +16,27 @@ from tenacity import RetryError
|
||||
|
||||
from onyx.access.access import get_access_for_document
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.doc_permission_syncing.tasks import (
|
||||
monitor_ccpair_permissions_taskset,
|
||||
)
|
||||
from onyx.background.celery.tasks.pruning.tasks import monitor_ccpair_pruning_taskset
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.connector_credential_pair import add_deletion_failure_message
|
||||
from onyx.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import count_documents_by_needs_sync
|
||||
from onyx.db.document import get_document
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document import mark_document_as_synced
|
||||
from onyx.db.document_set import delete_document_set
|
||||
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from onyx.db.document_set import fetch_document_sets
|
||||
from onyx.db.document_set import fetch_document_sets_for_document
|
||||
from onyx.db.document_set import get_document_set_by_id
|
||||
from onyx.db.document_set import mark_document_set_as_synced
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.index_attempt import delete_index_attempts
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
@@ -72,20 +46,14 @@ from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from onyx.redis.redis_connector_credential_pair import (
|
||||
RedisGlobalConnectorCredentialPair,
|
||||
)
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
@@ -94,7 +62,6 @@ from onyx.utils.variable_functionality import (
|
||||
)
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -114,6 +81,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
time_start = time.monotonic()
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||
@@ -125,6 +93,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
return None
|
||||
|
||||
try:
|
||||
# 1/3: KICKOFF
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_stale_document_sync_tasks(
|
||||
self.app, VESPA_SYNC_MAX_TASKS, db_session, r, lock_beat, tenant_id
|
||||
@@ -151,9 +120,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
# endregion
|
||||
|
||||
# check if any user groups are not synced
|
||||
lock_beat.reacquire()
|
||||
if global_version.is_ee_version():
|
||||
lock_beat.reacquire()
|
||||
|
||||
try:
|
||||
fetch_user_groups = fetch_versioned_implementation(
|
||||
"onyx.db.user_group", "fetch_user_groups"
|
||||
@@ -179,6 +147,35 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
|
||||
# 2/3: VALIDATE: TODO
|
||||
|
||||
# 3/3: FINALIZE
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
if not r.exists(key_bytes):
|
||||
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
|
||||
monitor_connector_taskset(r)
|
||||
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
elif key_str.startswith(RedisUserGroup.FENCE_PREFIX):
|
||||
monitor_usergroup_taskset = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"onyx.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -506,475 +503,6 @@ def monitor_document_set_taskset(
|
||||
rds.reset()
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
fence_data = redis_connector.delete.payload
|
||||
if not fence_data:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
if fence_data.num_tasks is None:
|
||||
# the fence is setting up but isn't ready yet
|
||||
return
|
||||
|
||||
remaining = redis_connector.delete.get_remaining()
|
||||
task_logger.info(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
|
||||
)
|
||||
if remaining > 0:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.IN_PROGRESS,
|
||||
num_docs_synced=remaining,
|
||||
)
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
doc_ids = get_document_ids_for_connector_credential_pair(
|
||||
db_session, cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
if len(doc_ids) > 0:
|
||||
# NOTE(rkuo): if this happens, documents somehow got added while
|
||||
# deletion was in progress. Likely a bug gating off pruning and indexing
|
||||
# work before deletion starts.
|
||||
task_logger.warning(
|
||||
"Connector deletion - documents still found after taskset completion. "
|
||||
"Clearing the current deletion attempt and allowing deletion to restart: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"docs_deleted={fence_data.num_tasks} "
|
||||
f"docs_remaining={len(doc_ids)}"
|
||||
)
|
||||
|
||||
# We don't want to waive off why we get into this state, but resetting
|
||||
# our attempt and letting the deletion restart is a good way to recover
|
||||
redis_connector.delete.reset()
|
||||
raise RuntimeError(
|
||||
"Connector deletion - documents still found after taskset completion"
|
||||
)
|
||||
|
||||
# clean up the rest of the related Postgres entities
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# user groups
|
||||
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
|
||||
"onyx.db.user_group",
|
||||
"delete_user_group_cc_pair_relationship__no_commit",
|
||||
noop_fallback,
|
||||
)
|
||||
cleanup_user_groups(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Connector deletion - Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
stack_trace = traceback.format_exc()
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair_id, error_message)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.FAILED,
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
task_logger.exception(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Connector deletion succeeded: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"docs_deleted={fence_data.num_tasks}"
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"Connector indexing: could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
parts = composite_id.split("/")
|
||||
if len(parts) != 2:
|
||||
return
|
||||
|
||||
cc_pair_id = int(parts[0])
|
||||
search_settings_id = int(parts[1])
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
if not redis_connector_index.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
elapsed_started_str = None
|
||||
if payload.started:
|
||||
elapsed_started = datetime.now(timezone.utc) - payload.started
|
||||
elapsed_started_str = f"{elapsed_started.total_seconds():.2f}"
|
||||
|
||||
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
|
||||
|
||||
progress = redis_connector_index.get_progress()
|
||||
if progress is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
# the task is still setting up
|
||||
return
|
||||
|
||||
# never use any blocking methods on the result from inside a task!
|
||||
result: AsyncResult = AsyncResult(payload.celery_task_id)
|
||||
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
|
||||
# Verify: if the generator isn't complete, the task must not be in READY state
|
||||
# inner = get_completion / generator_complete not signaled
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int is None: # inner signal not set ... possible error
|
||||
task_state = result.state
|
||||
if (
|
||||
task_state in READY_STATES
|
||||
): # outer signal in terminal state ... possible error
|
||||
# Now double check!
|
||||
if redis_connector_index.get_completion() is None:
|
||||
# inner signal still not set (and cannot change when outer result_state is READY)
|
||||
# Task is finished but generator complete isn't set.
|
||||
# We have a problem! Worker may have crashed.
|
||||
task_result = str(result.result)
|
||||
task_traceback = str(result.traceback)
|
||||
|
||||
msg = (
|
||||
f"Connector indexing aborted or exceptioned: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"celery_task={payload.celery_task_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"result.state={task_state} "
|
||||
f"result.result={task_result} "
|
||||
f"result.traceback={task_traceback}"
|
||||
)
|
||||
task_logger.warning(msg)
|
||||
|
||||
try:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session, payload.index_attempt_id
|
||||
)
|
||||
if index_attempt:
|
||||
if (
|
||||
index_attempt.status != IndexingStatus.CANCELED
|
||||
and index_attempt.status != IndexingStatus.FAILED
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Connector indexing - Transient exception marking index attempt as failed: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
if redis_connector_index.watchdog_signaled():
|
||||
# if the generator is complete, don't clean up until the watchdog has exited
|
||||
task_logger.info(
|
||||
f"Connector indexing - Delaying finalization until watchdog has exited: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
ignore_result=True,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
"""This is a celery beat task that monitors and finalizes various long running tasks.
|
||||
|
||||
The name monitor_vespa_sync is a bit of a misnomer since it checks many different tasks
|
||||
now. Should change that at some point.
|
||||
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
For many tasks, the count is 0, that means all tasks finished and we should clean up.
|
||||
|
||||
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
|
||||
do anything too expensive in this function!
|
||||
|
||||
Returns True if the task actually did work, False if it exited early to prevent overlap
|
||||
"""
|
||||
task_logger.info(f"monitor_vespa_sync starting: tenant={tenant_id}")
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Replica usage notes
|
||||
#
|
||||
# False negatives are OK. (aka fail to to see a key that exists on the master).
|
||||
# We simply skip the monitoring work and it will be caught on the next pass.
|
||||
#
|
||||
# False positives are not OK, and are possible if we clear a fence on the master and
|
||||
# then read from the replica. In this case, monitoring work could be done on a fence
|
||||
# that no longer exists. To avoid this, we scan from the replica, but double check
|
||||
# the result on the master.
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# prevent overlapping tasks
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
# print current queue lengths
|
||||
time.monotonic()
|
||||
# we don't need every tenant polling redis for this info.
|
||||
if not MULTI_TENANT or random.randint(1, 10) == 10:
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
n_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery
|
||||
)
|
||||
n_deletion = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
n_pruning = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
n_external_group_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
n_permissions_upsert = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
|
||||
prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"indexing_prefetched={len(prefetched)} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
f"external_group_sync={n_external_group_sync} "
|
||||
f"permissions_upsert={n_permissions_upsert} "
|
||||
)
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE):
|
||||
# build a lookup table of existing fences
|
||||
# this is just a migration concern and should be unnecessary once
|
||||
# lookup tables are rolled out
|
||||
for key_bytes in r_replica.scan_iter(count=SCAN_ITER_COUNT_DEFAULT):
|
||||
if is_fence(key_bytes) and not r.sismember(
|
||||
OnyxRedisConstants.ACTIVE_FENCES, key_bytes
|
||||
):
|
||||
logger.warning(f"Adding {key_bytes} to the lookup table.")
|
||||
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
|
||||
|
||||
# use a lookup table to find active fences. We still have to verify the fence
|
||||
# exists since it is an optimization and not the source of truth.
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
if not r.exists(key_bytes):
|
||||
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
|
||||
monitor_connector_taskset(r)
|
||||
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
elif key_str.startswith(RedisUserGroup.FENCE_PREFIX):
|
||||
monitor_usergroup_taskset = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"onyx.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
elif key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
elif key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
elif key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
elif key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(
|
||||
tenant_id, key_bytes, r, db_session
|
||||
)
|
||||
else:
|
||||
pass
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
task_logger.exception("monitor_vespa_sync exceptioned.")
|
||||
return False
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"monitor_vespa_sync - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
# f"timings={timings}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, r)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
|
||||
return True
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
bind=True,
|
||||
@@ -1072,23 +600,3 @@ def vespa_metadata_sync_task(
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_fence(key_bytes: bytes) -> bool:
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
|
||||
return True
|
||||
if key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisUserGroup.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
13
backend/onyx/background/error_logging.py
Normal file
13
backend/onyx/background/error_logging.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from onyx.db.background_error import create_background_error
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
|
||||
|
||||
def emit_background_error(
|
||||
message: str,
|
||||
cc_pair_id: int | None = None,
|
||||
) -> None:
|
||||
"""Currently just saves a row in the background_errors table.
|
||||
|
||||
In the future, could create notifications based on the severity."""
|
||||
with get_session_with_tenant() as db_session:
|
||||
create_background_error(db_session, message, cc_pair_id)
|
||||
@@ -107,9 +107,9 @@ CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
|
||||
CELERY_PRUNING_LOCK_TIMEOUT = 3600 # 1 hour (in seconds)
|
||||
|
||||
CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 3600 # 1 hour (in seconds)
|
||||
|
||||
CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
@@ -298,7 +298,6 @@ class OnyxRedisLocks:
|
||||
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
|
||||
"da_lock:check_connector_external_group_sync_beat"
|
||||
)
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
|
||||
|
||||
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
|
||||
@@ -324,6 +323,7 @@ class OnyxRedisSignals:
|
||||
BLOCK_VALIDATE_PERMISSION_SYNC_FENCES = (
|
||||
"signal:block_validate_permission_sync_fences"
|
||||
)
|
||||
BLOCK_PRUNING = "signal:block_pruning"
|
||||
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
|
||||
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
|
||||
|
||||
@@ -346,12 +346,18 @@ ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud"
|
||||
# the tenant id we use for system level redis operations
|
||||
ONYX_CLOUD_TENANT_ID = "cloud"
|
||||
|
||||
# the redis namespace for runtime variables
|
||||
ONYX_CLOUD_REDIS_RUNTIME = "runtime"
|
||||
|
||||
|
||||
class OnyxCeleryTask:
|
||||
DEFAULT = "celery"
|
||||
|
||||
CLOUD_BEAT_TASK_GENERATOR = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_generate_beat_tasks"
|
||||
CLOUD_CHECK_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_alembic"
|
||||
CLOUD_MONITOR_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_alembic"
|
||||
CLOUD_MONITOR_CELERY_QUEUES = (
|
||||
f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_queues"
|
||||
)
|
||||
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
@@ -361,8 +367,8 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
|
||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
|
||||
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
|
||||
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
|
||||
@@ -65,10 +65,25 @@ class AirtableConnector(LoadConnector):
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
treat_all_non_attachment_fields_as_metadata: bool = False,
|
||||
view_id: str | None = None,
|
||||
share_id: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
"""Initialize an AirtableConnector.
|
||||
|
||||
Args:
|
||||
base_id: The ID of the Airtable base to connect to
|
||||
table_name_or_id: The name or ID of the table to index
|
||||
treat_all_non_attachment_fields_as_metadata: If True, all fields except attachments will be treated as metadata.
|
||||
If False, only fields with types in DEFAULT_METADATA_FIELD_TYPES will be treated as metadata.
|
||||
view_id: Optional ID of a specific view to use
|
||||
share_id: Optional ID of a "share" to use for generating record URLs (https://airtable.com/developers/web/api/list-shares)
|
||||
batch_size: Number of records to process in each batch
|
||||
"""
|
||||
self.base_id = base_id
|
||||
self.table_name_or_id = table_name_or_id
|
||||
self.view_id = view_id
|
||||
self.share_id = share_id
|
||||
self.batch_size = batch_size
|
||||
self._airtable_client: AirtableApi | None = None
|
||||
self.treat_all_non_attachment_fields_as_metadata = (
|
||||
@@ -85,6 +100,39 @@ class AirtableConnector(LoadConnector):
|
||||
raise AirtableClientNotSetUpError()
|
||||
return self._airtable_client
|
||||
|
||||
@classmethod
|
||||
def _get_record_url(
|
||||
cls,
|
||||
base_id: str,
|
||||
table_id: str,
|
||||
record_id: str,
|
||||
share_id: str | None,
|
||||
view_id: str | None,
|
||||
field_id: str | None = None,
|
||||
attachment_id: str | None = None,
|
||||
) -> str:
|
||||
"""Constructs the URL for a record, optionally including field and attachment IDs
|
||||
|
||||
Full possible structure is:
|
||||
|
||||
https://airtable.com/BASE_ID/SHARE_ID/TABLE_ID/VIEW_ID/RECORD_ID/FIELD_ID/ATTACHMENT_ID
|
||||
"""
|
||||
# If we have a shared link, use that view for better UX
|
||||
if share_id:
|
||||
base_url = f"https://airtable.com/{base_id}/{share_id}/{table_id}"
|
||||
else:
|
||||
base_url = f"https://airtable.com/{base_id}/{table_id}"
|
||||
|
||||
if view_id:
|
||||
base_url = f"{base_url}/{view_id}"
|
||||
|
||||
base_url = f"{base_url}/{record_id}"
|
||||
|
||||
if field_id and attachment_id:
|
||||
return f"{base_url}/{field_id}/{attachment_id}?blocks=hide"
|
||||
|
||||
return base_url
|
||||
|
||||
def _extract_field_values(
|
||||
self,
|
||||
field_id: str,
|
||||
@@ -110,8 +158,10 @@ class AirtableConnector(LoadConnector):
|
||||
if field_type == "multipleRecordLinks":
|
||||
return []
|
||||
|
||||
# default link to use for non-attachment fields
|
||||
default_link = f"https://airtable.com/{base_id}/{table_id}/{record_id}"
|
||||
# Get the base URL for this record
|
||||
default_link = self._get_record_url(
|
||||
base_id, table_id, record_id, self.share_id, self.view_id or view_id
|
||||
)
|
||||
|
||||
if field_type == "multipleAttachments":
|
||||
attachment_texts: list[tuple[str, str]] = []
|
||||
@@ -165,17 +215,16 @@ class AirtableConnector(LoadConnector):
|
||||
extension=file_ext,
|
||||
)
|
||||
if attachment_text:
|
||||
# slightly nicer loading experience if we can specify the view ID
|
||||
if view_id:
|
||||
attachment_link = (
|
||||
f"https://airtable.com/{base_id}/{table_id}/{view_id}/{record_id}"
|
||||
f"/{field_id}/{attachment_id}?blocks=hide"
|
||||
)
|
||||
else:
|
||||
attachment_link = (
|
||||
f"https://airtable.com/{base_id}/{table_id}/{record_id}"
|
||||
f"/{field_id}/{attachment_id}?blocks=hide"
|
||||
)
|
||||
# Use the helper method to construct attachment URLs
|
||||
attachment_link = self._get_record_url(
|
||||
base_id,
|
||||
table_id,
|
||||
record_id,
|
||||
self.share_id,
|
||||
self.view_id or view_id,
|
||||
field_id,
|
||||
attachment_id,
|
||||
)
|
||||
attachment_texts.append(
|
||||
(f"{filename}:\n{attachment_text}", attachment_link)
|
||||
)
|
||||
|
||||
@@ -145,7 +145,8 @@ def fetch_jira_issues_batch(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=issue.fields.summary,
|
||||
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
|
||||
title=f"{issue.key} {issue.fields.summary}",
|
||||
doc_updated_at=time_str_to_utc(issue.fields.updated),
|
||||
primary_owners=list(people) or None,
|
||||
# TODO add secondary_owners (commenters) if needed
|
||||
|
||||
@@ -39,19 +39,6 @@ def get_message_link(
|
||||
return permalink
|
||||
|
||||
|
||||
def _make_slack_api_call_logged(
|
||||
call: Callable[..., SlackResponse],
|
||||
) -> Callable[..., SlackResponse]:
|
||||
@wraps(call)
|
||||
def logged_call(**kwargs: Any) -> SlackResponse:
|
||||
logger.debug(f"Making call to Slack API '{call.__name__}' with args '{kwargs}'")
|
||||
result = call(**kwargs)
|
||||
logger.debug(f"Call to Slack API '{call.__name__}' returned '{result}'")
|
||||
return result
|
||||
|
||||
return logged_call
|
||||
|
||||
|
||||
def _make_slack_api_call_paginated(
|
||||
call: Callable[..., SlackResponse],
|
||||
) -> Callable[..., Generator[dict[str, Any], None, None]]:
|
||||
@@ -127,18 +114,14 @@ def make_slack_api_rate_limited(
|
||||
def make_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
|
||||
)(**kwargs)
|
||||
return basic_retry_wrapper(make_slack_api_rate_limited(call))(**kwargs)
|
||||
|
||||
|
||||
def make_paginated_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
return _make_slack_api_call_paginated(
|
||||
basic_retry_wrapper(
|
||||
make_slack_api_rate_limited(_make_slack_api_call_logged(call))
|
||||
)
|
||||
basic_retry_wrapper(make_slack_api_rate_limited(call))
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -51,6 +51,7 @@ class SearchPipeline:
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
fast_llm: LLM,
|
||||
skip_query_analysis: bool,
|
||||
db_session: Session,
|
||||
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
|
||||
retrieval_metrics_callback: (
|
||||
@@ -61,10 +62,13 @@ class SearchPipeline:
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
prompt_config: PromptConfig | None = None,
|
||||
):
|
||||
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
|
||||
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
|
||||
self.search_request = search_request
|
||||
self.user = user
|
||||
self.llm = llm
|
||||
self.fast_llm = fast_llm
|
||||
self.skip_query_analysis = skip_query_analysis
|
||||
self.db_session = db_session
|
||||
self.bypass_acl = bypass_acl
|
||||
self.retrieval_metrics_callback = retrieval_metrics_callback
|
||||
@@ -106,6 +110,7 @@ class SearchPipeline:
|
||||
search_request=self.search_request,
|
||||
user=self.user,
|
||||
llm=self.llm,
|
||||
skip_query_analysis=self.skip_query_analysis,
|
||||
db_session=self.db_session,
|
||||
bypass_acl=self.bypass_acl,
|
||||
)
|
||||
@@ -160,6 +165,12 @@ class SearchPipeline:
|
||||
that have a corresponding chunk.
|
||||
|
||||
This step should be fast for any document index implementation.
|
||||
|
||||
Current implementation timing is approximately broken down in timing as:
|
||||
- 200 ms to get the embedding of the query
|
||||
- 15 ms to get chunks from the document index
|
||||
- possibly more to get additional surrounding chunks
|
||||
- possibly more for query expansion (multilingual)
|
||||
"""
|
||||
if self._retrieved_sections is not None:
|
||||
return self._retrieved_sections
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import MAX_METRICS_CONTENT
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RerankMetricsContainer
|
||||
from onyx.context.search.models import SearchQuery
|
||||
from onyx.document_index.document_index_utils import (
|
||||
@@ -77,7 +78,8 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def semantic_reranking(
|
||||
query: SearchQuery,
|
||||
query_str: str,
|
||||
rerank_settings: RerankingDetails,
|
||||
chunks: list[InferenceChunk],
|
||||
model_min: int = CROSS_ENCODER_RANGE_MIN,
|
||||
model_max: int = CROSS_ENCODER_RANGE_MAX,
|
||||
@@ -88,11 +90,9 @@ def semantic_reranking(
|
||||
|
||||
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
|
||||
"""
|
||||
rerank_settings = query.rerank_settings
|
||||
|
||||
if not rerank_settings or not rerank_settings.rerank_model_name:
|
||||
# Should never reach this part of the flow without reranking settings
|
||||
raise RuntimeError("Reranking flow should not be running")
|
||||
assert (
|
||||
rerank_settings.rerank_model_name
|
||||
), "Reranking flow cannot run without a specific model"
|
||||
|
||||
chunks_to_rerank = chunks[: rerank_settings.num_rerank]
|
||||
|
||||
@@ -107,7 +107,7 @@ def semantic_reranking(
|
||||
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
|
||||
for chunk in chunks_to_rerank
|
||||
]
|
||||
sim_scores_floats = cross_encoder.predict(query=query.query, passages=passages)
|
||||
sim_scores_floats = cross_encoder.predict(query=query_str, passages=passages)
|
||||
|
||||
# Old logic to handle multiple cross-encoders preserved but not used
|
||||
sim_scores = [numpy.array(sim_scores_floats)]
|
||||
@@ -165,8 +165,20 @@ def semantic_reranking(
|
||||
return list(ranked_chunks), list(ranked_indices)
|
||||
|
||||
|
||||
def should_rerank(rerank_settings: RerankingDetails | None) -> bool:
|
||||
"""Based on the RerankingDetails model, only run rerank if the following conditions are met:
|
||||
- rerank_model_name is not None
|
||||
- num_rerank is greater than 0
|
||||
"""
|
||||
if not rerank_settings:
|
||||
return False
|
||||
|
||||
return bool(rerank_settings.rerank_model_name and rerank_settings.num_rerank > 0)
|
||||
|
||||
|
||||
def rerank_sections(
|
||||
query: SearchQuery,
|
||||
query_str: str,
|
||||
rerank_settings: RerankingDetails,
|
||||
sections_to_rerank: list[InferenceSection],
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> list[InferenceSection]:
|
||||
@@ -181,16 +193,13 @@ def rerank_sections(
|
||||
"""
|
||||
chunks_to_rerank = [section.center_chunk for section in sections_to_rerank]
|
||||
|
||||
if not query.rerank_settings:
|
||||
# Should never reach this part of the flow without reranking settings
|
||||
raise RuntimeError("Reranking settings not found")
|
||||
|
||||
ranked_chunks, _ = semantic_reranking(
|
||||
query=query,
|
||||
query_str=query_str,
|
||||
rerank_settings=rerank_settings,
|
||||
chunks=chunks_to_rerank,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
lower_chunks = chunks_to_rerank[query.rerank_settings.num_rerank :]
|
||||
lower_chunks = chunks_to_rerank[rerank_settings.num_rerank :]
|
||||
|
||||
# Scores from rerank cannot be meaningfully combined with scores without rerank
|
||||
# However the ordering is still important
|
||||
@@ -260,16 +269,13 @@ def search_postprocessing(
|
||||
|
||||
rerank_task_id = None
|
||||
sections_yielded = False
|
||||
if (
|
||||
search_query.rerank_settings
|
||||
and search_query.rerank_settings.rerank_model_name
|
||||
and search_query.rerank_settings.num_rerank > 0
|
||||
):
|
||||
if should_rerank(search_query.rerank_settings):
|
||||
post_processing_tasks.append(
|
||||
FunctionCall(
|
||||
rerank_sections,
|
||||
(
|
||||
search_query,
|
||||
search_query.query,
|
||||
search_query.rerank_settings, # Cannot be None here
|
||||
retrieved_sections,
|
||||
rerank_metrics_callback,
|
||||
),
|
||||
|
||||
@@ -50,11 +50,11 @@ def retrieval_preprocessing(
|
||||
search_request: SearchRequest,
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
skip_query_analysis: bool,
|
||||
db_session: Session,
|
||||
bypass_acl: bool = False,
|
||||
skip_query_analysis: bool = False,
|
||||
base_recency_decay: float = BASE_RECENCY_DECAY,
|
||||
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
|
||||
base_recency_decay: float = BASE_RECENCY_DECAY,
|
||||
bypass_acl: bool = False,
|
||||
) -> SearchQuery:
|
||||
"""Logic is as follows:
|
||||
Any global disables apply first
|
||||
@@ -146,7 +146,7 @@ def retrieval_preprocessing(
|
||||
is_keyword, extracted_keywords = (
|
||||
parallel_results[run_query_analysis.result_id]
|
||||
if run_query_analysis
|
||||
else (None, None)
|
||||
else (False, None)
|
||||
)
|
||||
|
||||
all_query_terms = query.split()
|
||||
|
||||
10
backend/onyx/db/background_error.py
Normal file
10
backend/onyx/db/background_error.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import BackgroundError
|
||||
|
||||
|
||||
def create_background_error(
|
||||
db_session: Session, message: str, cc_pair_id: int | None
|
||||
) -> None:
|
||||
db_session.add(BackgroundError(message=message, cc_pair_id=cc_pair_id))
|
||||
db_session.commit()
|
||||
@@ -483,6 +483,10 @@ class ConnectorCredentialPair(Base):
|
||||
primaryjoin="foreign(ConnectorCredentialPair.creator_id) == remote(User.id)",
|
||||
)
|
||||
|
||||
background_errors: Mapped[list["BackgroundError"]] = relationship(
|
||||
"BackgroundError", back_populates="cc_pair", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Document(Base):
|
||||
__tablename__ = "document"
|
||||
@@ -2115,6 +2119,31 @@ class StandardAnswer(Base):
|
||||
)
|
||||
|
||||
|
||||
class BackgroundError(Base):
|
||||
"""Important background errors. Serves to:
|
||||
1. Ensure that important logs are kept around and not lost on rotation/container restarts
|
||||
2. A trail for high-signal events so that the debugger doesn't need to remember/know every
|
||||
possible relevant log line.
|
||||
"""
|
||||
|
||||
__tablename__ = "background_error"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
message: Mapped[str] = mapped_column(String)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
# option to link the error to a specific CC Pair
|
||||
cc_pair_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
|
||||
cc_pair: Mapped["ConnectorCredentialPair | None"] = relationship(
|
||||
"ConnectorCredentialPair", back_populates="background_errors"
|
||||
)
|
||||
|
||||
|
||||
"""Tables related to Permission Sync"""
|
||||
|
||||
|
||||
|
||||
@@ -204,6 +204,14 @@ def create_update_persona(
|
||||
if not all_prompt_ids:
|
||||
raise ValueError("No prompt IDs provided")
|
||||
|
||||
# Default persona validation
|
||||
if create_persona_request.is_default_persona:
|
||||
if not create_persona_request.is_public:
|
||||
raise ValueError("Cannot make a default persona non public")
|
||||
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
raise ValueError("Only admins can make a default persona")
|
||||
|
||||
persona = upsert_persona(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
@@ -510,6 +518,7 @@ def upsert_persona(
|
||||
existing_persona.is_visible = is_visible
|
||||
existing_persona.search_start_date = search_start_date
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.is_default_persona = is_default_persona
|
||||
# Do not delete any associations manually added unless
|
||||
# a new updated list is provided
|
||||
if document_sets is not None:
|
||||
@@ -590,6 +599,23 @@ def delete_old_default_personas(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_persona_is_default(
|
||||
persona_id: int,
|
||||
is_default: bool,
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
) -> None:
|
||||
persona = fetch_persona_by_id_for_user(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
if not persona.is_public:
|
||||
persona.is_public = True
|
||||
|
||||
persona.is_default_persona = is_default
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_persona_visibility(
|
||||
persona_id: int,
|
||||
is_visible: bool,
|
||||
|
||||
@@ -6,6 +6,7 @@ from fastapi import HTTPException
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import expression
|
||||
from sqlalchemy.sql.elements import ColumnElement
|
||||
@@ -274,7 +275,7 @@ def _generate_ext_permissioned_user(email: str) -> User:
|
||||
|
||||
|
||||
def batch_add_ext_perm_user_if_not_exists(
|
||||
db_session: Session, emails: list[str]
|
||||
db_session: Session, emails: list[str], continue_on_error: bool = False
|
||||
) -> list[User]:
|
||||
lower_emails = [email.lower() for email in emails]
|
||||
found_users, missing_lower_emails = _get_users_by_emails(db_session, lower_emails)
|
||||
@@ -283,10 +284,23 @@ def batch_add_ext_perm_user_if_not_exists(
|
||||
for email in missing_lower_emails:
|
||||
new_users.append(_generate_ext_permissioned_user(email=email))
|
||||
|
||||
db_session.add_all(new_users)
|
||||
db_session.commit()
|
||||
|
||||
return found_users + new_users
|
||||
try:
|
||||
db_session.add_all(new_users)
|
||||
db_session.commit()
|
||||
except IntegrityError:
|
||||
db_session.rollback()
|
||||
if not continue_on_error:
|
||||
raise
|
||||
for user in new_users:
|
||||
try:
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
except IntegrityError:
|
||||
db_session.rollback()
|
||||
continue
|
||||
# Fetch all users again to ensure we have the most up-to-date list
|
||||
all_users, _ = _get_users_by_emails(db_session, lower_emails)
|
||||
return all_users
|
||||
|
||||
|
||||
def delete_user_from_db(
|
||||
|
||||
@@ -17,6 +17,7 @@ from uuid import UUID
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
@@ -549,6 +550,11 @@ class VespaIndex(DocumentIndex):
|
||||
time.monotonic() - update_start,
|
||||
)
|
||||
|
||||
@retry(
|
||||
tries=3,
|
||||
delay=1,
|
||||
backoff=2,
|
||||
)
|
||||
def _update_single_chunk(
|
||||
self,
|
||||
doc_chunk_id: UUID,
|
||||
@@ -559,6 +565,7 @@ class VespaIndex(DocumentIndex):
|
||||
) -> None:
|
||||
"""
|
||||
Update a single "chunk" (document) in Vespa using its chunk ID.
|
||||
Retries if we encounter transient HTTPStatusError (e.g., overload).
|
||||
"""
|
||||
|
||||
update_dict: dict[str, dict] = {"fields": {}}
|
||||
@@ -567,13 +574,11 @@ class VespaIndex(DocumentIndex):
|
||||
update_dict["fields"][BOOST] = {"assign": fields.boost}
|
||||
|
||||
if fields.document_sets is not None:
|
||||
# WeightedSet<string> needs a map { item: weight, ... }
|
||||
update_dict["fields"][DOCUMENT_SETS] = {
|
||||
"assign": {document_set: 1 for document_set in fields.document_sets}
|
||||
}
|
||||
|
||||
if fields.access is not None:
|
||||
# Similar to above
|
||||
update_dict["fields"][ACCESS_CONTROL_LIST] = {
|
||||
"assign": {acl_entry: 1 for acl_entry in fields.access.to_acl()}
|
||||
}
|
||||
@@ -585,7 +590,10 @@ class VespaIndex(DocumentIndex):
|
||||
logger.error("Update request received but nothing to update.")
|
||||
return
|
||||
|
||||
vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}?create=true"
|
||||
vespa_url = (
|
||||
f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}"
|
||||
"?create=true"
|
||||
)
|
||||
|
||||
try:
|
||||
resp = http_client.put(
|
||||
@@ -595,8 +603,11 @@ class VespaIndex(DocumentIndex):
|
||||
)
|
||||
resp.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_message = f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). Details: {e.response.text}"
|
||||
logger.error(error_message)
|
||||
logger.error(
|
||||
f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). "
|
||||
f"Details: {e.response.text}"
|
||||
)
|
||||
# Re-raise so the @retry decorator will catch and retry
|
||||
raise
|
||||
|
||||
def update_single(
|
||||
|
||||
@@ -146,6 +146,23 @@ def _index_vespa_chunk(
|
||||
|
||||
title = document.get_title_for_document_index()
|
||||
|
||||
metadata_json = document.metadata
|
||||
cleaned_metadata_json: dict[str, str | list[str]] = {}
|
||||
for key, value in metadata_json.items():
|
||||
cleaned_key = remove_invalid_unicode_chars(key)
|
||||
if isinstance(value, list):
|
||||
cleaned_metadata_json[cleaned_key] = [
|
||||
remove_invalid_unicode_chars(item) for item in value
|
||||
]
|
||||
else:
|
||||
cleaned_metadata_json[cleaned_key] = remove_invalid_unicode_chars(value)
|
||||
|
||||
metadata_list = document.get_metadata_str_attributes()
|
||||
if metadata_list:
|
||||
metadata_list = [
|
||||
remove_invalid_unicode_chars(metadata) for metadata in metadata_list
|
||||
]
|
||||
|
||||
vespa_document_fields = {
|
||||
DOCUMENT_ID: document.id,
|
||||
CHUNK_ID: chunk.chunk_id,
|
||||
@@ -166,10 +183,10 @@ def _index_vespa_chunk(
|
||||
SEMANTIC_IDENTIFIER: remove_invalid_unicode_chars(document.semantic_identifier),
|
||||
SECTION_CONTINUATION: chunk.section_continuation,
|
||||
LARGE_CHUNK_REFERENCE_IDS: chunk.large_chunk_reference_ids,
|
||||
METADATA: json.dumps(document.metadata),
|
||||
METADATA: json.dumps(cleaned_metadata_json),
|
||||
# Save as a list for efficient extraction as an Attribute
|
||||
METADATA_LIST: chunk.source_document.get_metadata_str_attributes(),
|
||||
METADATA_SUFFIX: chunk.metadata_suffix_keyword,
|
||||
METADATA_LIST: metadata_list,
|
||||
METADATA_SUFFIX: remove_invalid_unicode_chars(chunk.metadata_suffix_keyword),
|
||||
EMBEDDINGS: embeddings_name_vector_map,
|
||||
TITLE_EMBEDDING: chunk.title_embedding,
|
||||
DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at),
|
||||
|
||||
@@ -396,9 +396,14 @@ class DefaultMultiLLM(LLM):
|
||||
self._record_call(processed_prompt)
|
||||
|
||||
try:
|
||||
print(
|
||||
"model is",
|
||||
f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}",
|
||||
)
|
||||
return litellm.completion(
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
# model choice
|
||||
# model="openai/gpt-4",
|
||||
model=f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}",
|
||||
# NOTE: have to pass in None instead of empty string for these
|
||||
# otherwise litellm can have some issues with bedrock
|
||||
|
||||
@@ -99,7 +99,7 @@ def _check_tokenizer_cache(
|
||||
|
||||
if not tokenizer:
|
||||
logger.info(
|
||||
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
|
||||
f"Falling back to default embedding model tokenizer: {DOCUMENT_ENCODER_MODEL}"
|
||||
)
|
||||
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
||||
|
||||
|
||||
@@ -120,6 +120,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
queue=OnyxCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
ignore_result=True,
|
||||
)
|
||||
|
||||
num_tasks_sent += 1
|
||||
|
||||
@@ -132,6 +132,7 @@ class RedisConnectorDelete:
|
||||
queue=OnyxCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
ignore_result=True,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
@@ -11,6 +11,7 @@ from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -49,7 +50,7 @@ class RedisConnectorPermissionSync:
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
ACTIVE_TTL = CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT * 2
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
@@ -195,6 +196,7 @@ class RedisConnectorPermissionSync:
|
||||
queue=OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
ignore_result=True,
|
||||
)
|
||||
async_results.append(result)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -49,7 +50,7 @@ class RedisConnectorPrune:
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
ACTIVE_TTL = CELERY_PRUNING_LOCK_TIMEOUT * 2
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
@@ -201,6 +202,7 @@ class RedisConnectorPrune:
|
||||
queue=OnyxCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
ignore_result=True,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
29
backend/onyx/redis/redis_utils.py
Normal file
29
backend/onyx/redis/redis_utils.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from onyx.redis.redis_connector_credential_pair import (
|
||||
RedisGlobalConnectorCredentialPair,
|
||||
)
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
|
||||
|
||||
def is_fence(key_bytes: bytes) -> bool:
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
|
||||
return True
|
||||
if key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisUserGroup.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -162,6 +162,11 @@ def load_personas_from_yaml(
|
||||
else persona.get("is_visible")
|
||||
),
|
||||
db_session=db_session,
|
||||
is_default_persona=(
|
||||
existing_persona.is_default_persona
|
||||
if existing_persona is not None
|
||||
else persona.get("is_default_persona", False)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ personas:
|
||||
icon_color: "#6FB1FF"
|
||||
display_priority: 0
|
||||
is_visible: true
|
||||
is_default_persona: true
|
||||
starter_messages:
|
||||
- name: "Give me an overview of what's here"
|
||||
message: "Sample some documents and tell me what you find."
|
||||
@@ -66,6 +67,7 @@ personas:
|
||||
icon_color: "#FF6F6F"
|
||||
display_priority: 1
|
||||
is_visible: true
|
||||
is_default_persona: true
|
||||
starter_messages:
|
||||
- name: "Summarize a document"
|
||||
message: "If I have provided a document please summarize it for me. If not, please ask me to upload a document either by dragging it into the input bar or clicking the +file icon."
|
||||
@@ -91,6 +93,7 @@ personas:
|
||||
icon_color: "#6FFF8D"
|
||||
display_priority: 2
|
||||
is_visible: false
|
||||
is_default_persona: true
|
||||
starter_messages:
|
||||
- name: "Document Search"
|
||||
message: "Hi! Could you help me find information about our team structure and reporting lines from our internal documents?"
|
||||
@@ -117,6 +120,7 @@ personas:
|
||||
image_generation: true
|
||||
display_priority: 3
|
||||
is_visible: true
|
||||
is_default_persona: true
|
||||
starter_messages:
|
||||
- name: "Create visuals for a presentation"
|
||||
message: "Generate someone presenting a graph which clearly demonstrates an upwards trajectory."
|
||||
|
||||
@@ -32,6 +32,7 @@ from onyx.db.persona import get_personas_for_user
|
||||
from onyx.db.persona import mark_persona_as_deleted
|
||||
from onyx.db.persona import mark_persona_as_not_deleted
|
||||
from onyx.db.persona import update_all_personas_display_priority
|
||||
from onyx.db.persona import update_persona_is_default
|
||||
from onyx.db.persona import update_persona_label
|
||||
from onyx.db.persona import update_persona_public_status
|
||||
from onyx.db.persona import update_persona_shared_users
|
||||
@@ -56,7 +57,6 @@ from onyx.tools.utils import is_image_generation_available
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -72,6 +72,10 @@ class IsPublicRequest(BaseModel):
|
||||
is_public: bool
|
||||
|
||||
|
||||
class IsDefaultRequest(BaseModel):
|
||||
is_default_persona: bool
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/visible")
|
||||
def patch_persona_visibility(
|
||||
persona_id: int,
|
||||
@@ -106,6 +110,25 @@ def patch_user_presona_public_status(
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/default")
|
||||
def patch_persona_default_status(
|
||||
persona_id: int,
|
||||
is_default_request: IsDefaultRequest,
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
update_persona_is_default(
|
||||
persona_id=persona_id,
|
||||
is_default=is_default_request.is_default_persona,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to update persona default status")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.put("/display-priority")
|
||||
def patch_persona_display_priority(
|
||||
display_priority_request: DisplayPriorityRequest,
|
||||
|
||||
@@ -76,6 +76,7 @@ def gpt_search(
|
||||
user=None,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
skip_query_analysis=True,
|
||||
db_session=db_session,
|
||||
).reranked_sections
|
||||
|
||||
|
||||
@@ -12,10 +12,10 @@ class PageType(str, Enum):
|
||||
SEARCH = "search"
|
||||
|
||||
|
||||
class GatingType(str, Enum):
|
||||
FULL = "full" # Complete restriction of access to the product or service
|
||||
PARTIAL = "partial" # Full access but warning (no credit card on file)
|
||||
NONE = "none" # No restrictions, full access to all features
|
||||
class ApplicationStatus(str, Enum):
|
||||
PAYMENT_REMINDER = "payment_reminder"
|
||||
GATED_ACCESS = "gated_access"
|
||||
ACTIVE = "active"
|
||||
|
||||
|
||||
class Notification(BaseModel):
|
||||
@@ -43,7 +43,7 @@ class Settings(BaseModel):
|
||||
|
||||
maximum_chat_retention_days: int | None = None
|
||||
gpu_enabled: bool | None = None
|
||||
product_gating: GatingType = GatingType.NONE
|
||||
application_status: ApplicationStatus = ApplicationStatus.ACTIVE
|
||||
anonymous_user_enabled: bool | None = None
|
||||
pro_search_disabled: bool | None = None
|
||||
auto_scroll: bool | None = None
|
||||
|
||||
@@ -34,7 +34,7 @@ Now respond to the following:
|
||||
""".strip()
|
||||
|
||||
|
||||
class BaseTool(Tool):
|
||||
class BaseTool(Tool[None]):
|
||||
def build_next_prompt(
|
||||
self,
|
||||
prompt_builder: "AnswerPromptBuilder",
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
@@ -57,5 +60,15 @@ class SearchQueryInfo(BaseModel):
|
||||
recency_bias_multiplier: float
|
||||
|
||||
|
||||
class SearchToolOverrideKwargs(BaseModel):
|
||||
force_no_rerank: bool
|
||||
alternate_db_session: Session | None
|
||||
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None
|
||||
skip_query_analysis: bool
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
|
||||
MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
@@ -14,7 +16,10 @@ if TYPE_CHECKING:
|
||||
from onyx.tools.models import ToolResponse
|
||||
|
||||
|
||||
class Tool(abc.ABC):
|
||||
OVERRIDE_T = TypeVar("OVERRIDE_T")
|
||||
|
||||
|
||||
class Tool(abc.ABC, Generic[OVERRIDE_T]):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
@@ -57,7 +62,9 @@ class Tool(abc.ABC):
|
||||
"""Actual execution of the tool"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self, **kwargs: Any) -> Generator["ToolResponse", None, None]:
|
||||
def run(
|
||||
self, override_kwargs: OVERRIDE_T | None = None, **llm_kwargs: Any
|
||||
) -> Generator["ToolResponse", None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -74,6 +74,7 @@ class CustomToolCallSummary(BaseModel):
|
||||
tool_result: Any # The response data
|
||||
|
||||
|
||||
# override_kwargs is not supported for custom tools
|
||||
class CustomTool(BaseTool):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -235,7 +236,9 @@ class CustomTool(BaseTool):
|
||||
|
||||
"""Actual execution of the tool"""
|
||||
|
||||
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
|
||||
def run(
|
||||
self, override_kwargs: dict[str, Any] | None = None, **kwargs: Any
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
request_body = kwargs.get(REQUEST_BODY)
|
||||
|
||||
path_params = {}
|
||||
|
||||
@@ -79,7 +79,8 @@ class ImageShape(str, Enum):
|
||||
LANDSCAPE = "landscape"
|
||||
|
||||
|
||||
class ImageGenerationTool(Tool):
|
||||
# override_kwargs is not supported for image generation tools
|
||||
class ImageGenerationTool(Tool[None]):
|
||||
_NAME = "run_image_generation"
|
||||
_DESCRIPTION = "Generate an image from a prompt."
|
||||
_DISPLAY_NAME = "Image Generation"
|
||||
@@ -255,7 +256,9 @@ class ImageGenerationTool(Tool):
|
||||
"An error occurred during image generation. Please try again later."
|
||||
)
|
||||
|
||||
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
|
||||
def run(
|
||||
self, override_kwargs: None = None, **kwargs: str
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
prompt = cast(str, kwargs["prompt"])
|
||||
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))
|
||||
format = self.output_format
|
||||
|
||||
@@ -106,7 +106,8 @@ def internet_search_response_to_search_docs(
|
||||
]
|
||||
|
||||
|
||||
class InternetSearchTool(Tool):
|
||||
# override_kwargs is not supported for internet search tools
|
||||
class InternetSearchTool(Tool[None]):
|
||||
_NAME = "run_internet_search"
|
||||
_DISPLAY_NAME = "Internet Search"
|
||||
_DESCRIPTION = "Perform an internet search for up-to-date information."
|
||||
@@ -242,7 +243,9 @@ class InternetSearchTool(Tool):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
|
||||
def run(
|
||||
self, override_kwargs: None = None, **kwargs: str
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
query = cast(str, kwargs["internet_search_query"])
|
||||
|
||||
results = self._perform_search(query)
|
||||
|
||||
@@ -39,6 +39,7 @@ from onyx.secondary_llm_flows.choose_search import check_if_need_search
|
||||
from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
|
||||
@@ -77,7 +78,7 @@ HINT: if you are unfamiliar with the user input OR think the user input is a typ
|
||||
"""
|
||||
|
||||
|
||||
class SearchTool(Tool):
|
||||
class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
_NAME = "run_search"
|
||||
_DISPLAY_NAME = "Search Tool"
|
||||
_DESCRIPTION = SEARCH_TOOL_DESCRIPTION
|
||||
@@ -275,14 +276,19 @@ class SearchTool(Tool):
|
||||
|
||||
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
|
||||
|
||||
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
|
||||
query = cast(str, kwargs["query"])
|
||||
force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False))
|
||||
alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None))
|
||||
retrieved_sections_callback = cast(
|
||||
Callable[[list[InferenceSection]], None],
|
||||
kwargs.get("retrieved_sections_callback"),
|
||||
)
|
||||
def run(
|
||||
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
query = cast(str, llm_kwargs["query"])
|
||||
force_no_rerank = False
|
||||
alternate_db_session = None
|
||||
retrieved_sections_callback = None
|
||||
skip_query_analysis = False
|
||||
if override_kwargs:
|
||||
force_no_rerank = override_kwargs.force_no_rerank
|
||||
alternate_db_session = override_kwargs.alternate_db_session
|
||||
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
|
||||
skip_query_analysis = override_kwargs.skip_query_analysis
|
||||
|
||||
if self.selected_sections:
|
||||
yield from self._build_response_for_specified_sections(query)
|
||||
@@ -324,6 +330,7 @@ class SearchTool(Tool):
|
||||
user=self.user,
|
||||
llm=self.llm,
|
||||
fast_llm=self.fast_llm,
|
||||
skip_query_analysis=skip_query_analysis,
|
||||
bypass_acl=self.bypass_acl,
|
||||
db_session=alternate_db_session or self.db_session,
|
||||
prompt_config=self.prompt_config,
|
||||
|
||||
@@ -86,7 +86,10 @@ def run_functions_in_parallel(
|
||||
Executes a list of FunctionCalls in parallel and stores the results in a dictionary where the keys
|
||||
are the result_id of the FunctionCall and the values are the results of the call.
|
||||
"""
|
||||
results = {}
|
||||
results: dict[str, Any] = {}
|
||||
|
||||
if len(function_calls) == 0:
|
||||
return results
|
||||
|
||||
with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
|
||||
future_to_id = {
|
||||
|
||||
@@ -256,16 +256,28 @@ def get_documents_for_tenant_connector(
|
||||
|
||||
|
||||
def search_for_document(
|
||||
index_name: str, document_id: str, max_hits: int | None = 10
|
||||
index_name: str,
|
||||
document_id: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
max_hits: int | None = 10,
|
||||
) -> List[Dict[str, Any]]:
|
||||
yql_query = (
|
||||
f'select * from sources {index_name} where document_id contains "{document_id}"'
|
||||
)
|
||||
yql_query = f"select * from sources {index_name}"
|
||||
|
||||
conditions = []
|
||||
if document_id is not None:
|
||||
conditions.append(f'document_id contains "{document_id}"')
|
||||
|
||||
if tenant_id is not None:
|
||||
conditions.append(f'tenant_id contains "{tenant_id}"')
|
||||
|
||||
if conditions:
|
||||
yql_query += " where " + " and ".join(conditions)
|
||||
|
||||
params: dict[str, Any] = {"yql": yql_query}
|
||||
if max_hits is not None:
|
||||
params["hits"] = max_hits
|
||||
with get_vespa_http_client() as client:
|
||||
response = client.get(f"{SEARCH_ENDPOINT}/search/", params=params)
|
||||
response = client.get(f"{SEARCH_ENDPOINT}search/", params=params)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
documents = result.get("root", {}).get("children", [])
|
||||
@@ -582,8 +594,15 @@ class VespaDebugging:
|
||||
) -> None:
|
||||
update_document(self.tenant_id, connector_id, doc_id, fields)
|
||||
|
||||
def search_for_document(self, document_id: str) -> List[Dict[str, Any]]:
|
||||
return search_for_document(self.index_name, document_id)
|
||||
def delete_documents_for_tenant(self, count: int | None = None) -> None:
|
||||
if not self.tenant_id:
|
||||
raise Exception("Tenant ID is not set")
|
||||
delete_documents_for_tenant(self.index_name, self.tenant_id, count=count)
|
||||
|
||||
def search_for_document(
|
||||
self, document_id: str | None = None, tenant_id: str | None = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
return search_for_document(self.index_name, document_id, tenant_id)
|
||||
|
||||
def delete_document(self, connector_id: int, doc_id: str) -> None:
|
||||
# Delete a document.
|
||||
@@ -600,6 +619,147 @@ class VespaDebugging:
|
||||
get_document_acls(self.tenant_id, cc_pair_id, n)
|
||||
|
||||
|
||||
def delete_where(
|
||||
index_name: str,
|
||||
selection: str,
|
||||
cluster: str = "default",
|
||||
bucket_space: str | None = None,
|
||||
continuation: str | None = None,
|
||||
time_chunk: str | None = None,
|
||||
timeout: str | None = None,
|
||||
tracelevel: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Removes visited documents in `cluster` where the given selection
|
||||
is true, using Vespa's 'delete where' endpoint.
|
||||
|
||||
:param index_name: Typically <namespace>/<document-type> from your schema
|
||||
:param selection: The selection string, e.g., "true" or "foo contains 'bar'"
|
||||
:param cluster: The name of the cluster where documents reside
|
||||
:param bucket_space: e.g. 'global' or 'default'
|
||||
:param continuation: For chunked visits
|
||||
:param time_chunk: If you want to chunk the visit by time
|
||||
:param timeout: e.g. '10s'
|
||||
:param tracelevel: Increase for verbose logs
|
||||
"""
|
||||
# Using index_name of form <namespace>/<document-type>, e.g. "nomic_ai_nomic_embed_text_v1"
|
||||
# This route ends with "/docid/" since the actual ID is not specified — we rely on "selection".
|
||||
path = f"/document/v1/{index_name}/docid/"
|
||||
|
||||
params = {
|
||||
"cluster": cluster,
|
||||
"selection": selection,
|
||||
}
|
||||
|
||||
# Optional parameters
|
||||
if bucket_space is not None:
|
||||
params["bucketSpace"] = bucket_space
|
||||
if continuation is not None:
|
||||
params["continuation"] = continuation
|
||||
if time_chunk is not None:
|
||||
params["timeChunk"] = time_chunk
|
||||
if timeout is not None:
|
||||
params["timeout"] = timeout
|
||||
if tracelevel is not None:
|
||||
params["tracelevel"] = tracelevel # type: ignore
|
||||
|
||||
with get_vespa_http_client() as client:
|
||||
url = f"{VESPA_APPLICATION_ENDPOINT}{path}"
|
||||
logger.info(f"Performing 'delete where' on {url} with selection={selection}...")
|
||||
response = client.delete(url, params=params)
|
||||
# (Optionally, you can keep fetching `continuation` from the JSON response
|
||||
# if you have more documents to delete in chunks.)
|
||||
response.raise_for_status() # will raise HTTPError if not 2xx
|
||||
logger.info(f"Delete where completed with status: {response.status_code}")
|
||||
print(f"Delete where completed with status: {response.status_code}")
|
||||
|
||||
|
||||
def delete_documents_for_tenant(
|
||||
index_name: str,
|
||||
tenant_id: str,
|
||||
route: str | None = None,
|
||||
condition: str | None = None,
|
||||
timeout: str | None = None,
|
||||
tracelevel: int | None = None,
|
||||
count: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
For the given tenant_id and index_name (often in the form <namespace>/<document-type>),
|
||||
find documents via search_for_document, then delete them one at a time using Vespa's
|
||||
/document/v1/<namespace>/<document-type>/docid/<document-id> endpoint.
|
||||
|
||||
:param index_name: Typically <namespace>/<document-type> from your schema
|
||||
:param tenant_id: The tenant to match in your Vespa search
|
||||
:param route: Optional route parameter for delete
|
||||
:param condition: Optional conditional remove
|
||||
:param timeout: e.g. '10s'
|
||||
:param tracelevel: Increase for verbose logs
|
||||
"""
|
||||
deleted_count = 0
|
||||
while True:
|
||||
# Search for documents with the given tenant_id
|
||||
docs = search_for_document(
|
||||
index_name=index_name,
|
||||
document_id=None,
|
||||
tenant_id=tenant_id,
|
||||
max_hits=100, # Fetch in batches of 100
|
||||
)
|
||||
|
||||
if not docs:
|
||||
logger.info("No more documents found to delete.")
|
||||
break
|
||||
|
||||
with get_vespa_http_client() as client:
|
||||
for doc in docs:
|
||||
if count is not None and deleted_count >= count:
|
||||
logger.info(f"Reached maximum delete limit of {count} documents.")
|
||||
return
|
||||
|
||||
fields = doc.get("fields", {})
|
||||
doc_id_value = fields.get("document_id") or fields.get("documentid")
|
||||
tenant_id = fields.get("tenant_id")
|
||||
if tenant_id != tenant_id:
|
||||
raise Exception("Tenant ID mismatch")
|
||||
|
||||
if not doc_id_value:
|
||||
logger.warning(
|
||||
"Skipping a document that has no document_id in 'fields'."
|
||||
)
|
||||
continue
|
||||
|
||||
url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_id_value}"
|
||||
|
||||
params = {}
|
||||
if condition:
|
||||
params["condition"] = condition
|
||||
if route:
|
||||
params["route"] = route
|
||||
if timeout:
|
||||
params["timeout"] = timeout
|
||||
if tracelevel is not None:
|
||||
params["tracelevel"] = str(tracelevel)
|
||||
|
||||
response = client.delete(url, params=params)
|
||||
if response.status_code == 200:
|
||||
logger.info(f"Successfully deleted doc_id={doc_id_value}")
|
||||
deleted_count += 1
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to delete doc_id={doc_id_value}, "
|
||||
f"status={response.status_code}, response={response.text}"
|
||||
)
|
||||
print(
|
||||
f"Could not delete doc_id={doc_id_value}. "
|
||||
f"Status={response.status_code}, response={response.text}"
|
||||
)
|
||||
raise Exception(
|
||||
f"Could not delete doc_id={doc_id_value}. "
|
||||
f"Status={response.status_code}, response={response.text}"
|
||||
)
|
||||
|
||||
logger.info(f"Deleted {deleted_count} documents in total.")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Vespa debugging tool")
|
||||
parser.add_argument(
|
||||
@@ -612,6 +772,7 @@ def main() -> None:
|
||||
"update",
|
||||
"delete",
|
||||
"get_acls",
|
||||
"delete-all-documents",
|
||||
],
|
||||
required=True,
|
||||
help="Action to perform",
|
||||
@@ -626,11 +787,20 @@ def main() -> None:
|
||||
parser.add_argument(
|
||||
"--fields", help="Fields to update, in JSON format (for update)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--count",
|
||||
type=int,
|
||||
help="Maximum number of documents to delete (for delete-all-documents)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
vespa_debug = VespaDebugging(args.tenant_id)
|
||||
|
||||
if args.action == "config":
|
||||
if args.action == "delete-all-documents":
|
||||
if not args.tenant_id:
|
||||
parser.error("--tenant-id is required for delete-all-documents action")
|
||||
vespa_debug.delete_documents_for_tenant(count=args.count)
|
||||
elif args.action == "config":
|
||||
vespa_debug.print_config()
|
||||
elif args.action == "connect":
|
||||
vespa_debug.check_connectivity()
|
||||
|
||||
@@ -9,6 +9,8 @@ from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
|
||||
BASE_VIEW_ID = "viwVUEJjWPd8XYjh8"
|
||||
|
||||
|
||||
class AirtableConfig(BaseModel):
|
||||
base_id: str
|
||||
@@ -46,6 +48,8 @@ def create_test_document(
|
||||
days_since_status_change: int | None,
|
||||
attachments: list[tuple[str, str]] | None = None,
|
||||
all_fields_as_metadata: bool = False,
|
||||
share_id: str | None = None,
|
||||
view_id: str | None = None,
|
||||
) -> Document:
|
||||
base_id = os.environ.get("AIRTABLE_TEST_BASE_ID")
|
||||
table_id = os.environ.get("AIRTABLE_TEST_TABLE_ID")
|
||||
@@ -60,7 +64,13 @@ def create_test_document(
|
||||
f"Required environment variables not set: {', '.join(missing_vars)}. "
|
||||
"These variables are required to run Airtable connector tests."
|
||||
)
|
||||
link_base = f"https://airtable.com/{base_id}/{table_id}"
|
||||
link_base = f"https://airtable.com/{base_id}"
|
||||
if share_id:
|
||||
link_base = f"{link_base}/{share_id}"
|
||||
link_base = f"{link_base}/{table_id}"
|
||||
if view_id:
|
||||
link_base = f"{link_base}/{view_id}"
|
||||
|
||||
sections = []
|
||||
|
||||
if not all_fields_as_metadata:
|
||||
@@ -214,6 +224,7 @@ def test_airtable_connector_basic(
|
||||
assignee="Chris Weaver (chris@onyx.app)",
|
||||
submitted_by="Chris Weaver (chris@onyx.app)",
|
||||
all_fields_as_metadata=False,
|
||||
view_id=BASE_VIEW_ID,
|
||||
),
|
||||
create_test_document(
|
||||
id="reccSlIA4pZEFxPBg",
|
||||
@@ -234,6 +245,7 @@ def test_airtable_connector_basic(
|
||||
)
|
||||
],
|
||||
all_fields_as_metadata=False,
|
||||
view_id=BASE_VIEW_ID,
|
||||
),
|
||||
]
|
||||
|
||||
@@ -285,6 +297,81 @@ def test_airtable_connector_all_metadata(
|
||||
)
|
||||
],
|
||||
all_fields_as_metadata=True,
|
||||
view_id=BASE_VIEW_ID,
|
||||
),
|
||||
]
|
||||
|
||||
# Compare documents using the utility function
|
||||
compare_documents(doc_batch, expected_docs)
|
||||
|
||||
|
||||
def test_airtable_connector_with_share_and_view(
|
||||
mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig
|
||||
) -> None:
|
||||
"""Test behavior when using share_id and view_id for URL generation."""
|
||||
SHARE_ID = "shrkfjEzDmLaDtK83"
|
||||
|
||||
connector = AirtableConnector(
|
||||
base_id=airtable_config.base_id,
|
||||
table_name_or_id=airtable_config.table_identifier,
|
||||
treat_all_non_attachment_fields_as_metadata=False,
|
||||
share_id=SHARE_ID,
|
||||
view_id=BASE_VIEW_ID,
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"airtable_access_token": airtable_config.access_token,
|
||||
}
|
||||
)
|
||||
doc_batch_generator = connector.load_from_state()
|
||||
doc_batch = next(doc_batch_generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 2
|
||||
|
||||
expected_docs = [
|
||||
create_test_document(
|
||||
id="rec8BnxDLyWeegOuO",
|
||||
title="Slow Internet",
|
||||
description="The internet connection is very slow.",
|
||||
priority="Medium",
|
||||
status="In Progress",
|
||||
ticket_id="2",
|
||||
created_time="2024-12-24T21:02:49.000Z",
|
||||
status_last_changed="2024-12-24T21:02:49.000Z",
|
||||
days_since_status_change=0,
|
||||
assignee="Chris Weaver (chris@onyx.app)",
|
||||
submitted_by="Chris Weaver (chris@onyx.app)",
|
||||
all_fields_as_metadata=False,
|
||||
share_id=SHARE_ID,
|
||||
view_id=BASE_VIEW_ID,
|
||||
),
|
||||
create_test_document(
|
||||
id="reccSlIA4pZEFxPBg",
|
||||
title="Printer Issue",
|
||||
description="The office printer is not working.",
|
||||
priority="High",
|
||||
status="Open",
|
||||
ticket_id="1",
|
||||
created_time="2024-12-24T21:02:49.000Z",
|
||||
status_last_changed="2024-12-24T21:02:49.000Z",
|
||||
days_since_status_change=0,
|
||||
assignee="Chris Weaver (chris@onyx.app)",
|
||||
submitted_by="Chris Weaver (chris@onyx.app)",
|
||||
attachments=[
|
||||
(
|
||||
"Test.pdf:\ntesting!!!",
|
||||
(
|
||||
f"https://airtable.com/{airtable_config.base_id}/{SHARE_ID}/"
|
||||
f"{os.environ['AIRTABLE_TEST_TABLE_ID']}/{BASE_VIEW_ID}/reccSlIA4pZEFxPBg/"
|
||||
"fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide"
|
||||
),
|
||||
)
|
||||
],
|
||||
all_fields_as_metadata=False,
|
||||
share_id=SHARE_ID,
|
||||
view_id=BASE_VIEW_ID,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -34,11 +34,11 @@ def test_jira_connector_basic(jira_connector: JiraConnector) -> None:
|
||||
doc = doc_batch[0]
|
||||
|
||||
assert doc.id == "https://danswerai.atlassian.net/browse/AS-2"
|
||||
assert doc.semantic_identifier == "test123small"
|
||||
assert doc.semantic_identifier == "AS-2: test123small"
|
||||
assert doc.source == DocumentSource.JIRA
|
||||
assert doc.metadata == {"priority": "Medium", "status": "Backlog"}
|
||||
assert doc.secondary_owners is None
|
||||
assert doc.title is None
|
||||
assert doc.title == "AS-2 test123small"
|
||||
assert doc.from_ingestion_api is False
|
||||
assert doc.additional_info is None
|
||||
|
||||
|
||||
@@ -23,6 +23,9 @@ def bedrock_provider() -> WellKnownLLMProviderDescriptor:
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Credentials not yet available due to compliance work needed",
|
||||
)
|
||||
def test_bedrock_llm_configuration(
|
||||
client: TestClient, bedrock_provider: WellKnownLLMProviderDescriptor
|
||||
) -> None:
|
||||
|
||||
@@ -66,7 +66,7 @@ class PersonaManager:
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/persona",
|
||||
json=persona_creation_request.model_dump(),
|
||||
json=persona_creation_request.model_dump(mode="json"),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
@@ -119,6 +119,7 @@ class PersonaManager:
|
||||
) -> DATestPersona:
|
||||
system_prompt = system_prompt or f"System prompt for {persona.name}"
|
||||
task_prompt = task_prompt or f"Task prompt for {persona.name}"
|
||||
|
||||
persona_update_request = PersonaUpsertRequest(
|
||||
name=name or persona.name,
|
||||
description=description or persona.description,
|
||||
@@ -146,7 +147,7 @@ class PersonaManager:
|
||||
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
json=persona_update_request.model_dump(),
|
||||
json=persona_update_request.model_dump(mode="json"),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
|
||||
@@ -58,6 +58,7 @@ def test_persona_permissions(reset: None) -> None:
|
||||
description="A persona created by basic user",
|
||||
is_public=False,
|
||||
groups=[],
|
||||
users=[admin_user.id],
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
PersonaManager.verify(basic_user_persona, user_performing_action=basic_user)
|
||||
@@ -139,9 +140,14 @@ def test_persona_permissions(reset: None) -> None:
|
||||
|
||||
"""Test admin permissions"""
|
||||
# Admin can edit any persona
|
||||
|
||||
# the persona was shared with the admin user on creation
|
||||
# this edit call will simulate having the same user in the list twice.
|
||||
# The server side should dedupe and handle this correctly (prior bug)
|
||||
PersonaManager.edit(
|
||||
persona=basic_user_persona,
|
||||
description="Updated by admin",
|
||||
description="Updated by admin 2",
|
||||
users=[admin_user.id, admin_user.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
PersonaManager.verify(basic_user_persona, user_performing_action=admin_user)
|
||||
|
||||
@@ -84,6 +84,9 @@ ENV NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=${NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED}
|
||||
ARG NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK
|
||||
ENV NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=${NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK}
|
||||
|
||||
ARG NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY
|
||||
ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY}
|
||||
|
||||
# Use NODE_OPTIONS in the build command
|
||||
RUN NODE_OPTIONS="${NODE_OPTIONS}" npx next build
|
||||
|
||||
@@ -145,7 +148,6 @@ ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT}
|
||||
ARG NEXT_PUBLIC_CUSTOM_REFRESH_URL
|
||||
ENV NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL}
|
||||
|
||||
|
||||
ARG NEXT_PUBLIC_POSTHOG_KEY
|
||||
ARG NEXT_PUBLIC_POSTHOG_HOST
|
||||
ENV NEXT_PUBLIC_POSTHOG_KEY=${NEXT_PUBLIC_POSTHOG_KEY}
|
||||
@@ -166,6 +168,9 @@ ENV NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=${NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED}
|
||||
ARG NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK
|
||||
ENV NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=${NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK}
|
||||
|
||||
ARG NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY
|
||||
ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY}
|
||||
|
||||
# Note: Don't expose ports here, Compose will handle that for us if necessary.
|
||||
# If you want to run this without compose, specify the ports to
|
||||
# expose via cli
|
||||
|
||||
@@ -23,12 +23,12 @@ _Note:_ if you are having problems accessing the ^, try setting the `WEB_DOMAIN`
|
||||
`http://127.0.0.1:3000` and accessing it there.
|
||||
|
||||
## Testing
|
||||
This testing process will reset your application into a clean state.
|
||||
|
||||
This testing process will reset your application into a clean state.
|
||||
Don't run these tests if you don't want to do this!
|
||||
|
||||
Bring up the entire application.
|
||||
|
||||
|
||||
1. Reset the instance
|
||||
|
||||
```cd backend
|
||||
@@ -59,4 +59,4 @@ may use this for local troubleshooting and testing.
|
||||
```
|
||||
cd web
|
||||
npx chromatic --playwright --project-token={your token here}
|
||||
```
|
||||
```
|
||||
|
||||
@@ -3,7 +3,13 @@
|
||||
import React from "react";
|
||||
import { Option } from "@/components/Dropdown";
|
||||
import { generateRandomIconShape } from "@/lib/assistantIconUtils";
|
||||
import { CCPairBasicInfo, DocumentSet, User, UserGroup } from "@/lib/types";
|
||||
import {
|
||||
CCPairBasicInfo,
|
||||
DocumentSet,
|
||||
User,
|
||||
UserGroup,
|
||||
UserRole,
|
||||
} from "@/lib/types";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ArrayHelpers, FieldArray, Form, Formik, FormikProps } from "formik";
|
||||
@@ -33,9 +39,8 @@ import {
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import Link from "next/link";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { FiInfo } from "react-icons/fi";
|
||||
import * as Yup from "yup";
|
||||
import CollapsibleSection from "./CollapsibleSection";
|
||||
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
|
||||
@@ -71,11 +76,11 @@ import {
|
||||
Option as DropdownOption,
|
||||
} from "@/components/Dropdown";
|
||||
import { SourceChip } from "@/app/chat/input/ChatInputBar";
|
||||
import { TagIcon, UserIcon, XIcon } from "lucide-react";
|
||||
import { TagIcon, UserIcon, XIcon, InfoIcon } from "lucide-react";
|
||||
import { LLMSelector } from "@/components/llm/LLMSelector";
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal";
|
||||
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
|
||||
import Title from "@/components/ui/title";
|
||||
import { SEARCH_TOOL_ID } from "@/app/chat/tools/constants";
|
||||
|
||||
@@ -127,6 +132,8 @@ export function AssistantEditor({
|
||||
}) {
|
||||
const { refreshAssistants, isImageGenerationAvailable } = useAssistants();
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const isAdminPage = searchParams.get("admin") === "true";
|
||||
|
||||
const { popup, setPopup } = usePopup();
|
||||
const { labels, refreshLabels, createLabel, updateLabel, deleteLabel } =
|
||||
@@ -216,6 +223,8 @@ export function AssistantEditor({
|
||||
enabledToolsMap[tool.id] = personaCurrentToolIds.includes(tool.id);
|
||||
});
|
||||
|
||||
const [showVisibilityWarning, setShowVisibilityWarning] = useState(false);
|
||||
|
||||
const initialValues = {
|
||||
name: existingPersona?.name ?? "",
|
||||
description: existingPersona?.description ?? "",
|
||||
@@ -252,6 +261,7 @@ export function AssistantEditor({
|
||||
(u) => u.id !== existingPersona.owner?.id
|
||||
) ?? [],
|
||||
selectedGroups: existingPersona?.groups ?? [],
|
||||
is_default_persona: existingPersona?.is_default_persona ?? false,
|
||||
};
|
||||
|
||||
interface AssistantPrompt {
|
||||
@@ -308,24 +318,12 @@ export function AssistantEditor({
|
||||
const [isRequestSuccessful, setIsRequestSuccessful] = useState(false);
|
||||
|
||||
const { data: userGroups } = useUserGroups();
|
||||
// const { data: allUsers } = useUsers({ includeApiKeys: false }) as {
|
||||
// data: MinimalUserSnapshot[] | undefined;
|
||||
// };
|
||||
|
||||
const { data: users } = useSWR<MinimalUserSnapshot[]>(
|
||||
"/api/users",
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
const mapUsersToMinimalSnapshot = (users: any): MinimalUserSnapshot[] => {
|
||||
if (!users || !Array.isArray(users.users)) return [];
|
||||
return users.users.map((user: any) => ({
|
||||
id: user.id,
|
||||
name: user.name,
|
||||
email: user.email,
|
||||
}));
|
||||
};
|
||||
|
||||
const [deleteModalOpen, setDeleteModalOpen] = useState(false);
|
||||
|
||||
if (!labels) {
|
||||
@@ -346,9 +344,7 @@ export function AssistantEditor({
|
||||
if (response.ok) {
|
||||
await refreshAssistants();
|
||||
router.push(
|
||||
redirectType === SuccessfulPersonaUpdateRedirectType.ADMIN
|
||||
? `/admin/assistants?u=${Date.now()}`
|
||||
: `/chat`
|
||||
isAdminPage ? `/admin/assistants?u=${Date.now()}` : `/chat`
|
||||
);
|
||||
} else {
|
||||
setPopup({
|
||||
@@ -374,8 +370,9 @@ export function AssistantEditor({
|
||||
<BackButton />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{labelToDelete && (
|
||||
<DeleteEntityModal
|
||||
<ConfirmEntityModal
|
||||
entityType="label"
|
||||
entityName={labelToDelete.name}
|
||||
onClose={() => setLabelToDelete(null)}
|
||||
@@ -398,7 +395,7 @@ export function AssistantEditor({
|
||||
/>
|
||||
)}
|
||||
{deleteModalOpen && existingPersona && (
|
||||
<DeleteEntityModal
|
||||
<ConfirmEntityModal
|
||||
entityType="Persona"
|
||||
entityName={existingPersona.name}
|
||||
onClose={closeDeleteModal}
|
||||
@@ -439,6 +436,7 @@ export function AssistantEditor({
|
||||
label_ids: Yup.array().of(Yup.number()),
|
||||
selectedUsers: Yup.array().of(Yup.object()),
|
||||
selectedGroups: Yup.array().of(Yup.number()),
|
||||
is_default_persona: Yup.boolean().required(),
|
||||
})
|
||||
.test(
|
||||
"system-prompt-or-task-prompt",
|
||||
@@ -459,6 +457,19 @@ export function AssistantEditor({
|
||||
"Must provide either Instructions or Reminders (Advanced)",
|
||||
});
|
||||
}
|
||||
)
|
||||
.test(
|
||||
"default-persona-public",
|
||||
"Default persona must be public",
|
||||
function (values) {
|
||||
if (values.is_default_persona && !values.is_public) {
|
||||
return this.createError({
|
||||
path: "is_public",
|
||||
message: "Default persona must be public",
|
||||
});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
)}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
if (
|
||||
@@ -499,7 +510,6 @@ export function AssistantEditor({
|
||||
const submissionData: PersonaUpsertParameters = {
|
||||
...values,
|
||||
existing_prompt_id: existingPrompt?.id ?? null,
|
||||
is_default_persona: admin!,
|
||||
starter_messages: starterMessages,
|
||||
groups: groups,
|
||||
users: values.is_public
|
||||
@@ -563,8 +573,9 @@ export function AssistantEditor({
|
||||
}
|
||||
|
||||
await refreshAssistants();
|
||||
|
||||
router.push(
|
||||
redirectType === SuccessfulPersonaUpdateRedirectType.ADMIN
|
||||
isAdminPage
|
||||
? `/admin/assistants?u=${Date.now()}`
|
||||
: `/chat?assistantId=${assistantId}`
|
||||
);
|
||||
@@ -1005,6 +1016,22 @@ export function AssistantEditor({
|
||||
{showAdvancedOptions && (
|
||||
<>
|
||||
<div className="max-w-4xl w-full">
|
||||
{user?.role == UserRole.ADMIN && (
|
||||
<BooleanFormField
|
||||
onChange={(checked) => {
|
||||
if (checked) {
|
||||
setFieldValue("is_public", true);
|
||||
setFieldValue("is_default_persona", true);
|
||||
}
|
||||
}}
|
||||
name="is_default_persona"
|
||||
label="Featured Assistant"
|
||||
subtext="If set, this assistant will be pinned for all new users and appear in the Featured list in the assistant explorer. This also makes the assistant public."
|
||||
/>
|
||||
)}
|
||||
|
||||
<Separator />
|
||||
|
||||
<div className="flex gap-x-2 items-center ">
|
||||
<div className="block font-medium text-sm">Access</div>
|
||||
</div>
|
||||
@@ -1014,22 +1041,60 @@ export function AssistantEditor({
|
||||
|
||||
<div className="min-h-[100px]">
|
||||
<div className="flex items-center mb-2">
|
||||
<SwitchField
|
||||
name="is_public"
|
||||
size="md"
|
||||
onCheckedChange={(checked) => {
|
||||
setFieldValue("is_public", checked);
|
||||
if (checked) {
|
||||
setFieldValue("selectedUsers", []);
|
||||
setFieldValue("selectedGroups", []);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<TooltipProvider delayDuration={0}>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div>
|
||||
<SwitchField
|
||||
name="is_public"
|
||||
size="md"
|
||||
onCheckedChange={(checked) => {
|
||||
if (values.is_default_persona && !checked) {
|
||||
setShowVisibilityWarning(true);
|
||||
} else {
|
||||
setFieldValue("is_public", checked);
|
||||
if (!checked) {
|
||||
// Even though this code path should not be possible,
|
||||
// we set the default persona to false to be safe
|
||||
setFieldValue(
|
||||
"is_default_persona",
|
||||
false
|
||||
);
|
||||
}
|
||||
if (checked) {
|
||||
setFieldValue("selectedUsers", []);
|
||||
setFieldValue("selectedGroups", []);
|
||||
}
|
||||
}
|
||||
}}
|
||||
disabled={values.is_default_persona}
|
||||
/>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
{values.is_default_persona && (
|
||||
<TooltipContent side="top" align="center">
|
||||
Default persona must be public. Set
|
||||
"Default Persona" to false to change
|
||||
visibility.
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<span className="text-sm ml-2">
|
||||
{values.is_public ? "Public" : "Private"}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{showVisibilityWarning && (
|
||||
<div className="flex items-center text-warning mt-2">
|
||||
<InfoIcon size={16} className="mr-2" />
|
||||
<span className="text-sm">
|
||||
Default persona must be public. Visibility has been
|
||||
automatically set to public.
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{values.is_public ? (
|
||||
<p className="text-sm text-text-dark">
|
||||
Anyone from your organization can view and use this
|
||||
|
||||
@@ -11,13 +11,14 @@ import { DraggableTable } from "@/components/table/DraggableTable";
|
||||
import {
|
||||
deletePersona,
|
||||
personaComparator,
|
||||
togglePersonaDefault,
|
||||
togglePersonaVisibility,
|
||||
} from "./lib";
|
||||
import { FiEdit2 } from "react-icons/fi";
|
||||
import { TrashIcon } from "@/components/icons/icons";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import { useAssistants } from "@/components/context/AssistantsContext";
|
||||
import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal";
|
||||
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
|
||||
|
||||
function PersonaTypeDisplay({ persona }: { persona: Persona }) {
|
||||
if (persona.builtin_persona) {
|
||||
@@ -56,6 +57,9 @@ export function PersonasTable() {
|
||||
const [finalPersonas, setFinalPersonas] = useState<Persona[]>([]);
|
||||
const [deleteModalOpen, setDeleteModalOpen] = useState(false);
|
||||
const [personaToDelete, setPersonaToDelete] = useState<Persona | null>(null);
|
||||
const [defaultModalOpen, setDefaultModalOpen] = useState(false);
|
||||
const [personaToToggleDefault, setPersonaToToggleDefault] =
|
||||
useState<Persona | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const editable = editablePersonas.sort(personaComparator);
|
||||
@@ -126,11 +130,39 @@ export function PersonasTable() {
|
||||
}
|
||||
};
|
||||
|
||||
const openDefaultModal = (persona: Persona) => {
|
||||
setPersonaToToggleDefault(persona);
|
||||
setDefaultModalOpen(true);
|
||||
};
|
||||
|
||||
const closeDefaultModal = () => {
|
||||
setDefaultModalOpen(false);
|
||||
setPersonaToToggleDefault(null);
|
||||
};
|
||||
|
||||
const handleToggleDefault = async () => {
|
||||
if (personaToToggleDefault) {
|
||||
const response = await togglePersonaDefault(
|
||||
personaToToggleDefault.id,
|
||||
personaToToggleDefault.is_default_persona
|
||||
);
|
||||
if (response.ok) {
|
||||
await refreshAssistants();
|
||||
closeDefaultModal();
|
||||
} else {
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: `Failed to update persona - ${await response.text()}`,
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div>
|
||||
{popup}
|
||||
{deleteModalOpen && personaToDelete && (
|
||||
<DeleteEntityModal
|
||||
<ConfirmEntityModal
|
||||
entityType="Persona"
|
||||
entityName={personaToDelete.name}
|
||||
onClose={closeDeleteModal}
|
||||
@@ -138,8 +170,35 @@ export function PersonasTable() {
|
||||
/>
|
||||
)}
|
||||
|
||||
{defaultModalOpen && personaToToggleDefault && (
|
||||
<ConfirmEntityModal
|
||||
variant="action"
|
||||
entityType="Assistant"
|
||||
entityName={personaToToggleDefault.name}
|
||||
onClose={closeDefaultModal}
|
||||
onSubmit={handleToggleDefault}
|
||||
actionButtonText={
|
||||
personaToToggleDefault.is_default_persona
|
||||
? "Remove Featured"
|
||||
: "Set as Featured"
|
||||
}
|
||||
additionalDetails={
|
||||
personaToToggleDefault.is_default_persona
|
||||
? `Removing "${personaToToggleDefault.name}" as a featured assistant will not affect its visibility or accessibility.`
|
||||
: `Setting "${personaToToggleDefault.name}" as a featured assistant will make it public and visible to all users. This action cannot be undone.`
|
||||
}
|
||||
/>
|
||||
)}
|
||||
|
||||
<DraggableTable
|
||||
headers={["Name", "Description", "Type", "Is Visible", "Delete"]}
|
||||
headers={[
|
||||
"Name",
|
||||
"Description",
|
||||
"Type",
|
||||
"Featured Assistant",
|
||||
"Is Visible",
|
||||
"Delete",
|
||||
]}
|
||||
isAdmin={isAdmin}
|
||||
rows={finalPersonas.map((persona) => {
|
||||
const isEditable = editablePersonas.includes(persona);
|
||||
@@ -152,7 +211,9 @@ export function PersonasTable() {
|
||||
className="mr-1 my-auto cursor-pointer"
|
||||
onClick={() =>
|
||||
router.push(
|
||||
`/admin/assistants/${persona.id}?u=${Date.now()}`
|
||||
`/assistants/edit/${
|
||||
persona.id
|
||||
}?u=${Date.now()}&admin=true`
|
||||
)
|
||||
}
|
||||
/>
|
||||
@@ -168,6 +229,30 @@ export function PersonasTable() {
|
||||
{persona.description}
|
||||
</p>,
|
||||
<PersonaTypeDisplay key={persona.id} persona={persona} />,
|
||||
<div
|
||||
key="is_default_persona"
|
||||
onClick={() => {
|
||||
if (isEditable) {
|
||||
openDefaultModal(persona);
|
||||
}
|
||||
}}
|
||||
className={`px-1 py-0.5 rounded flex ${
|
||||
isEditable
|
||||
? "hover:bg-accent-background-hovered cursor-pointer"
|
||||
: ""
|
||||
} select-none w-fit`}
|
||||
>
|
||||
<div className="my-auto flex-none w-22">
|
||||
{!persona.is_default_persona ? (
|
||||
<div className="text-error">Not Featured</div>
|
||||
) : (
|
||||
"Featured"
|
||||
)}
|
||||
</div>
|
||||
<div className="ml-1 my-auto">
|
||||
<CustomCheckbox checked={persona.is_default_persona} />
|
||||
</div>
|
||||
</div>,
|
||||
<div
|
||||
key="is_visible"
|
||||
onClick={async () => {
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { deletePersona } from "../lib";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { SuccessfulPersonaUpdateRedirectType } from "../enums";
|
||||
|
||||
export function DeletePersonaButton({
|
||||
personaId,
|
||||
redirectType,
|
||||
}: {
|
||||
personaId: number;
|
||||
redirectType: SuccessfulPersonaUpdateRedirectType;
|
||||
}) {
|
||||
const router = useRouter();
|
||||
|
||||
return (
|
||||
<Button
|
||||
variant="destructive"
|
||||
onClick={async () => {
|
||||
const response = await deletePersona(personaId);
|
||||
if (response.ok) {
|
||||
router.push(
|
||||
redirectType === SuccessfulPersonaUpdateRedirectType.ADMIN
|
||||
? `/admin/assistants?u=${Date.now()}`
|
||||
: `/chat`
|
||||
);
|
||||
} else {
|
||||
alert(`Failed to delete persona - ${await response.text()}`);
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { AssistantEditor } from "../AssistantEditor";
|
||||
import { BackButton } from "@/components/BackButton";
|
||||
|
||||
import { DeletePersonaButton } from "./DeletePersonaButton";
|
||||
import { fetchAssistantEditorInfoSS } from "@/lib/assistants/fetchPersonaEditorInfoSS";
|
||||
import { SuccessfulPersonaUpdateRedirectType } from "../enums";
|
||||
import { RobotIcon } from "@/components/icons/icons";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import Title from "@/components/ui/title";
|
||||
|
||||
export default async function Page(props: { params: Promise<{ id: string }> }) {
|
||||
const params = await props.params;
|
||||
const [values, error] = await fetchAssistantEditorInfoSS(params.id);
|
||||
|
||||
let body;
|
||||
if (!values) {
|
||||
body = (
|
||||
<ErrorCallout errorTitle="Something went wrong :(" errorMsg={error} />
|
||||
);
|
||||
} else {
|
||||
body = (
|
||||
<>
|
||||
<CardSection className="!border-none !bg-transparent !ring-none">
|
||||
<AssistantEditor
|
||||
{...values}
|
||||
admin
|
||||
defaultPublic={true}
|
||||
redirectType={SuccessfulPersonaUpdateRedirectType.ADMIN}
|
||||
/>
|
||||
</CardSection>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="w-full">
|
||||
<AdminPageTitle title="Edit Assistant" icon={<RobotIcon size={32} />} />
|
||||
{body}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -261,6 +261,22 @@ export function personaComparator(a: Persona, b: Persona) {
|
||||
return closerToZeroNegativesFirstComparator(a.id, b.id);
|
||||
}
|
||||
|
||||
export const togglePersonaDefault = async (
|
||||
personaId: number,
|
||||
isDefault: boolean
|
||||
) => {
|
||||
const response = await fetch(`/api/admin/persona/${personaId}/default`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
is_default_persona: !isDefault,
|
||||
}),
|
||||
});
|
||||
return response;
|
||||
};
|
||||
|
||||
export const togglePersonaVisibility = async (
|
||||
personaId: number,
|
||||
isVisible: boolean
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
import { AssistantEditor } from "../AssistantEditor";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { fetchAssistantEditorInfoSS } from "@/lib/assistants/fetchPersonaEditorInfoSS";
|
||||
import { SuccessfulPersonaUpdateRedirectType } from "../enums";
|
||||
|
||||
export default async function Page() {
|
||||
const [values, error] = await fetchAssistantEditorInfoSS();
|
||||
|
||||
if (!values) {
|
||||
return (
|
||||
<ErrorCallout errorTitle="Something went wrong :(" errorMsg={error} />
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<div className="w-full">
|
||||
<AssistantEditor
|
||||
{...values}
|
||||
admin
|
||||
defaultPublic={true}
|
||||
redirectType={SuccessfulPersonaUpdateRedirectType.ADMIN}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -29,7 +29,7 @@ export default async function Page() {
|
||||
<Separator />
|
||||
|
||||
<Title>Create an Assistant</Title>
|
||||
<CreateButton href="/admin/assistants/new" text="New Assistant" />
|
||||
<CreateButton href="/assistants/new?admin=true" text="New Assistant" />
|
||||
|
||||
<Separator />
|
||||
|
||||
|
||||
@@ -100,7 +100,7 @@ export function SlackChannelConfigsTable({
|
||||
slackChannelConfig.persona
|
||||
) ? (
|
||||
<Link
|
||||
href={`/admin/assistants/${slackChannelConfig.persona.id}`}
|
||||
href={`/assistants/${slackChannelConfig.persona.id}`}
|
||||
className="text-primary hover:underline"
|
||||
>
|
||||
{slackChannelConfig.persona.name}
|
||||
|
||||
@@ -502,9 +502,10 @@ export default function AddConnector({
|
||||
{oauthSupportedSources.includes(connector) &&
|
||||
(NEXT_PUBLIC_CLOUD_ENABLED ||
|
||||
NEXT_PUBLIC_TEST_ENV) && (
|
||||
<button
|
||||
<Button
|
||||
variant="navigate"
|
||||
onClick={handleAuthorize}
|
||||
className="mt-6 text-sm bg-blue-500 px-2 py-1.5 flex text-text-200 flex-none rounded"
|
||||
className="mt-6 "
|
||||
disabled={isAuthorizing}
|
||||
hidden={!isAuthorizeVisible}
|
||||
>
|
||||
@@ -513,7 +514,7 @@ export default function AddConnector({
|
||||
: `Authorize with ${getSourceDisplayName(
|
||||
connector
|
||||
)}`}
|
||||
</button>
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -5,7 +5,7 @@ import useSWR, { mutate } from "swr";
|
||||
import { FetchError, errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { usePublicCredentials } from "@/lib/hooks";
|
||||
import Title from "@/components/ui/title";
|
||||
@@ -32,7 +32,11 @@ const useConnectorsByCredentialId = (credential_id: number | null) => {
|
||||
};
|
||||
};
|
||||
|
||||
const GDriveMain = ({}: {}) => {
|
||||
const GDriveMain = ({
|
||||
setPopup,
|
||||
}: {
|
||||
setPopup: (popup: PopupSpec | null) => void;
|
||||
}) => {
|
||||
const { isAdmin, user } = useUser();
|
||||
|
||||
// tries getting the uploaded credential json
|
||||
@@ -97,8 +101,6 @@ const GDriveMain = ({}: {}) => {
|
||||
refreshConnectorsByCredentialId,
|
||||
} = useConnectorsByCredentialId(credential_id);
|
||||
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
const appCredentialSuccessfullyFetched =
|
||||
appCredentialData ||
|
||||
(isAppCredentialError && isAppCredentialError.status === 404);
|
||||
@@ -173,10 +175,7 @@ const GDriveMain = ({}: {}) => {
|
||||
|
||||
return (
|
||||
<>
|
||||
{popup}
|
||||
<Title className="mb-2 mt-6 ml-auto mr-auto">
|
||||
Step 1: Provide your Credentials
|
||||
</Title>
|
||||
<Title className="mb-2 mt-6">Step 1: Provide your Credentials</Title>
|
||||
<DriveJsonUploadSection
|
||||
setPopup={setPopup}
|
||||
appCredentialData={appCredentialData}
|
||||
@@ -186,9 +185,7 @@ const GDriveMain = ({}: {}) => {
|
||||
|
||||
{isAdmin && (
|
||||
<>
|
||||
<Title className="mb-2 mt-6 ml-auto mr-auto">
|
||||
Step 2: Authenticate with Onyx
|
||||
</Title>
|
||||
<Title className="mb-2 mt-6">Step 2: Authenticate with Onyx</Title>
|
||||
<DriveAuthSection
|
||||
setPopup={setPopup}
|
||||
refreshCredentials={refreshCredentials}
|
||||
|
||||
@@ -188,7 +188,7 @@ export const DocumentSetCreationForm = ({
|
||||
flex
|
||||
cursor-pointer ` +
|
||||
(isSelected
|
||||
? " bg-background-strong"
|
||||
? " bg-background-200"
|
||||
: " hover:bg-accent-background-hovered")
|
||||
}
|
||||
onClick={() => {
|
||||
@@ -304,7 +304,7 @@ export const DocumentSetCreationForm = ({
|
||||
flex
|
||||
cursor-pointer ` +
|
||||
(isSelected
|
||||
? " bg-background-strong"
|
||||
? " bg-background-200"
|
||||
: " hover:bg-accent-background-hovered")
|
||||
}
|
||||
onClick={() => {
|
||||
|
||||
@@ -235,8 +235,8 @@ export function EmbeddingModelSelection({
|
||||
onClick={() => setModelTab(null)}
|
||||
className={`mr-4 p-2 font-bold ${
|
||||
!modelTab
|
||||
? "rounded bg-background-900 dark:bg-neutral-900 text-text-100 dark:text-neutral-100 underline"
|
||||
: " hover:underline bg-background-100 dark:bg-neutral-700"
|
||||
? "rounded bg-neutral-900 dark:bg-neutral-950 text-neutral-100 dark:text-neutral-300 underline"
|
||||
: " hover:underline bg-neutral-100 dark:bg-neutral-900"
|
||||
}`}
|
||||
>
|
||||
Current
|
||||
@@ -246,8 +246,8 @@ export function EmbeddingModelSelection({
|
||||
onClick={() => setModelTab("cloud")}
|
||||
className={`mx-2 p-2 font-bold ${
|
||||
modelTab == "cloud"
|
||||
? "rounded bg-background-900 dark:bg-neutral-900 text-text-100 dark:text-neutral-100 underline"
|
||||
: " hover:underline bg-background-100 dark:bg-neutral-700"
|
||||
? "rounded bg-neutral-900 dark:bg-neutral-950 text-neutral-100 dark:text-neutral-300 underline"
|
||||
: " hover:underline bg-neutral-100 dark:bg-neutral-900"
|
||||
}`}
|
||||
>
|
||||
Cloud-based
|
||||
@@ -258,8 +258,8 @@ export function EmbeddingModelSelection({
|
||||
onClick={() => setModelTab("open")}
|
||||
className={` mx-2 p-2 font-bold ${
|
||||
modelTab == "open"
|
||||
? "rounded bg-background-900 dark:bg-neutral-900 text-text-100 dark:text-neutral-100 underline"
|
||||
: "hover:underline bg-background-100 dark:bg-neutral-700"
|
||||
? "rounded bg-neutral-900 dark:bg-neutral-950 text-neutral-100 dark:text-neutral-300 underline"
|
||||
: "hover:underline bg-neutral-100 dark:bg-neutral-900"
|
||||
}`}
|
||||
>
|
||||
Self-hosted
|
||||
|
||||
@@ -116,8 +116,8 @@ const RerankingDetailsForm = forwardRef<
|
||||
onClick={() => setModelTab("cloud")}
|
||||
className={`mr-2 p-2 font-bold ${
|
||||
modelTab == "cloud"
|
||||
? "rounded bg-background-900 text-text-100 underline"
|
||||
: " hover:underline bg-background-100"
|
||||
? "rounded bg-neutral-900 dark:bg-neutral-950 text-neutral-100 dark:text-neutral-300 underline"
|
||||
: " hover:underline bg-neutral-100 dark:bg-neutral-900"
|
||||
}`}
|
||||
>
|
||||
Cloud-based
|
||||
@@ -129,8 +129,8 @@ const RerankingDetailsForm = forwardRef<
|
||||
onClick={() => setModelTab("open")}
|
||||
className={` mx-2 p-2 font-bold ${
|
||||
modelTab == "open"
|
||||
? "rounded bg-background-900 text-text-100 underline"
|
||||
: "hover:underline bg-background-100"
|
||||
? "rounded bg-neutral-900 dark:bg-neutral-950 text-neutral-100 dark:text-neutral-300 underline"
|
||||
: "hover:underline bg-neutral-100 dark:bg-neutral-900"
|
||||
}`}
|
||||
>
|
||||
Self-hosted
|
||||
@@ -140,7 +140,7 @@ const RerankingDetailsForm = forwardRef<
|
||||
<div className="px-2">
|
||||
<button
|
||||
onClick={() => resetRerankingValues()}
|
||||
className="mx-2 p-2 font-bold rounded bg-background-100 text-text-900 hover:underline"
|
||||
className={`mx-2 p-2 font-bold rounded bg-neutral-100 dark:bg-neutral-900 text-neutral-900 dark:text-neutral-100 hover:underline`}
|
||||
>
|
||||
Remove Reranking
|
||||
</button>
|
||||
@@ -177,7 +177,7 @@ const RerankingDetailsForm = forwardRef<
|
||||
key={`${card.rerank_provider_type}-${card.modelName}`}
|
||||
className={`p-4 border rounded-lg cursor-pointer transition-all duration-200 ${
|
||||
isSelected
|
||||
? "border-blue-500 bg-blue-50 dark:bg-blue-900 dark:border-blue-700 shadow-md"
|
||||
? "border-blue-800 bg-blue-50 dark:bg-blue-950 dark:border-blue-700 shadow-md"
|
||||
: "border-background-200 hover:border-blue-300 hover:shadow-sm dark:border-neutral-700 dark:hover:border-blue-300"
|
||||
}`}
|
||||
onClick={() => {
|
||||
|
||||
@@ -19,7 +19,7 @@ import { Dispatch, SetStateAction, useEffect, useState } from "react";
|
||||
import { CustomEmbeddingModelForm } from "@/components/embedding/CustomEmbeddingModelForm";
|
||||
import { deleteSearchSettings } from "./utils";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal";
|
||||
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
|
||||
import { AdvancedSearchConfiguration } from "../interfaces";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
|
||||
@@ -322,7 +322,7 @@ export default function CloudEmbeddingPage({
|
||||
OpenAI for embeddings.
|
||||
</Text>
|
||||
<div className="flex items-center text-sm text-text-700">
|
||||
<FiInfo className="text-text-400 mr-2" size={16} />
|
||||
<FiInfo className="text-neutral-400 mr-2" size={16} />
|
||||
<Text>
|
||||
You'll need: API version, base URL, API key, model
|
||||
name, and deployment name.
|
||||
@@ -456,7 +456,7 @@ export function CloudModelCard({
|
||||
>
|
||||
{popup}
|
||||
{showDeleteModel && (
|
||||
<DeleteEntityModal
|
||||
<ConfirmEntityModal
|
||||
entityName={model.model_name}
|
||||
entityType="embedding model configuration"
|
||||
onSubmit={() => deleteModel()}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
export enum GatingType {
|
||||
FULL = "full",
|
||||
PARTIAL = "partial",
|
||||
NONE = "none",
|
||||
export enum ApplicationStatus {
|
||||
PAYMENT_REMINDER = "payment_reminder",
|
||||
GATED_ACCESS = "gated_access",
|
||||
ACTIVE = "active",
|
||||
}
|
||||
|
||||
export interface Settings {
|
||||
@@ -11,7 +11,7 @@ export interface Settings {
|
||||
needs_reindexing: boolean;
|
||||
gpu_enabled: boolean;
|
||||
pro_search_disabled: boolean | null;
|
||||
product_gating: GatingType;
|
||||
application_status: ApplicationStatus;
|
||||
auto_scroll: boolean;
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ export default function SidebarWrapper<T extends object>({
|
||||
size = "sm",
|
||||
children,
|
||||
}: SidebarWrapperProps<T>) {
|
||||
const [toggledSidebar, setToggledSidebar] = useState(initiallyToggled);
|
||||
const [sidebarVisible, setSidebarVisible] = useState(initiallyToggled);
|
||||
const [showDocSidebar, setShowDocSidebar] = useState(false); // State to track if sidebar is open
|
||||
// Used to maintain a "time out" for history sidebar so our existing refs can have time to process change
|
||||
const [untoggled, setUntoggled] = useState(false);
|
||||
@@ -41,13 +41,13 @@ export default function SidebarWrapper<T extends object>({
|
||||
const toggleSidebar = useCallback(() => {
|
||||
Cookies.set(
|
||||
SIDEBAR_TOGGLED_COOKIE_NAME,
|
||||
String(!toggledSidebar).toLocaleLowerCase()
|
||||
String(!sidebarVisible).toLocaleLowerCase()
|
||||
),
|
||||
{
|
||||
path: "/",
|
||||
};
|
||||
setToggledSidebar((toggledSidebar) => !toggledSidebar);
|
||||
}, [toggledSidebar]);
|
||||
setSidebarVisible((sidebarVisible) => !sidebarVisible);
|
||||
}, [sidebarVisible]);
|
||||
|
||||
const sidebarElementRef = useRef<HTMLDivElement>(null);
|
||||
const { folders, openedFolders, chatSessions } = useChatContext();
|
||||
@@ -63,7 +63,7 @@ export default function SidebarWrapper<T extends object>({
|
||||
|
||||
const settings = useContext(SettingsContext);
|
||||
useSidebarVisibility({
|
||||
toggledSidebar,
|
||||
sidebarVisible,
|
||||
sidebarElementRef,
|
||||
showDocSidebar,
|
||||
setShowDocSidebar,
|
||||
@@ -94,7 +94,7 @@ export default function SidebarWrapper<T extends object>({
|
||||
duration-300
|
||||
ease-in-out
|
||||
${
|
||||
!untoggled && (showDocSidebar || toggledSidebar)
|
||||
!untoggled && (showDocSidebar || sidebarVisible)
|
||||
? "opacity-100 w-[250px] translate-x-0"
|
||||
: "opacity-0 w-[200px] pointer-events-none -translate-x-10"
|
||||
}`}
|
||||
@@ -107,7 +107,7 @@ export default function SidebarWrapper<T extends object>({
|
||||
explicitlyUntoggle={explicitlyUntoggle}
|
||||
ref={sidebarElementRef}
|
||||
toggleSidebar={toggleSidebar}
|
||||
toggled={toggledSidebar}
|
||||
toggled={sidebarVisible}
|
||||
existingChats={chatSessions}
|
||||
currentChatSession={null}
|
||||
folders={folders}
|
||||
@@ -117,7 +117,7 @@ export default function SidebarWrapper<T extends object>({
|
||||
|
||||
<div className="absolute px-2 left-0 w-full top-0">
|
||||
<FunctionalHeader
|
||||
sidebarToggled={toggledSidebar}
|
||||
sidebarToggled={sidebarVisible}
|
||||
toggleSidebar={toggleSidebar}
|
||||
page="chat"
|
||||
/>
|
||||
@@ -132,7 +132,7 @@ export default function SidebarWrapper<T extends object>({
|
||||
bg-opacity-80
|
||||
duration-300
|
||||
ease-in-out
|
||||
${toggledSidebar ? "w-[250px]" : "w-[0px]"}`}
|
||||
${sidebarVisible ? "w-[250px]" : "w-[0px]"}`}
|
||||
/>
|
||||
|
||||
<div
|
||||
@@ -144,7 +144,7 @@ export default function SidebarWrapper<T extends object>({
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<FixedLogo backgroundToggled={toggledSidebar || showDocSidebar} />
|
||||
<FixedLogo backgroundToggled={sidebarVisible || showDocSidebar} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useContext, useState, useRef, useLayoutEffect } from "react";
|
||||
import React, { useState, useRef, useLayoutEffect } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import {
|
||||
FiMoreHorizontal,
|
||||
@@ -8,7 +8,7 @@ import {
|
||||
FiLock,
|
||||
FiUnlock,
|
||||
} from "react-icons/fi";
|
||||
import { FaHashtag } from "react-icons/fa";
|
||||
|
||||
import {
|
||||
Popover,
|
||||
PopoverTrigger,
|
||||
@@ -26,14 +26,12 @@ import {
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { PinnedIcon } from "@/components/icons/icons";
|
||||
import {
|
||||
deletePersona,
|
||||
togglePersonaPublicStatus,
|
||||
} from "@/app/admin/assistants/lib";
|
||||
import { deletePersona } from "@/app/admin/assistants/lib";
|
||||
import { PencilIcon } from "lucide-react";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import { truncateString } from "@/lib/utils";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
export const AssistantBadge = ({
|
||||
text,
|
||||
@@ -63,6 +61,7 @@ const AssistantCard: React.FC<{
|
||||
const { user, toggleAssistantPinnedStatus } = useUser();
|
||||
const router = useRouter();
|
||||
const { refreshAssistants, pinnedAssistants } = useAssistants();
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
const isOwnedByUser = checkUserOwnsAssistant(user, persona);
|
||||
|
||||
@@ -72,7 +71,34 @@ const AssistantCard: React.FC<{
|
||||
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
|
||||
const handleDelete = () => setActivePopover("delete");
|
||||
const [isDeleteConfirmation, setIsDeleteConfirmation] = useState(false);
|
||||
|
||||
const handleDelete = () => {
|
||||
setIsDeleteConfirmation(true);
|
||||
};
|
||||
|
||||
const confirmDelete = async () => {
|
||||
const response = await deletePersona(persona.id);
|
||||
if (response.ok) {
|
||||
await refreshAssistants();
|
||||
setActivePopover(null);
|
||||
setIsDeleteConfirmation(false);
|
||||
setPopup({
|
||||
message: `${persona.name} has been successfully deleted.`,
|
||||
type: "success",
|
||||
});
|
||||
} else {
|
||||
setPopup({
|
||||
message: `Failed to delete assistant - ${await response.text()}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const cancelDelete = () => {
|
||||
setIsDeleteConfirmation(false);
|
||||
};
|
||||
|
||||
const handleEdit = () => {
|
||||
router.push(`/assistants/edit/${persona.id}`);
|
||||
setActivePopover(null);
|
||||
@@ -100,6 +126,7 @@ const AssistantCard: React.FC<{
|
||||
|
||||
return (
|
||||
<div className="w-full text-text-800 p-2 overflow-visible pb-4 pt-3 bg-transparent dark:bg-neutral-800/80 rounded shadow-[0px_0px_4px_0px_rgba(0,0,0,0.25)] flex flex-col">
|
||||
{popup}
|
||||
<div className="w-full flex">
|
||||
<div className="ml-2 flex-none mr-2 mt-1 w-10 h-10">
|
||||
<AssistantIcon assistant={persona} size="large" />
|
||||
@@ -157,55 +184,84 @@ const AssistantCard: React.FC<{
|
||||
<FiMoreHorizontal size={16} />
|
||||
</button>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className={`w-32 z-[10000] p-2`}>
|
||||
<div className="flex flex-col text-sm space-y-1">
|
||||
<button
|
||||
onClick={isOwnedByUser ? handleEdit : undefined}
|
||||
className={`w-full flex items-center text-left px-2 py-1 rounded ${
|
||||
isOwnedByUser
|
||||
? "hover:bg-neutral-200 dark:hover:bg-neutral-700"
|
||||
: "opacity-50 cursor-not-allowed"
|
||||
}`}
|
||||
disabled={!isOwnedByUser}
|
||||
>
|
||||
<FiEdit size={12} className="inline mr-2" />
|
||||
Edit
|
||||
</button>
|
||||
{isPaidEnterpriseFeaturesEnabled && isOwnedByUser && (
|
||||
<PopoverContent
|
||||
className={`${
|
||||
isDeleteConfirmation ? "w-64" : "w-32"
|
||||
} z-[10000] p-2`}
|
||||
>
|
||||
{!isDeleteConfirmation ? (
|
||||
<div className="flex flex-col text-sm space-y-1">
|
||||
<button
|
||||
onClick={
|
||||
onClick={isOwnedByUser ? handleEdit : undefined}
|
||||
className={`w-full flex items-center text-left px-2 py-1 rounded ${
|
||||
isOwnedByUser
|
||||
? () => {
|
||||
router.push(
|
||||
`/assistants/stats/${persona.id}`
|
||||
);
|
||||
closePopover();
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
className={`w-full text-left items-center px-2 py-1 rounded ${
|
||||
isOwnedByUser
|
||||
? "hover:bg-neutral-200 dark:hover:bg-neutral-800"
|
||||
? "hover:bg-neutral-200 dark:hover:bg-neutral-700"
|
||||
: "opacity-50 cursor-not-allowed"
|
||||
}`}
|
||||
disabled={!isOwnedByUser}
|
||||
>
|
||||
<FiBarChart size={12} className="inline mr-2" />
|
||||
Stats
|
||||
<FiEdit size={12} className="inline mr-2" />
|
||||
Edit
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
onClick={isOwnedByUser ? handleDelete : undefined}
|
||||
className={`w-full text-left items-center px-2 py-1 rounded ${
|
||||
isOwnedByUser
|
||||
? "hover:bg-neutral-200 dark:hover:bg-neutral- text-red-600 dark:text-red-400"
|
||||
: "opacity-50 cursor-not-allowed text-red-300 dark:text-red-500"
|
||||
}`}
|
||||
disabled={!isOwnedByUser}
|
||||
>
|
||||
<FiTrash size={12} className="inline mr-2" />
|
||||
Delete
|
||||
</button>
|
||||
</div>
|
||||
{isPaidEnterpriseFeaturesEnabled && isOwnedByUser && (
|
||||
<button
|
||||
onClick={
|
||||
isOwnedByUser
|
||||
? () => {
|
||||
router.push(
|
||||
`/assistants/stats/${persona.id}`
|
||||
);
|
||||
closePopover();
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
className={`w-full text-left items-center px-2 py-1 rounded ${
|
||||
isOwnedByUser
|
||||
? "hover:bg-neutral-200 dark:hover:bg-neutral-800"
|
||||
: "opacity-50 cursor-not-allowed"
|
||||
}`}
|
||||
>
|
||||
<FiBarChart size={12} className="inline mr-2" />
|
||||
Stats
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
onClick={isOwnedByUser ? handleDelete : undefined}
|
||||
className={`w-full text-left items-center px-2 py-1 rounded ${
|
||||
isOwnedByUser
|
||||
? "hover:bg-neutral-200 dark:hover:bg-neutral- text-red-600 dark:text-red-400"
|
||||
: "opacity-50 cursor-not-allowed text-red-300 dark:text-red-500"
|
||||
}`}
|
||||
disabled={!isOwnedByUser}
|
||||
>
|
||||
<FiTrash size={12} className="inline mr-2" />
|
||||
Delete
|
||||
</button>
|
||||
</div>
|
||||
) : (
|
||||
<div className="w-full">
|
||||
<p className="text-sm mb-3">
|
||||
Are you sure you want to delete assistant{" "}
|
||||
<b>{persona.name}</b>?
|
||||
</p>
|
||||
<div className="flex justify-center gap-2">
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="sm"
|
||||
onClick={cancelDelete}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
onClick={confirmDelete}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
</div>
|
||||
|
||||
@@ -5,10 +5,8 @@ import { useRouter } from "next/navigation";
|
||||
import AssistantCard from "./AssistantCard";
|
||||
import { useAssistants } from "@/components/context/AssistantsContext";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import { FilterIcon } from "lucide-react";
|
||||
import { FilterIcon, XIcon } from "lucide-react";
|
||||
import { checkUserOwnsAssistant } from "@/lib/assistants/checkOwnership";
|
||||
import { Dialog, DialogContent } from "@/components/ui/dialog";
|
||||
import { Modal } from "@/components/Modal";
|
||||
|
||||
export const AssistantBadgeSelector = ({
|
||||
text,
|
||||
@@ -24,8 +22,8 @@ export const AssistantBadgeSelector = ({
|
||||
className={`
|
||||
select-none ${
|
||||
selected
|
||||
? "bg-neutral-900 text-white"
|
||||
: "bg-transparent text-neutral-900"
|
||||
? "bg-background-900 text-white"
|
||||
: "bg-transparent text-text-900"
|
||||
} w-12 h-5 text-center px-1 py-0.5 rounded-lg cursor-pointer text-[12px] font-normal leading-[10px] border border-black justify-center items-center gap-1 inline-flex`}
|
||||
onClick={toggleFilter}
|
||||
>
|
||||
@@ -109,16 +107,20 @@ export function AssistantModal({
|
||||
|
||||
const featuredAssistants = [
|
||||
...memoizedCurrentlyVisibleAssistants.filter(
|
||||
(assistant) => assistant.builtin_persona || assistant.is_default_persona
|
||||
(assistant) => assistant.is_default_persona
|
||||
),
|
||||
];
|
||||
const allAssistants = memoizedCurrentlyVisibleAssistants.filter(
|
||||
(assistant) => !assistant.builtin_persona && !assistant.is_default_persona
|
||||
(assistant) => !assistant.is_default_persona
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="fixed inset-0 bg-neutral-950/80 bg-opacity-50 flex items-center justify-center z-50">
|
||||
<div
|
||||
onClick={hideModal}
|
||||
className="fixed inset-0 bg-neutral-950/80 bg-opacity-50 flex items-center justify-center z-50"
|
||||
>
|
||||
<div
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
className="p-0 max-w-4xl overflow-hidden max-h-[80vh] w-[95%] bg-background rounded-md shadow-2xl transform transition-all duration-300 ease-in-out relative w-11/12 max-w-4xl pt-10 pb-10 px-10 overflow-hidden flex flex-col"
|
||||
style={{
|
||||
position: "fixed",
|
||||
@@ -128,129 +130,142 @@ export function AssistantModal({
|
||||
margin: 0,
|
||||
}}
|
||||
>
|
||||
<div className="absolute top-2 right-2">
|
||||
<button
|
||||
onClick={hideModal}
|
||||
className="cursor-pointer text-neutral-500 hover:text-neutral-700 dark:text-neutral-400 dark:hover:text-neutral-300 transition-colors duration-200 p-2"
|
||||
aria-label="Close modal"
|
||||
>
|
||||
<XIcon className="w-5 h-5" />
|
||||
</button>
|
||||
</div>
|
||||
<div className="flex overflow-hidden flex-col h-full">
|
||||
<div className="flex flex-col sticky top-0 z-10">
|
||||
<div className="flex px-2 justify-between items-center gap-x-2 mb-0">
|
||||
<div className="h-12 w-full rounded-lg flex-col justify-center items-start gap-2.5 inline-flex">
|
||||
<div className="h-12 rounded-md w-full shadow-[0px_0px_2px_0px_rgba(0,0,0,0.25)] border border-[#dcdad4] flex items-center px-3">
|
||||
{!isSearchFocused && (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="h-5 w-5 text-gray-400"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
strokeWidth={2}
|
||||
d="M21 21l-6-6m2-5a7 7 0 11-14 0 7 7 0 0114 0z"
|
||||
/>
|
||||
</svg>
|
||||
)}
|
||||
<input
|
||||
value={searchQuery}
|
||||
onChange={(e) => setSearchQuery(e.target.value)}
|
||||
onFocus={() => setIsSearchFocused(true)}
|
||||
onBlur={() => setIsSearchFocused(false)}
|
||||
type="text"
|
||||
className="w-full h-full bg-transparent outline-none text-black"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<button
|
||||
onClick={() => router.push("/assistants/new")}
|
||||
className="h-10 cursor-pointer px-6 py-3 bg-black rounded-md border border-black justify-center items-center gap-2.5 inline-flex"
|
||||
>
|
||||
<div className="text-[#fffcf4] text-lg font-normal leading-normal">
|
||||
Create
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
<div className="px-2 flex py-4 items-center gap-x-2 flex-wrap">
|
||||
<FilterIcon size={16} />
|
||||
<AssistantBadgeSelector
|
||||
text="Pinned"
|
||||
selected={assistantFilters[AssistantFilter.Pinned]}
|
||||
toggleFilter={() =>
|
||||
toggleAssistantFilter(AssistantFilter.Pinned)
|
||||
}
|
||||
/>
|
||||
|
||||
<AssistantBadgeSelector
|
||||
text="Mine"
|
||||
selected={assistantFilters[AssistantFilter.Mine]}
|
||||
toggleFilter={() => toggleAssistantFilter(AssistantFilter.Mine)}
|
||||
/>
|
||||
<AssistantBadgeSelector
|
||||
text="Private"
|
||||
selected={assistantFilters[AssistantFilter.Private]}
|
||||
toggleFilter={() =>
|
||||
toggleAssistantFilter(AssistantFilter.Private)
|
||||
}
|
||||
/>
|
||||
<AssistantBadgeSelector
|
||||
text="Public"
|
||||
selected={assistantFilters[AssistantFilter.Public]}
|
||||
toggleFilter={() =>
|
||||
toggleAssistantFilter(AssistantFilter.Public)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<div className="w-full border-t border-neutral-200" />
|
||||
</div>
|
||||
|
||||
<div className="flex-grow overflow-y-auto">
|
||||
<h2 className="text-2xl font-semibold text-gray-800 mb-2 px-4 py-2">
|
||||
Featured Assistants
|
||||
</h2>
|
||||
|
||||
<div className="w-full px-2 pb-2 grid grid-cols-1 md:grid-cols-2 gap-x-6 gap-y-6">
|
||||
{featuredAssistants.length > 0 ? (
|
||||
featuredAssistants.map((assistant, index) => (
|
||||
<div key={index}>
|
||||
<AssistantCard
|
||||
pinned={pinnedAssistants
|
||||
.map((a) => a.id)
|
||||
.includes(assistant.id)}
|
||||
persona={assistant}
|
||||
closeModal={hideModal}
|
||||
<div className="flex overflow-hidden flex-col h-full">
|
||||
<div className="flex flex-col sticky top-0 z-10">
|
||||
<div className="flex px-2 justify-between items-center gap-x-2 mb-0">
|
||||
<div className="h-12 w-full rounded-lg flex-col justify-center items-start gap-2.5 inline-flex">
|
||||
<div className="h-12 rounded-md w-full shadow-[0px_0px_2px_0px_rgba(0,0,0,0.25)] border border-background-300 flex items-center px-3">
|
||||
{!isSearchFocused && (
|
||||
<svg
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
className="h-5 w-5 text-text-400"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
strokeWidth={2}
|
||||
d="M21 21l-6-6m2-5a7 7 0 11-14 0 7 7 0 0114 0z"
|
||||
/>
|
||||
</svg>
|
||||
)}
|
||||
<input
|
||||
value={searchQuery}
|
||||
onChange={(e) => setSearchQuery(e.target.value)}
|
||||
onFocus={() => setIsSearchFocused(true)}
|
||||
onBlur={() => setIsSearchFocused(false)}
|
||||
type="text"
|
||||
className="w-full h-full bg-transparent outline-none text-black"
|
||||
/>
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
<div className="col-span-2 text-center text-gray-500">
|
||||
No featured assistants match filters
|
||||
</div>
|
||||
)}
|
||||
<button
|
||||
onClick={() => router.push("/assistants/new")}
|
||||
className="h-10 cursor-pointer px-6 py-3 bg-background-800 hover:bg-black rounded-md border border-black justify-center items-center gap-2.5 inline-flex"
|
||||
>
|
||||
<div className="text-text-50 text-lg font-normal leading-normal">
|
||||
Create
|
||||
</div>
|
||||
</button>
|
||||
</div>
|
||||
<div className="px-2 flex py-4 items-center gap-x-2 flex-wrap">
|
||||
<FilterIcon className="text-text-800" size={16} />
|
||||
<AssistantBadgeSelector
|
||||
text="Pinned"
|
||||
selected={assistantFilters[AssistantFilter.Pinned]}
|
||||
toggleFilter={() =>
|
||||
toggleAssistantFilter(AssistantFilter.Pinned)
|
||||
}
|
||||
/>
|
||||
|
||||
<AssistantBadgeSelector
|
||||
text="Mine"
|
||||
selected={assistantFilters[AssistantFilter.Mine]}
|
||||
toggleFilter={() =>
|
||||
toggleAssistantFilter(AssistantFilter.Mine)
|
||||
}
|
||||
/>
|
||||
<AssistantBadgeSelector
|
||||
text="Private"
|
||||
selected={assistantFilters[AssistantFilter.Private]}
|
||||
toggleFilter={() =>
|
||||
toggleAssistantFilter(AssistantFilter.Private)
|
||||
}
|
||||
/>
|
||||
<AssistantBadgeSelector
|
||||
text="Public"
|
||||
selected={assistantFilters[AssistantFilter.Public]}
|
||||
toggleFilter={() =>
|
||||
toggleAssistantFilter(AssistantFilter.Public)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
<div className="w-full border-t border-background-200" />
|
||||
</div>
|
||||
|
||||
{allAssistants && allAssistants.length > 0 && (
|
||||
<>
|
||||
<h2 className="text-2xl font-semibold text-gray-800 mt-4 mb-2 px-4 py-2">
|
||||
All Assistants
|
||||
</h2>
|
||||
<div className="flex-grow overflow-y-auto">
|
||||
<h2 className="text-2xl font-semibold text-text-800 mb-2 px-4 py-2">
|
||||
Featured Assistants
|
||||
</h2>
|
||||
|
||||
<div className="w-full mt-2 px-2 pb-2 grid grid-cols-1 md:grid-cols-2 gap-x-6 gap-y-6">
|
||||
{allAssistants
|
||||
.sort((a, b) => b.id - a.id)
|
||||
.map((assistant, index) => (
|
||||
<div key={index}>
|
||||
<AssistantCard
|
||||
pinned={
|
||||
user?.preferences?.pinned_assistants?.includes(
|
||||
assistant.id
|
||||
) ?? false
|
||||
}
|
||||
persona={assistant}
|
||||
closeModal={hideModal}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
<div className="w-full px-2 pb-10 grid grid-cols-1 md:grid-cols-2 gap-x-6 gap-y-6">
|
||||
{featuredAssistants.length > 0 ? (
|
||||
featuredAssistants.map((assistant, index) => (
|
||||
<div key={index}>
|
||||
<AssistantCard
|
||||
pinned={pinnedAssistants
|
||||
.map((a) => a.id)
|
||||
.includes(assistant.id)}
|
||||
persona={assistant}
|
||||
closeModal={hideModal}
|
||||
/>
|
||||
</div>
|
||||
))
|
||||
) : (
|
||||
<div className="col-span-2 text-center text-text-500">
|
||||
No featured assistants match filters
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{allAssistants && allAssistants.length > 0 && (
|
||||
<>
|
||||
<h2 className="text-2xl font-semibold text-text-800 mt-4 mb-2 px-4 py-2">
|
||||
All Assistants
|
||||
</h2>
|
||||
|
||||
<div className="w-full mt-2 px-2 pb-2 grid grid-cols-1 md:grid-cols-2 gap-x-6 gap-y-6">
|
||||
{allAssistants
|
||||
.sort((a, b) => b.id - a.id)
|
||||
.map((assistant, index) => (
|
||||
<div key={index}>
|
||||
<AssistantCard
|
||||
pinned={
|
||||
user?.preferences?.pinned_assistants?.includes(
|
||||
assistant.id
|
||||
) ?? false
|
||||
}
|
||||
persona={assistant}
|
||||
closeModal={hideModal}
|
||||
/>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -97,7 +97,6 @@ import {
|
||||
} from "@/components/resizable/constants";
|
||||
import FixedLogo from "../../components/logo/FixedLogo";
|
||||
|
||||
import { DeleteEntityModal } from "../../components/modals/DeleteEntityModal";
|
||||
import { MinimalMarkdown } from "@/components/chat/MinimalMarkdown";
|
||||
import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
|
||||
|
||||
@@ -130,6 +129,7 @@ import {
|
||||
useSidebarShortcut,
|
||||
} from "@/lib/browserUtilities";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
|
||||
|
||||
const TEMP_USER_MESSAGE_ID = -1;
|
||||
const TEMP_ASSISTANT_MESSAGE_ID = -2;
|
||||
@@ -138,12 +138,12 @@ const SYSTEM_MESSAGE_ID = -3;
|
||||
export function ChatPage({
|
||||
toggle,
|
||||
documentSidebarInitialWidth,
|
||||
toggledSidebar,
|
||||
sidebarVisible,
|
||||
firstMessage,
|
||||
}: {
|
||||
toggle: (toggled?: boolean) => void;
|
||||
documentSidebarInitialWidth?: number;
|
||||
toggledSidebar: boolean;
|
||||
sidebarVisible: boolean;
|
||||
firstMessage?: string;
|
||||
}) {
|
||||
const router = useRouter();
|
||||
@@ -204,7 +204,7 @@ export function ChatPage({
|
||||
const settings = useContext(SettingsContext);
|
||||
const enterpriseSettings = settings?.enterpriseSettings;
|
||||
|
||||
const [documentSidebarToggled, setDocumentSidebarToggled] = useState(false);
|
||||
const [documentSidebarVisible, setDocumentSidebarVisible] = useState(false);
|
||||
const [proSearchEnabled, setProSearchEnabled] = useState(proSearchToggled);
|
||||
const [streamingAllowed, setStreamingAllowed] = useState(false);
|
||||
const toggleProSearch = () => {
|
||||
@@ -243,7 +243,7 @@ export function ChatPage({
|
||||
if (user?.is_anonymous_user) {
|
||||
Cookies.set(
|
||||
SIDEBAR_TOGGLED_COOKIE_NAME,
|
||||
String(!toggledSidebar).toLocaleLowerCase()
|
||||
String(!sidebarVisible).toLocaleLowerCase()
|
||||
);
|
||||
toggle(false);
|
||||
}
|
||||
@@ -1024,10 +1024,10 @@ export function ChatPage({
|
||||
if (
|
||||
(!personaIncludesRetrieval &&
|
||||
(!selectedDocuments || selectedDocuments.length === 0) &&
|
||||
documentSidebarToggled) ||
|
||||
documentSidebarVisible) ||
|
||||
chatSessionIdRef.current == undefined
|
||||
) {
|
||||
setDocumentSidebarToggled(false);
|
||||
setDocumentSidebarVisible(false);
|
||||
}
|
||||
clientScrollToBottom();
|
||||
}, [chatSessionIdRef.current]);
|
||||
@@ -1122,6 +1122,7 @@ export function ChatPage({
|
||||
"Continue Generating (pick up exactly where you left off)",
|
||||
});
|
||||
};
|
||||
const [gener, setFinishedStreaming] = useState(false);
|
||||
|
||||
const onSubmit = async ({
|
||||
messageIdToResend,
|
||||
@@ -1272,6 +1273,7 @@ export function ChatPage({
|
||||
let finalMessage: BackendMessage | null = null;
|
||||
let toolCall: ToolCallMetadata | null = null;
|
||||
let isImprovement: boolean | undefined = undefined;
|
||||
let isStreamingQuestions = true;
|
||||
|
||||
let initialFetchDetails: null | {
|
||||
user_message_id: number;
|
||||
@@ -1442,11 +1444,22 @@ export function ChatPage({
|
||||
Object.hasOwn(packet, "stop_reason") &&
|
||||
Object.hasOwn(packet, "level_question_num")
|
||||
) {
|
||||
if ((packet as StreamStopInfo).stream_type == "main_answer") {
|
||||
setFinishedStreaming(true);
|
||||
updateChatState("streaming", frozenSessionId);
|
||||
}
|
||||
if (
|
||||
(packet as StreamStopInfo).stream_type == "sub_questions" &&
|
||||
(packet as StreamStopInfo).level_question_num == undefined
|
||||
) {
|
||||
isStreamingQuestions = false;
|
||||
}
|
||||
sub_questions = constructSubQuestions(
|
||||
sub_questions,
|
||||
packet as StreamStopInfo
|
||||
);
|
||||
} else if (Object.hasOwn(packet, "sub_question")) {
|
||||
updateChatState("toolBuilding", frozenSessionId);
|
||||
is_generating = true;
|
||||
sub_questions = constructSubQuestions(
|
||||
sub_questions,
|
||||
@@ -1606,6 +1619,7 @@ export function ChatPage({
|
||||
latestChildMessageId: initialFetchDetails.assistant_message_id,
|
||||
},
|
||||
{
|
||||
isStreamingQuestions: isStreamingQuestions,
|
||||
is_generating: is_generating,
|
||||
isImprovement: isImprovement,
|
||||
messageId: initialFetchDetails.assistant_message_id!,
|
||||
@@ -1613,7 +1627,7 @@ export function ChatPage({
|
||||
second_level_message: second_level_answer,
|
||||
type: error ? "error" : "assistant",
|
||||
retrievalType,
|
||||
query: finalMessage?.rephrased_query,
|
||||
query: finalMessage?.rephrased_query || query,
|
||||
documents: documents,
|
||||
citations: finalMessage?.citations || {},
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
@@ -1805,7 +1819,7 @@ export function ChatPage({
|
||||
}
|
||||
Cookies.set(
|
||||
SIDEBAR_TOGGLED_COOKIE_NAME,
|
||||
String(!toggledSidebar).toLocaleLowerCase()
|
||||
String(!sidebarVisible).toLocaleLowerCase()
|
||||
),
|
||||
{
|
||||
path: "/",
|
||||
@@ -1822,7 +1836,7 @@ export function ChatPage({
|
||||
const sidebarElementRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
useSidebarVisibility({
|
||||
toggledSidebar,
|
||||
sidebarVisible,
|
||||
sidebarElementRef,
|
||||
showDocSidebar: showHistorySidebar,
|
||||
setShowDocSidebar: setShowHistorySidebar,
|
||||
@@ -2003,7 +2017,7 @@ export function ChatPage({
|
||||
|
||||
useEffect(() => {
|
||||
if (!retrievalEnabled) {
|
||||
setDocumentSidebarToggled(false);
|
||||
setDocumentSidebarVisible(false);
|
||||
}
|
||||
}, [retrievalEnabled]);
|
||||
|
||||
@@ -2068,10 +2082,10 @@ export function ChatPage({
|
||||
const [showAssistantsModal, setShowAssistantsModal] = useState(false);
|
||||
|
||||
const toggleDocumentSidebar = () => {
|
||||
if (!documentSidebarToggled) {
|
||||
setDocumentSidebarToggled(true);
|
||||
if (!documentSidebarVisible) {
|
||||
setDocumentSidebarVisible(true);
|
||||
} else {
|
||||
setDocumentSidebarToggled(false);
|
||||
setDocumentSidebarVisible(false);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -2122,7 +2136,7 @@ export function ChatPage({
|
||||
<ChatPopup />
|
||||
|
||||
{showDeleteAllModal && (
|
||||
<DeleteEntityModal
|
||||
<ConfirmEntityModal
|
||||
entityType="All Chats"
|
||||
entityName="all your chat sessions"
|
||||
onClose={() => setShowDeleteAllModal(false)}
|
||||
@@ -2178,11 +2192,11 @@ export function ChatPage({
|
||||
/>
|
||||
)}
|
||||
|
||||
{retrievalEnabled && documentSidebarToggled && settings?.isMobile && (
|
||||
{retrievalEnabled && documentSidebarVisible && settings?.isMobile && (
|
||||
<div className="md:hidden">
|
||||
<Modal
|
||||
hideDividerForTitle
|
||||
onOutsideClick={() => setDocumentSidebarToggled(false)}
|
||||
onOutsideClick={() => setDocumentSidebarVisible(false)}
|
||||
title="Sources"
|
||||
>
|
||||
<DocumentResults
|
||||
@@ -2198,7 +2212,7 @@ export function ChatPage({
|
||||
modal={true}
|
||||
ref={innerSidebarElementRef}
|
||||
closeSidebar={() => {
|
||||
setDocumentSidebarToggled(false);
|
||||
setDocumentSidebarVisible(false);
|
||||
}}
|
||||
selectedMessage={aiMessage}
|
||||
selectedDocuments={selectedDocuments}
|
||||
@@ -2278,21 +2292,21 @@ export function ChatPage({
|
||||
duration-300
|
||||
ease-in-out
|
||||
${
|
||||
!untoggled && (showHistorySidebar || toggledSidebar)
|
||||
!untoggled && (showHistorySidebar || sidebarVisible)
|
||||
? "opacity-100 w-[250px] translate-x-0"
|
||||
: "opacity-0 w-[250px] pointer-events-none -translate-x-10"
|
||||
}`}
|
||||
>
|
||||
<div className="w-full relative">
|
||||
<HistorySidebar
|
||||
liveAssistant={liveAssistant}
|
||||
setShowAssistantsModal={setShowAssistantsModal}
|
||||
explicitlyUntoggle={explicitlyUntoggle}
|
||||
reset={() => setMessage("")}
|
||||
page="chat"
|
||||
ref={innerSidebarElementRef}
|
||||
toggleSidebar={toggleSidebar}
|
||||
toggled={toggledSidebar}
|
||||
currentAssistantId={liveAssistant?.id}
|
||||
toggled={sidebarVisible}
|
||||
existingChats={chatSessions}
|
||||
currentChatSession={selectedChatSession}
|
||||
folders={folders}
|
||||
@@ -2314,7 +2328,7 @@ export function ChatPage({
|
||||
duration-300
|
||||
ease-in-out
|
||||
${
|
||||
documentSidebarToggled &&
|
||||
documentSidebarVisible &&
|
||||
!settings?.isMobile &&
|
||||
"opacity-100 w-[350px]"
|
||||
}`}
|
||||
@@ -2339,7 +2353,7 @@ export function ChatPage({
|
||||
ease-in-out
|
||||
h-full
|
||||
${
|
||||
documentSidebarToggled && !settings?.isMobile
|
||||
documentSidebarVisible && !settings?.isMobile
|
||||
? "w-[400px]"
|
||||
: "w-[0px]"
|
||||
}
|
||||
@@ -2358,7 +2372,7 @@ export function ChatPage({
|
||||
modal={false}
|
||||
ref={innerSidebarElementRef}
|
||||
closeSidebar={() =>
|
||||
setTimeout(() => setDocumentSidebarToggled(false), 300)
|
||||
setTimeout(() => setDocumentSidebarVisible(false), 300)
|
||||
}
|
||||
selectedMessage={aiMessage}
|
||||
selectedDocuments={selectedDocuments}
|
||||
@@ -2367,12 +2381,12 @@ export function ChatPage({
|
||||
selectedDocumentTokens={selectedDocumentTokens}
|
||||
maxTokens={maxTokens}
|
||||
initialWidth={400}
|
||||
isOpen={documentSidebarToggled && !settings?.isMobile}
|
||||
isOpen={documentSidebarVisible && !settings?.isMobile}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<BlurBackground
|
||||
visible={!untoggled && (showHistorySidebar || toggledSidebar)}
|
||||
visible={!untoggled && (showHistorySidebar || sidebarVisible)}
|
||||
onClick={() => toggleSidebar()}
|
||||
/>
|
||||
|
||||
@@ -2387,7 +2401,7 @@ export function ChatPage({
|
||||
{liveAssistant && (
|
||||
<FunctionalHeader
|
||||
toggleUserSettings={() => setUserSettingsToggled(true)}
|
||||
sidebarToggled={toggledSidebar}
|
||||
sidebarToggled={sidebarVisible}
|
||||
reset={() => setMessage("")}
|
||||
page="chat"
|
||||
setSharingModalVisible={
|
||||
@@ -2395,8 +2409,8 @@ export function ChatPage({
|
||||
? setSharingModalVisible
|
||||
: undefined
|
||||
}
|
||||
documentSidebarToggled={
|
||||
documentSidebarToggled && !settings?.isMobile
|
||||
documentSidebarVisible={
|
||||
documentSidebarVisible && !settings?.isMobile
|
||||
}
|
||||
toggleSidebar={toggleSidebar}
|
||||
currentChatSession={selectedChatSession}
|
||||
@@ -2424,7 +2438,7 @@ export function ChatPage({
|
||||
duration-300
|
||||
ease-in-out
|
||||
h-full
|
||||
${toggledSidebar ? "w-[200px]" : "w-[0px]"}
|
||||
${sidebarVisible ? "w-[200px]" : "w-[0px]"}
|
||||
`}
|
||||
></div>
|
||||
)}
|
||||
@@ -2450,7 +2464,7 @@ export function ChatPage({
|
||||
duration-300
|
||||
ease-in-out
|
||||
h-full
|
||||
${toggledSidebar ? "w-[200px]" : "w-[0px]"}
|
||||
${sidebarVisible ? "w-[200px]" : "w-[0px]"}
|
||||
`}
|
||||
></div>
|
||||
)}
|
||||
@@ -2635,8 +2649,14 @@ export function ChatPage({
|
||||
{message.sub_questions &&
|
||||
message.sub_questions.length > 0 ? (
|
||||
<AgenticMessage
|
||||
isStreamingQuestions={
|
||||
message.isStreamingQuestions ?? false
|
||||
}
|
||||
isGenerating={
|
||||
message.is_generating ?? false
|
||||
}
|
||||
docSidebarToggled={
|
||||
documentSidebarToggled &&
|
||||
documentSidebarVisible &&
|
||||
(selectedMessageForDocDisplay ==
|
||||
message.messageId ||
|
||||
selectedMessageForDocDisplay ==
|
||||
@@ -2732,7 +2752,8 @@ export function ChatPage({
|
||||
setMessageAsLatest(messageId);
|
||||
}}
|
||||
isActive={
|
||||
messageHistory.length - 1 == i
|
||||
messageHistory.length - 1 == i ||
|
||||
messageHistory.length - 2 == i
|
||||
}
|
||||
selectedDocuments={selectedDocuments}
|
||||
toggleDocumentSelection={(
|
||||
@@ -2740,8 +2761,8 @@ export function ChatPage({
|
||||
) => {
|
||||
if (
|
||||
(!second &&
|
||||
!documentSidebarToggled) ||
|
||||
(documentSidebarToggled &&
|
||||
!documentSidebarVisible) ||
|
||||
(documentSidebarVisible &&
|
||||
selectedMessageForDocDisplay ===
|
||||
message.messageId)
|
||||
) {
|
||||
@@ -2749,8 +2770,8 @@ export function ChatPage({
|
||||
}
|
||||
if (
|
||||
(second &&
|
||||
!documentSidebarToggled) ||
|
||||
(documentSidebarToggled &&
|
||||
!documentSidebarVisible) ||
|
||||
(documentSidebarVisible &&
|
||||
selectedMessageForDocDisplay ===
|
||||
secondLevelMessage?.messageId)
|
||||
) {
|
||||
@@ -2851,8 +2872,8 @@ export function ChatPage({
|
||||
selectedDocuments={selectedDocuments}
|
||||
toggleDocumentSelection={() => {
|
||||
if (
|
||||
!documentSidebarToggled ||
|
||||
(documentSidebarToggled &&
|
||||
!documentSidebarVisible ||
|
||||
(documentSidebarVisible &&
|
||||
selectedMessageForDocDisplay ===
|
||||
message.messageId)
|
||||
) {
|
||||
@@ -3070,12 +3091,13 @@ export function ChatPage({
|
||||
<div className="mx-auto w-fit !pointer-events-none flex sticky justify-center">
|
||||
<button
|
||||
onClick={() => clientScrollToBottom()}
|
||||
className="p-1 pointer-events-auto rounded-2xl bg-background-strong border border-border mx-auto "
|
||||
className="p-1 pointer-events-auto text-neutral-700 dark:text-neutral-800 rounded-2xl bg-neutral-200 border border-border mx-auto "
|
||||
>
|
||||
<FiArrowDown size={18} />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="pointer-events-auto w-[95%] mx-auto relative mb-8">
|
||||
<ChatInputBar
|
||||
proSearchEnabled={proSearchEnabled}
|
||||
@@ -3147,7 +3169,7 @@ export function ChatPage({
|
||||
ease-in-out
|
||||
h-full
|
||||
${
|
||||
documentSidebarToggled && !settings?.isMobile
|
||||
documentSidebarVisible && !settings?.isMobile
|
||||
? "w-[350px]"
|
||||
: "w-[0px]"
|
||||
}
|
||||
@@ -3162,7 +3184,7 @@ export function ChatPage({
|
||||
style={{ transition: "width 0.30s ease-out" }}
|
||||
className={`flex-none bg-transparent transition-all bg-opacity-80 duration-300 ease-in-out h-full
|
||||
${
|
||||
toggledSidebar && !settings?.isMobile
|
||||
sidebarVisible && !settings?.isMobile
|
||||
? "w-[250px] "
|
||||
: "w-[0px]"
|
||||
}`}
|
||||
@@ -3174,7 +3196,7 @@ export function ChatPage({
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<FixedLogo backgroundToggled={toggledSidebar || showHistorySidebar} />
|
||||
<FixedLogo backgroundToggled={sidebarVisible || showHistorySidebar} />
|
||||
</div>
|
||||
{/* Right Sidebar - DocumentSidebar */}
|
||||
</div>
|
||||
|
||||
@@ -115,21 +115,61 @@ export function RefinemenetBadge({
|
||||
const isDone = displayedPhases.includes(StreamingPhase.COMPLETE);
|
||||
|
||||
// Expand/collapse, hover states
|
||||
const [expanded, setExpanded] = useState(true);
|
||||
const [expanded] = useState(true);
|
||||
const [toolTipHoveredInternal, setToolTipHoveredInternal] = useState(false);
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
const [shouldShow, setShouldShow] = useState(true);
|
||||
|
||||
// Refs for bounding area checks
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const tooltipRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Keep the tooltip open if hovered on container or tooltip
|
||||
// Remove the old onMouseLeave calls and rely on bounding area checks
|
||||
useEffect(() => {
|
||||
function handleMouseMove(e: MouseEvent) {
|
||||
if (!containerRef.current || !tooltipRef.current) return;
|
||||
|
||||
const containerRect = containerRef.current.getBoundingClientRect();
|
||||
const tooltipRect = tooltipRef.current.getBoundingClientRect();
|
||||
const [x, y] = [e.clientX, e.clientY];
|
||||
|
||||
const inContainer =
|
||||
x >= containerRect.left &&
|
||||
x <= containerRect.right &&
|
||||
y >= containerRect.top &&
|
||||
y <= containerRect.bottom;
|
||||
|
||||
const inTooltip =
|
||||
x >= tooltipRect.left &&
|
||||
x <= tooltipRect.right &&
|
||||
y >= tooltipRect.top &&
|
||||
y <= tooltipRect.bottom;
|
||||
|
||||
// If not hovering in either region, close tooltip
|
||||
if (!inContainer && !inTooltip) {
|
||||
setToolTipHoveredInternal(false);
|
||||
setToolTipHovered(false);
|
||||
setIsHovered(false);
|
||||
}
|
||||
}
|
||||
|
||||
window.addEventListener("mousemove", handleMouseMove);
|
||||
return () => {
|
||||
window.removeEventListener("mousemove", handleMouseMove);
|
||||
};
|
||||
}, [setToolTipHovered]);
|
||||
|
||||
// Once "done", hide after a short delay if not hovered
|
||||
useEffect(() => {
|
||||
if (isDone) {
|
||||
const timer = setTimeout(() => {
|
||||
setShouldShow(false);
|
||||
setCanShowResponse(true);
|
||||
}, 800); // e.g. 0.8s
|
||||
}, 800);
|
||||
return () => clearTimeout(timer);
|
||||
}
|
||||
}, [isDone, isHovered]);
|
||||
}, [isDone, isHovered, setCanShowResponse]);
|
||||
|
||||
if (!shouldShow) {
|
||||
return null; // entire box disappears
|
||||
@@ -137,13 +177,22 @@ export function RefinemenetBadge({
|
||||
|
||||
return (
|
||||
<TooltipProvider delayDuration={0}>
|
||||
{/*
|
||||
IMPORTANT: We rely on open={ isHovered || toolTipHoveredInternal }
|
||||
to keep the tooltip visible if either the badge or tooltip is hovered.
|
||||
*/}
|
||||
<Tooltip open={isHovered || toolTipHoveredInternal}>
|
||||
<div
|
||||
className="relative w-fit max-w-sm"
|
||||
onMouseEnter={() => setIsHovered(true)}
|
||||
onMouseLeave={() => setIsHovered(false)}
|
||||
ref={containerRef}
|
||||
// onMouseEnter keeps the tooltip open
|
||||
onMouseEnter={() => {
|
||||
setIsHovered(true);
|
||||
setToolTipHoveredInternal(true);
|
||||
setToolTipHovered(true);
|
||||
}}
|
||||
// Remove the explicit onMouseLeave – the global bounding check will close it
|
||||
>
|
||||
{/* Original snippet's tooltip usage */}
|
||||
<TooltipTrigger asChild>
|
||||
<div className="flex items-center gap-x-1 text-black text-sm font-medium cursor-pointer hover:text-blue-600 transition-colors duration-200">
|
||||
<p className="text-sm loading-text font-medium">
|
||||
@@ -159,36 +208,32 @@ export function RefinemenetBadge({
|
||||
</TooltipTrigger>
|
||||
{expanded && (
|
||||
<TooltipContent
|
||||
ref={tooltipRef}
|
||||
// onMouseEnter keeps the tooltip open when cursor enters tooltip
|
||||
onMouseEnter={() => {
|
||||
setToolTipHoveredInternal(true);
|
||||
setToolTipHovered(true);
|
||||
}}
|
||||
onMouseLeave={() => {
|
||||
setToolTipHoveredInternal(false);
|
||||
}}
|
||||
// Remove onMouseLeave and rely on bounding box logic to close
|
||||
side="bottom"
|
||||
align="start"
|
||||
className="w-fit -mt-1 p-4 bg-white border-2 border-border shadow-lg rounded-md"
|
||||
width="w-fit"
|
||||
className=" -mt-1 p-4 bg-[#fff] dark:bg-[#000] border-2 border-border dark:border-neutral-800 shadow-lg rounded-md"
|
||||
>
|
||||
{/* If not done, show the "Refining" box + a chevron */}
|
||||
|
||||
{/* Expanded area: each displayed phase in order */}
|
||||
|
||||
<div className="items-start flex flex-col gap-y-2">
|
||||
{currentState !== StreamingPhase.WAITING ? (
|
||||
Array.from(new Set(displayedPhases)).map((phase, index) => {
|
||||
const phaseIndex = displayedPhases.indexOf(phase);
|
||||
// The last displayed item is "running" if not COMPLETE
|
||||
let status = ToggleState.Done;
|
||||
if (
|
||||
index ===
|
||||
Array.from(new Set(displayedPhases)).length - 1
|
||||
Array.from(new Set(displayedPhases)).length - 1 &&
|
||||
phase !== StreamingPhase.COMPLETE
|
||||
) {
|
||||
status = ToggleState.InProgress;
|
||||
}
|
||||
if (phase === StreamingPhase.COMPLETE) {
|
||||
status = ToggleState.Done;
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
@@ -338,6 +383,7 @@ export function StatusRefinement({
|
||||
onMouseLeave={() => setToolTipHovered(false)}
|
||||
side="bottom"
|
||||
align="start"
|
||||
width="w-fit"
|
||||
className="w-fit p-4 bg-[#fff] border-2 border-border dark:border-neutral-800 shadow-lg rounded-md"
|
||||
>
|
||||
{/* If not done, show the "Refining" box + a chevron */}
|
||||
@@ -355,7 +401,6 @@ export function StatusRefinement({
|
||||
</div>
|
||||
<span className="text-neutral-800 text-sm font-medium">
|
||||
{StreamingPhaseText[phase]}
|
||||
LLL
|
||||
</span>
|
||||
</div>
|
||||
))}
|
||||
|
||||
@@ -5,18 +5,22 @@ import FunctionalWrapper from "../../components/chat/FunctionalWrapper";
|
||||
|
||||
export default function WrappedChat({
|
||||
firstMessage,
|
||||
defaultSidebarOff,
|
||||
}: {
|
||||
firstMessage?: string;
|
||||
// This is required for the chrome extension side panel
|
||||
// we don't want to show the sidebar by default when the user opens the side panel
|
||||
defaultSidebarOff?: boolean;
|
||||
}) {
|
||||
const { toggledSidebar } = useChatContext();
|
||||
const { sidebarInitiallyVisible } = useChatContext();
|
||||
|
||||
return (
|
||||
<FunctionalWrapper
|
||||
initiallyToggled={toggledSidebar}
|
||||
content={(toggledSidebar, toggle) => (
|
||||
initiallyVisible={sidebarInitiallyVisible && !defaultSidebarOff}
|
||||
content={(sidebarVisible, toggle) => (
|
||||
<ChatPage
|
||||
toggle={toggle}
|
||||
toggledSidebar={toggledSidebar}
|
||||
sidebarVisible={sidebarVisible}
|
||||
firstMessage={firstMessage}
|
||||
/>
|
||||
)}
|
||||
|
||||
@@ -86,19 +86,20 @@ export function AgenticToggle({
|
||||
</TooltipTrigger>
|
||||
<TooltipContent
|
||||
side="top"
|
||||
className="w-72 p-4 bg-white rounded-lg shadow-lg border border-background-200 dark:border-neutral-900"
|
||||
width="w-72"
|
||||
className="p-4 bg-white rounded-lg shadow-lg border border-background-200 dark:border-neutral-900"
|
||||
>
|
||||
<div className="flex items-center space-x-2 mb-3">
|
||||
<h3 className="text-sm font-semibold text-text-900">
|
||||
<h3 className="text-sm font-semibold text-neutral-900">
|
||||
Agent Search (BETA)
|
||||
</h3>
|
||||
</div>
|
||||
<p className="text-xs text-text-600 mb-2">
|
||||
<p className="text-xs text-neutarl-600 dark:text-neutral-700 mb-2">
|
||||
Use AI agents to break down questions and run deep iterative
|
||||
research through promising pathways. Gives more thorough and
|
||||
accurate responses but takes slightly longer.
|
||||
</p>
|
||||
<ul className="text-xs text-text-600 list-disc list-inside">
|
||||
<ul className="text-xs text-text-600 dark:text-neutral-700 list-disc list-inside">
|
||||
<li>Improved accuracy of search results</li>
|
||||
<li>Less hallucinations</li>
|
||||
<li>More comprehensive answers</li>
|
||||
|
||||
@@ -819,27 +819,17 @@ export function ChatInputBar({
|
||||
chatState == "toolBuilding" ||
|
||||
chatState == "loading"
|
||||
? chatState != "streaming"
|
||||
? "bg-neutral-900 dark:bg-neutral-400 "
|
||||
: "bg-neutral-500 dark:bg-neutral-50"
|
||||
: ""
|
||||
? "bg-neutral-500 dark:bg-neutral-400 "
|
||||
: "bg-neutral-900 dark:bg-neutral-50"
|
||||
: "bg-red-200"
|
||||
} h-[22px] w-[22px] rounded-full`}
|
||||
onClick={() => {
|
||||
if (
|
||||
chatState == "streaming" ||
|
||||
chatState == "toolBuilding" ||
|
||||
chatState == "loading"
|
||||
) {
|
||||
if (chatState == "streaming") {
|
||||
stopGenerating();
|
||||
} else if (message) {
|
||||
onSubmit();
|
||||
}
|
||||
}}
|
||||
disabled={
|
||||
(chatState == "streaming" ||
|
||||
chatState == "toolBuilding" ||
|
||||
chatState == "loading") &&
|
||||
chatState != "streaming"
|
||||
}
|
||||
>
|
||||
{chatState == "streaming" ||
|
||||
chatState == "toolBuilding" ||
|
||||
|
||||
@@ -110,6 +110,7 @@ export interface Message {
|
||||
second_level_message?: string;
|
||||
second_level_subquestions?: SubQuestionDetail[] | null;
|
||||
isImprovement?: boolean | null;
|
||||
isStreamingQuestions?: boolean;
|
||||
}
|
||||
|
||||
export interface BackendChatSession {
|
||||
@@ -219,6 +220,7 @@ export interface SubQuestionDetail extends BaseQuestionIdentifier {
|
||||
context_docs?: { top_documents: OnyxDocument[] } | null;
|
||||
is_complete?: boolean;
|
||||
is_stopped?: boolean;
|
||||
answer_streaming?: boolean;
|
||||
}
|
||||
|
||||
export interface SubQueryDetail {
|
||||
@@ -245,9 +247,6 @@ export const constructSubQuestions = (
|
||||
}
|
||||
|
||||
const updatedSubQuestions = [...subQuestions];
|
||||
// .filter(
|
||||
// (sq) => sq.level_question_num !== 0
|
||||
// );
|
||||
|
||||
if ("stop_reason" in newDetail) {
|
||||
const { level, level_question_num } = newDetail;
|
||||
@@ -255,8 +254,12 @@ export const constructSubQuestions = (
|
||||
(sq) => sq.level === level && sq.level_question_num === level_question_num
|
||||
);
|
||||
if (subQuestion) {
|
||||
subQuestion.is_complete = true;
|
||||
subQuestion.is_stopped = true;
|
||||
if (newDetail.stream_type == "sub_answer") {
|
||||
subQuestion.answer_streaming = false;
|
||||
} else {
|
||||
subQuestion.is_complete = true;
|
||||
subQuestion.is_stopped = true;
|
||||
}
|
||||
}
|
||||
} else if ("top_documents" in newDetail) {
|
||||
const { level, level_question_num, top_documents } = newDetail;
|
||||
|
||||
@@ -31,7 +31,7 @@ export default async function Layout({
|
||||
llmProviders,
|
||||
folders,
|
||||
openedFolders,
|
||||
toggleSidebar,
|
||||
sidebarInitiallyVisible,
|
||||
defaultAssistantId,
|
||||
shouldShowWelcomeModal,
|
||||
ccPairs,
|
||||
@@ -47,7 +47,7 @@ export default async function Layout({
|
||||
proSearchToggled,
|
||||
inputPrompts,
|
||||
chatSessions,
|
||||
toggledSidebar: toggleSidebar,
|
||||
sidebarInitiallyVisible,
|
||||
availableSources,
|
||||
ccPairs,
|
||||
documentSets,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user