mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-21 01:35:46 +00:00
Compare commits
20 Commits
monitoring
...
bugfix/ves
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ac3ec7575 | ||
|
|
e9492ce9ec | ||
|
|
35574369ed | ||
|
|
eff433bdc5 | ||
|
|
3260d793d1 | ||
|
|
1a7aca06b9 | ||
|
|
c6434db7eb | ||
|
|
667b9e04c5 | ||
|
|
29c84d7707 | ||
|
|
17c915b11b | ||
|
|
95ca592d6d | ||
|
|
e39a27fd6b | ||
|
|
26d3c952c6 | ||
|
|
53683e2f3c | ||
|
|
0c0113a481 | ||
|
|
c0f381e471 | ||
|
|
5ed83f1148 | ||
|
|
9db7b67a6c | ||
|
|
2850048c6b | ||
|
|
61058e5fcd |
@@ -65,6 +65,7 @@ jobs:
|
||||
NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }}
|
||||
NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }}
|
||||
NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }}
|
||||
NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }}
|
||||
NEXT_PUBLIC_GTM_ENABLED=true
|
||||
NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true
|
||||
NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true
|
||||
|
||||
@@ -4,9 +4,6 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
paths:
|
||||
- 'backend/model_server/**'
|
||||
- 'backend/Dockerfile.model_server'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
@@ -15,7 +12,32 @@ env:
|
||||
BUILDKIT_PROGRESS: plain
|
||||
|
||||
jobs:
|
||||
# 1) Preliminary job to check if the changed files are relevant
|
||||
check_model_server_changes:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
changed: ${{ steps.check.outputs.changed }}
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check if relevant files changed
|
||||
id: check
|
||||
run: |
|
||||
# Default to "false"
|
||||
echo "changed=false" >> $GITHUB_OUTPUT
|
||||
|
||||
# Compare the previous commit (github.event.before) to the current one (github.sha)
|
||||
# If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
|
||||
# set changed=true
|
||||
if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
|
||||
| grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
|
||||
echo "changed=true" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
build-amd64:
|
||||
needs: [check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-amd64"]
|
||||
steps:
|
||||
@@ -55,6 +77,8 @@ jobs:
|
||||
provenance: false
|
||||
|
||||
build-arm64:
|
||||
needs: [check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on:
|
||||
[runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-arm64"]
|
||||
steps:
|
||||
@@ -94,7 +118,8 @@ jobs:
|
||||
provenance: false
|
||||
|
||||
merge-and-scan:
|
||||
needs: [build-amd64, build-arm64]
|
||||
needs: [build-amd64, build-arm64, check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Login to Docker Hub
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Add background errors table
|
||||
|
||||
Revision ID: f39c5794c10a
|
||||
Revises: 2cdeff6d8c93
|
||||
Create Date: 2025-02-12 17:11:14.527876
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f39c5794c10a"
|
||||
down_revision = "2cdeff6d8c93"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"background_error",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("message", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("cc_pair_id", sa.Integer(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["cc_pair_id"],
|
||||
["connector_credential_pair.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("background_error")
|
||||
@@ -1,10 +1,10 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
beat_system_tasks as base_beat_system_tasks,
|
||||
beat_cloud_tasks as base_beat_system_tasks,
|
||||
)
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
beat_task_templates as base_beat_task_templates,
|
||||
)
|
||||
|
||||
@@ -77,3 +77,5 @@ POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com"
|
||||
HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL")
|
||||
|
||||
ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user"
|
||||
|
||||
GATED_TENANTS_KEY = "gated_tenants"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
@@ -10,7 +11,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _build_group_member_email_map(
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_client: OnyxConfluence, cc_pair_id: int
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
@@ -18,8 +19,11 @@ def _build_group_member_email_map(
|
||||
|
||||
user = user_result.get("user", {})
|
||||
if not user:
|
||||
logger.warning(f"user result missing user field: {user_result}")
|
||||
msg = f"user result missing user field: {user_result}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
continue
|
||||
|
||||
email = user.get("email")
|
||||
if not email:
|
||||
# This field is only present in Confluence Server
|
||||
@@ -32,7 +36,12 @@ def _build_group_member_email_map(
|
||||
)
|
||||
if not email:
|
||||
# If we still don't have an email, skip this user
|
||||
logger.warning(f"user result missing email field: {user_result}")
|
||||
msg = f"user result missing email field: {user_result}"
|
||||
if user.get("type") == "app":
|
||||
logger.warning(msg)
|
||||
else:
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
continue
|
||||
|
||||
all_users_groups: set[str] = set()
|
||||
@@ -42,11 +51,18 @@ def _build_group_member_email_map(
|
||||
group_member_emails.setdefault(group_id, set()).add(email)
|
||||
all_users_groups.add(group_id)
|
||||
|
||||
if not group_member_emails:
|
||||
logger.warning(f"No groups found for user with email: {email}")
|
||||
if not all_users_groups:
|
||||
msg = f"No groups found for user with email: {email}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
else:
|
||||
logger.debug(f"Found groups {all_users_groups} for user with email {email}")
|
||||
|
||||
if not group_member_emails:
|
||||
msg = "No groups found for any users."
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
|
||||
return group_member_emails
|
||||
|
||||
|
||||
@@ -61,6 +77,7 @@ def confluence_group_sync(
|
||||
|
||||
group_member_email_map = _build_group_member_email_map(
|
||||
confluence_client=confluence_client,
|
||||
cc_pair_id=cc_pair.id,
|
||||
)
|
||||
onyx_groups: list[ExternalUserGroup] = []
|
||||
all_found_emails = set()
|
||||
|
||||
@@ -83,6 +83,7 @@ def handle_search_request(
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
skip_query_analysis=False,
|
||||
db_session=db_session,
|
||||
bypass_acl=False,
|
||||
)
|
||||
|
||||
@@ -18,11 +18,16 @@ from ee.onyx.server.tenants.anonymous_user_path import (
|
||||
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import AnonymousUserPath
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
@@ -39,12 +44,9 @@ from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.settings.store import store_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
@@ -126,37 +128,29 @@ async def login_as_anonymous_user(
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
) -> None:
|
||||
) -> ProductGatingResponse:
|
||||
"""
|
||||
Gating the product means that the product is not available to the tenant.
|
||||
They will be directed to the billing page.
|
||||
We gate the product when
|
||||
1) User has ended free trial without adding payment method
|
||||
2) User's card has declined
|
||||
We gate the product when their subscription has ended.
|
||||
"""
|
||||
tenant_id = product_gating_request.tenant_id
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
store_product_gating(
|
||||
product_gating_request.tenant_id, product_gating_request.application_status
|
||||
)
|
||||
return ProductGatingResponse(updated=True, error=None)
|
||||
|
||||
settings = load_settings()
|
||||
settings.product_gating = product_gating_request.product_gating
|
||||
store_settings(settings)
|
||||
|
||||
if product_gating_request.notification:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
create_notification(None, product_gating_request.notification, db_session)
|
||||
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
except Exception as e:
|
||||
logger.exception("Failed to gate product")
|
||||
return ProductGatingResponse(updated=False, error=str(e))
|
||||
|
||||
|
||||
@router.get("/billing-information", response_model=BillingInformation)
|
||||
@router.get("/billing-information")
|
||||
async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation:
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
return BillingInformation(
|
||||
**fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
|
||||
)
|
||||
return fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
@@ -169,9 +163,10 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/cloud-settings",
|
||||
return_url=f"{WEB_DOMAIN}/admin/billing",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
@@ -180,6 +175,20 @@ async def create_customer_portal_session(_: User = Depends(current_admin_user))
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/impersonate")
|
||||
async def impersonate_user(
|
||||
impersonate_request: ImpersonateRequest,
|
||||
|
||||
@@ -6,6 +6,7 @@ import stripe
|
||||
from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -14,6 +15,19 @@ stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def fetch_stripe_checkout_session(tenant_id: str) -> str:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session"
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.post(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return response.json()["sessionId"]
|
||||
|
||||
|
||||
def fetch_tenant_stripe_information(tenant_id: str) -> dict:
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
@@ -27,7 +41,7 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
|
||||
return response.json()
|
||||
|
||||
|
||||
def fetch_billing_information(tenant_id: str) -> dict:
|
||||
def fetch_billing_information(tenant_id: str) -> BillingInformation:
|
||||
logger.info("Fetching billing information")
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
@@ -38,7 +52,7 @@ def fetch_billing_information(tenant_id: str) -> dict:
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
billing_info = response.json()
|
||||
billing_info = BillingInformation(**response.json())
|
||||
return billing_info
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.server.settings.models import GatingType
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
|
||||
|
||||
class CheckoutSessionCreationRequest(BaseModel):
|
||||
@@ -15,15 +16,24 @@ class CreateTenantRequest(BaseModel):
|
||||
|
||||
class ProductGatingRequest(BaseModel):
|
||||
tenant_id: str
|
||||
product_gating: GatingType
|
||||
notification: NotificationType | None = None
|
||||
application_status: ApplicationStatus
|
||||
|
||||
|
||||
class SubscriptionStatusResponse(BaseModel):
|
||||
subscribed: bool
|
||||
|
||||
|
||||
class BillingInformation(BaseModel):
|
||||
stripe_subscription_id: str
|
||||
status: str
|
||||
current_period_start: datetime
|
||||
current_period_end: datetime
|
||||
number_of_seats: int
|
||||
cancel_at_period_end: bool
|
||||
canceled_at: datetime | None
|
||||
trial_start: datetime | None
|
||||
trial_end: datetime | None
|
||||
seats: int
|
||||
subscription_status: str
|
||||
billing_start: str
|
||||
billing_end: str
|
||||
payment_method_enabled: bool
|
||||
|
||||
|
||||
@@ -48,3 +58,12 @@ class TenantDeletionPayload(BaseModel):
|
||||
|
||||
class AnonymousUserPath(BaseModel):
|
||||
anonymous_user_path: str | None
|
||||
|
||||
|
||||
class ProductGatingResponse(BaseModel):
|
||||
updated: bool
|
||||
error: str | None
|
||||
|
||||
|
||||
class SubscriptionSessionResponse(BaseModel):
|
||||
sessionId: str
|
||||
|
||||
51
backend/ee/onyx/server/tenants/product_gating.py
Normal file
51
backend/ee/onyx/server/tenants/product_gating.py
Normal file
@@ -0,0 +1,51 @@
|
||||
from typing import cast
|
||||
|
||||
from ee.onyx.configs.app_configs import GATED_TENANTS_KEY
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.settings.store import store_settings
|
||||
from onyx.setup import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def update_tenant_gating(tenant_id: str, status: ApplicationStatus) -> None:
|
||||
redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
# Store the full status
|
||||
status_key = f"tenant:{tenant_id}:status"
|
||||
redis_client.set(status_key, status.value)
|
||||
|
||||
# Maintain the GATED_ACCESS set
|
||||
if status == ApplicationStatus.GATED_ACCESS:
|
||||
redis_client.sadd(GATED_TENANTS_KEY, tenant_id)
|
||||
else:
|
||||
redis_client.srem(GATED_TENANTS_KEY, tenant_id)
|
||||
|
||||
|
||||
def store_product_gating(tenant_id: str, application_status: ApplicationStatus) -> None:
|
||||
try:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
settings = load_settings()
|
||||
settings.application_status = application_status
|
||||
store_settings(settings)
|
||||
|
||||
# Store gated tenant information in Redis
|
||||
update_tenant_gating(tenant_id, application_status)
|
||||
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Failed to gate product")
|
||||
raise
|
||||
|
||||
|
||||
def get_gated_tenants() -> set[str]:
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))
|
||||
@@ -21,10 +21,11 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
from onyx.configs.agent_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
|
||||
from onyx.configs.agent_configs import AGENT_RERANKING_STATS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.pipeline import retrieval_preprocessing
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.postprocessing.postprocessing import rerank_sections
|
||||
from onyx.context.search.postprocessing.postprocessing import should_rerank
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
|
||||
|
||||
def rerank_documents(
|
||||
@@ -39,6 +40,8 @@ def rerank_documents(
|
||||
|
||||
# Rerank post retrieval and verification. First, create a search query
|
||||
# then create the list of reranked sections
|
||||
# If no question defined/question is None in the state, use the original
|
||||
# question from the search request as query
|
||||
|
||||
graph_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
question = (
|
||||
@@ -47,39 +50,28 @@ def rerank_documents(
|
||||
assert (
|
||||
graph_config.tooling.search_tool
|
||||
), "search_tool must be provided for agentic search"
|
||||
with get_session_context_manager() as db_session:
|
||||
# we ignore some of the user specified fields since this search is
|
||||
# internal to agentic search, but we still want to pass through
|
||||
# persona (for stuff like document sets) and rerank settings
|
||||
# (to not make an unnecessary db call).
|
||||
search_request = SearchRequest(
|
||||
query=question,
|
||||
persona=graph_config.inputs.search_request.persona,
|
||||
rerank_settings=graph_config.inputs.search_request.rerank_settings,
|
||||
)
|
||||
_search_query = retrieval_preprocessing(
|
||||
search_request=search_request,
|
||||
user=graph_config.tooling.search_tool.user, # bit of a hack
|
||||
llm=graph_config.tooling.fast_llm,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# skip section filtering
|
||||
# Note that these are passed in values from the API and are overrides which are typically None
|
||||
rerank_settings = graph_config.inputs.search_request.rerank_settings
|
||||
|
||||
if (
|
||||
_search_query.rerank_settings
|
||||
and _search_query.rerank_settings.rerank_model_name
|
||||
and _search_query.rerank_settings.num_rerank > 0
|
||||
and len(verified_documents) > 0
|
||||
):
|
||||
if rerank_settings is None:
|
||||
with get_session_context_manager() as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
|
||||
if should_rerank(rerank_settings) and len(verified_documents) > 0:
|
||||
if len(verified_documents) > 1:
|
||||
reranked_documents = rerank_sections(
|
||||
_search_query,
|
||||
verified_documents,
|
||||
query_str=question,
|
||||
# if runnable, then rerank_settings is not None
|
||||
rerank_settings=cast(RerankingDetails, rerank_settings),
|
||||
sections_to_rerank=verified_documents,
|
||||
)
|
||||
else:
|
||||
num = "No" if len(verified_documents) == 0 else "One"
|
||||
logger.warning(f"{num} verified document(s) found, skipping reranking")
|
||||
logger.warning(
|
||||
f"{len(verified_documents)} verified document(s) found, skipping reranking"
|
||||
)
|
||||
reranked_documents = verified_documents
|
||||
else:
|
||||
logger.warning("No reranking settings found, using unranked documents")
|
||||
|
||||
@@ -23,6 +23,7 @@ from onyx.configs.agent_configs import AGENT_RETRIEVAL_STATS
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
@@ -67,9 +68,12 @@ def retrieve_documents(
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=query_to_retrieve,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=callback_container.append,
|
||||
skip_query_analysis=not state.base_search,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
|
||||
@@ -58,6 +58,7 @@ from onyx.prompts.agent_search import (
|
||||
)
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
@@ -218,7 +219,10 @@ def get_test_config(
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
)
|
||||
|
||||
chat_session_id = os.environ.get("ONYX_AS_CHAT_SESSION_ID")
|
||||
chat_session_id = (
|
||||
os.environ.get("ONYX_AS_CHAT_SESSION_ID")
|
||||
or "00000000-0000-0000-0000-000000000000"
|
||||
)
|
||||
assert (
|
||||
chat_session_id is not None
|
||||
), "ONYX_AS_CHAT_SESSION_ID must be set for backend tests"
|
||||
@@ -341,8 +345,12 @@ def retrieve_search_docs(
|
||||
with get_session_context_manager() as db_session:
|
||||
for tool_response in search_tool.run(
|
||||
query=question,
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
override_kwargs=SearchToolOverrideKwargs(
|
||||
force_no_rerank=True,
|
||||
alternate_db_session=db_session,
|
||||
retrieved_sections_callback=None,
|
||||
skip_query_analysis=False,
|
||||
),
|
||||
):
|
||||
# get retrieved docs to send to the rest of the graph
|
||||
if tool_response.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import smtplib
|
||||
from datetime import datetime
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from textwrap import dedent
|
||||
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
@@ -13,23 +13,150 @@ from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.models import User
|
||||
|
||||
HTML_EMAIL_TEMPLATE = """\
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width" />
|
||||
<title>{title}</title>
|
||||
<style>
|
||||
body, table, td, a {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
|
||||
text-size-adjust: 100%;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
-webkit-font-smoothing: antialiased;
|
||||
-webkit-text-size-adjust: none;
|
||||
}}
|
||||
body {{
|
||||
background-color: #f7f7f7;
|
||||
color: #333;
|
||||
}}
|
||||
.body-content {{
|
||||
color: #333;
|
||||
}}
|
||||
.email-container {{
|
||||
width: 100%;
|
||||
max-width: 600px;
|
||||
margin: 0 auto;
|
||||
background-color: #ffffff;
|
||||
border-radius: 6px;
|
||||
overflow: hidden;
|
||||
border: 1px solid #eaeaea;
|
||||
}}
|
||||
.header {{
|
||||
background-color: #000000;
|
||||
padding: 20px;
|
||||
text-align: center;
|
||||
}}
|
||||
.header img {{
|
||||
max-width: 140px;
|
||||
}}
|
||||
.body-content {{
|
||||
padding: 20px 30px;
|
||||
}}
|
||||
.title {{
|
||||
font-size: 20px;
|
||||
font-weight: bold;
|
||||
margin: 0 0 10px;
|
||||
}}
|
||||
.message {{
|
||||
font-size: 16px;
|
||||
line-height: 1.5;
|
||||
margin: 0 0 20px;
|
||||
}}
|
||||
.cta-button {{
|
||||
display: inline-block;
|
||||
padding: 12px 20px;
|
||||
background-color: #000000;
|
||||
color: #ffffff !important;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
font-weight: 500;
|
||||
}}
|
||||
.footer {{
|
||||
font-size: 13px;
|
||||
color: #6A7280;
|
||||
text-align: center;
|
||||
padding: 20px;
|
||||
}}
|
||||
.footer a {{
|
||||
color: #6b7280;
|
||||
text-decoration: underline;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<table role="presentation" class="email-container" cellpadding="0" cellspacing="0">
|
||||
<tr>
|
||||
<td class="header">
|
||||
<img
|
||||
style="background-color: #ffffff; border-radius: 8px;"
|
||||
src="https://www.onyx.app/logos/customer/onyx.png"
|
||||
alt="Onyx Logo"
|
||||
>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="body-content">
|
||||
<h1 class="title">{heading}</h1>
|
||||
<div class="message">
|
||||
{message}
|
||||
</div>
|
||||
{cta_block}
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="footer">
|
||||
© {year} Onyx. All rights reserved.
|
||||
<br>
|
||||
Have questions? Join our Slack community <a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA">here</a>.
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def build_html_email(
|
||||
heading: str, message: str, cta_text: str | None = None, cta_link: str | None = None
|
||||
) -> str:
|
||||
if cta_text and cta_link:
|
||||
cta_block = f'<a class="cta-button" href="{cta_link}">{cta_text}</a>'
|
||||
else:
|
||||
cta_block = ""
|
||||
return HTML_EMAIL_TEMPLATE.format(
|
||||
title=heading,
|
||||
heading=heading,
|
||||
message=message,
|
||||
cta_block=cta_block,
|
||||
year=datetime.now().year,
|
||||
)
|
||||
|
||||
|
||||
def send_email(
|
||||
user_email: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
html_body: str,
|
||||
text_body: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
if not EMAIL_CONFIGURED:
|
||||
raise ValueError("Email is not configured.")
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
msg.attach(MIMEText(body))
|
||||
part_text = MIMEText(text_body, "plain")
|
||||
part_html = MIMEText(html_body, "html")
|
||||
|
||||
msg.attach(part_text)
|
||||
msg.attach(part_html)
|
||||
|
||||
try:
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
@@ -40,26 +167,44 @@ def send_email(
|
||||
raise e
|
||||
|
||||
|
||||
def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
# Example usage of the reusable HTML
|
||||
subject = "Your Onyx Subscription Has Been Canceled"
|
||||
heading = "Subscription Canceled"
|
||||
message = (
|
||||
"<p>We’re sorry to see you go.</p>"
|
||||
"<p>Your subscription has been canceled and will end on your next billing date.</p>"
|
||||
"<p>If you change your mind, you can always come back!</p>"
|
||||
)
|
||||
cta_text = "Renew Subscription"
|
||||
cta_link = "https://www.onyx.app/pricing"
|
||||
html_content = build_html_email(heading, message, cta_text, cta_link)
|
||||
text_content = (
|
||||
"We're sorry to see you go.\n"
|
||||
"Your subscription has been canceled and will end on your next billing date.\n"
|
||||
"If you change your mind, visit https://www.onyx.app/pricing"
|
||||
)
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
|
||||
|
||||
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
||||
subject = "Invitation to Join Onyx Organization"
|
||||
body = dedent(
|
||||
f"""\
|
||||
Hello,
|
||||
|
||||
You have been invited to join an organization on Onyx.
|
||||
|
||||
To join the organization, please visit the following link:
|
||||
|
||||
{WEB_DOMAIN}/auth/signup?email={user_email}
|
||||
|
||||
You'll be asked to set a password or login with Google to complete your registration.
|
||||
|
||||
Best regards,
|
||||
The Onyx Team
|
||||
"""
|
||||
heading = "You've Been Invited!"
|
||||
message = (
|
||||
f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
|
||||
"<p>To join the organization, please click the button below to set a password "
|
||||
"or login with Google and complete your registration.</p>"
|
||||
)
|
||||
|
||||
send_email(user_email, subject, body, current_user.email)
|
||||
cta_text = "Join Organization"
|
||||
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
|
||||
html_content = build_html_email(heading, message, cta_text, cta_link)
|
||||
text_content = (
|
||||
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
|
||||
"To join the organization, please visit the following link:\n"
|
||||
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
|
||||
"You'll be asked to set a password or login with Google to complete your registration."
|
||||
)
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
|
||||
|
||||
def send_forgot_password_email(
|
||||
@@ -68,13 +213,15 @@ def send_forgot_password_email(
|
||||
mail_from: str = EMAIL_FROM,
|
||||
tenant_id: str | None = None,
|
||||
) -> None:
|
||||
# Builds a forgot password email with or without fancy HTML
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
if tenant_id:
|
||||
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
|
||||
# Keep search param same name as cookie for simplicity
|
||||
body = f"Click the following link to reset your password: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
|
||||
html_content = build_html_email("Reset Your Password", message)
|
||||
text_content = f"Click the following link to reset your password: {link}"
|
||||
send_email(user_email, subject, html_content, text_content, mail_from)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
@@ -82,7 +229,12 @@ def send_user_verification_email(
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
# Builds a verification email
|
||||
subject = "Onyx Email Verification"
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
body = f"Click the following link to verify your email address: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
message = (
|
||||
f"<p>Click the following link to verify your email address:</p><p>{link}</p>"
|
||||
)
|
||||
html_content = build_html_email("Verify Your Email", message)
|
||||
text_content = f"Click the following link to verify your email address: {link}"
|
||||
send_email(user_email, subject, html_content, text_content, mail_from)
|
||||
|
||||
@@ -144,7 +144,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
# As currently designed, when this worker starts as "primary", we reinitialize redis
|
||||
# to a clean state (for our purposes, anyway)
|
||||
r.delete(OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK)
|
||||
r.delete(OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK)
|
||||
|
||||
r.delete(OnyxRedisConstants.ACTIVE_FENCES)
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# hack to slow down task dispatch in the cloud until
|
||||
# we have a better implementation (backpressure, etc)
|
||||
# Note that DynamicTenantScheduler can adjust the runtime value for this via Redis
|
||||
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
|
||||
|
||||
# tasks that run in either self-hosted on cloud
|
||||
@@ -56,16 +57,7 @@ beat_task_templates.extend(
|
||||
{
|
||||
"name": "check-for-pruning",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
@@ -141,14 +133,14 @@ def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
|
||||
return cloud_task
|
||||
|
||||
|
||||
# tasks that only run in the cloud
|
||||
# tasks that only run in the cloud and are system wide
|
||||
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be seen
|
||||
# by the DynamicTenantScheduler as system wide task and not a per tenant task
|
||||
beat_system_tasks: list[dict] = [
|
||||
beat_cloud_tasks: list[dict] = [
|
||||
# cloud specific tasks
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
|
||||
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-alembic",
|
||||
"task": OnyxCeleryTask.CLOUD_MONITOR_ALEMBIC,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
@@ -156,11 +148,37 @@ beat_system_tasks: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-celery-queues",
|
||||
"task": OnyxCeleryTask.CLOUD_MONITOR_CELERY_QUEUES,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# tasks that only run self hosted
|
||||
tasks_to_schedule: list[dict] = []
|
||||
if not MULTI_TENANT:
|
||||
tasks_to_schedule = beat_task_templates
|
||||
tasks_to_schedule.extend(
|
||||
[
|
||||
{
|
||||
"name": "monitor-celery-queues",
|
||||
"task": OnyxCeleryTask.MONITOR_CELERY_QUEUES,
|
||||
"schedule": timedelta(seconds=10),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
tasks_to_schedule.extend(beat_task_templates)
|
||||
|
||||
|
||||
def generate_cloud_tasks(
|
||||
@@ -180,23 +198,24 @@ def generate_cloud_tasks(
|
||||
if beat_multiplier <= 0:
|
||||
raise ValueError("beat_multiplier must be positive!")
|
||||
|
||||
# start with the incoming beat tasks
|
||||
cloud_tasks: list[dict] = copy.deepcopy(beat_tasks)
|
||||
cloud_tasks: list[dict] = []
|
||||
|
||||
# generate our cloud tasks from the templates
|
||||
# generate our tenant aware cloud tasks from the templates
|
||||
for beat_template in beat_templates:
|
||||
cloud_task = make_cloud_generator_task(beat_template)
|
||||
cloud_tasks.append(cloud_task)
|
||||
|
||||
# factor in the cloud multiplier
|
||||
# factor in the cloud multiplier for the above
|
||||
for cloud_task in cloud_tasks:
|
||||
cloud_task["schedule"] = cloud_task["schedule"] * beat_multiplier
|
||||
|
||||
# add the fixed cloud/system beat tasks. No multiplier for these.
|
||||
cloud_tasks.extend(copy.deepcopy(beat_tasks))
|
||||
return cloud_tasks
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
|
||||
return generate_cloud_tasks(beat_system_tasks, beat_task_templates, beat_multiplier)
|
||||
return generate_cloud_tasks(beat_cloud_tasks, beat_task_templates, beat_multiplier)
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -12,18 +16,35 @@ from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.connector_credential_pair import add_deletion_failure_message
|
||||
from onyx.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.index_attempt import delete_index_attempts
|
||||
from onyx.db.search_settings import get_all_search_settings
|
||||
from onyx.db.sync_record import cleanup_sync_records
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.variable_functionality import (
|
||||
fetch_versioned_implementation_with_fallback,
|
||||
)
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
|
||||
|
||||
class TaskDependencyError(RuntimeError):
|
||||
@@ -42,6 +63,7 @@ def check_for_connector_deletion_task(
|
||||
self: Task, *, tenant_id: str | None
|
||||
) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
@@ -77,6 +99,18 @@ def check_for_connector_deletion_task(
|
||||
# clear the stop signal if it exists ... no longer needed
|
||||
redis_connector.stop.set_fence(False)
|
||||
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
if not r.exists(key_bytes):
|
||||
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -212,3 +246,158 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
redis_connector.delete.set_fence(fence_payload)
|
||||
|
||||
return tasks_generated
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
fence_data = redis_connector.delete.payload
|
||||
if not fence_data:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
if fence_data.num_tasks is None:
|
||||
# the fence is setting up but isn't ready yet
|
||||
return
|
||||
|
||||
remaining = redis_connector.delete.get_remaining()
|
||||
task_logger.info(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
|
||||
)
|
||||
if remaining > 0:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.IN_PROGRESS,
|
||||
num_docs_synced=remaining,
|
||||
)
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
doc_ids = get_document_ids_for_connector_credential_pair(
|
||||
db_session, cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
if len(doc_ids) > 0:
|
||||
# NOTE(rkuo): if this happens, documents somehow got added while
|
||||
# deletion was in progress. Likely a bug gating off pruning and indexing
|
||||
# work before deletion starts.
|
||||
task_logger.warning(
|
||||
"Connector deletion - documents still found after taskset completion. "
|
||||
"Clearing the current deletion attempt and allowing deletion to restart: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"docs_deleted={fence_data.num_tasks} "
|
||||
f"docs_remaining={len(doc_ids)}"
|
||||
)
|
||||
|
||||
# We don't want to waive off why we get into this state, but resetting
|
||||
# our attempt and letting the deletion restart is a good way to recover
|
||||
redis_connector.delete.reset()
|
||||
raise RuntimeError(
|
||||
"Connector deletion - documents still found after taskset completion"
|
||||
)
|
||||
|
||||
# clean up the rest of the related Postgres entities
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# user groups
|
||||
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
|
||||
"onyx.db.user_group",
|
||||
"delete_user_group_cc_pair_relationship__no_commit",
|
||||
noop_fallback,
|
||||
)
|
||||
cleanup_user_groups(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Connector deletion - Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
stack_trace = traceback.format_exc()
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair_id, error_message)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.FAILED,
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
task_logger.exception(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Connector deletion succeeded: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"docs_deleted={fence_data.num_tasks}"
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
|
||||
@@ -175,6 +175,24 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
)
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=300)
|
||||
|
||||
# use a lookup table to find active fences. We still have to verify the fence
|
||||
# exists since it is an optimization and not the source of truth.
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
if not r.exists(key_bytes):
|
||||
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(
|
||||
tenant_id, key_bytes, r, db_session
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -349,6 +367,7 @@ def connector_permission_sync_generator_task(
|
||||
OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
@@ -756,7 +775,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
raise
|
||||
|
||||
|
||||
"""Monitoring CCPair permissions utils, called in monitor_vespa_sync"""
|
||||
"""Monitoring CCPair permissions utils"""
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
|
||||
@@ -26,11 +26,11 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -72,18 +72,26 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
"""Returns boolean indicating if external group sync is due."""
|
||||
|
||||
if cc_pair.access_type != AccessType.SYNC:
|
||||
return False
|
||||
|
||||
# skip external group sync if not active
|
||||
if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE:
|
||||
task_logger.error(
|
||||
f"Recieved non-sync CC Pair {cc_pair.id} for external "
|
||||
f"group sync. Actual access type: {cc_pair.access_type}"
|
||||
)
|
||||
return False
|
||||
|
||||
if cc_pair.status == ConnectorCredentialPairStatus.DELETING:
|
||||
task_logger.debug(
|
||||
f"Skipping group sync for CC Pair {cc_pair.id} - "
|
||||
f"CC Pair is being deleted"
|
||||
)
|
||||
return False
|
||||
|
||||
# If there is not group sync function for the connector, we don't run the sync
|
||||
# This is fine because all sources dont necessarily have a concept of groups
|
||||
if not GROUP_PERMISSIONS_FUNC_MAP.get(cc_pair.connector.source):
|
||||
task_logger.debug(
|
||||
f"Skipping group sync for CC Pair {cc_pair.id} - "
|
||||
f"no group sync function for {cc_pair.connector.source}"
|
||||
)
|
||||
return False
|
||||
|
||||
# If the last sync is None, it has never been run so we run the sync
|
||||
@@ -125,6 +133,9 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
task_logger.warning(
|
||||
f"Failed to acquire beat lock for external group sync: {tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -205,20 +216,12 @@ def try_creating_external_group_sync_task(
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_external_group_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Dont kick off a new sync if the previous one is still running
|
||||
if redis_connector.external_group_sync.fenced:
|
||||
logger.warning(
|
||||
f"Skipping external group sync for CC Pair {cc_pair_id} - already running."
|
||||
)
|
||||
return None
|
||||
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
@@ -269,9 +272,6 @@ def try_creating_external_group_sync_task(
|
||||
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return payload_id
|
||||
|
||||
@@ -304,22 +304,26 @@ def connector_external_group_sync_generator_task(
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"connector_external_group_sync_generator_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
if not redis_connector.external_group_sync.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
msg = (
|
||||
f"connector_external_group_sync_generator_task - fence not found: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
payload = redis_connector.external_group_sync.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError(
|
||||
"connector_external_group_sync_generator_task: payload invalid or not found"
|
||||
)
|
||||
msg = "connector_external_group_sync_generator_task: payload invalid or not found"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
@@ -344,9 +348,9 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
msg = f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
task_logger.error(msg)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -367,9 +371,9 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
if ext_group_sync_func is None:
|
||||
raise ValueError(
|
||||
f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
msg = f"No external group sync func found for {source_type} for cc_pair: {cc_pair_id}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
raise ValueError(msg)
|
||||
|
||||
logger.info(
|
||||
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
@@ -400,9 +404,9 @@ def connector_external_group_sync_generator_task(
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
|
||||
)
|
||||
msg = f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
|
||||
task_logger.exception(msg)
|
||||
emit_background_error(msg + f"\n\n{e}", cc_pair_id=cc_pair_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
@@ -492,9 +496,11 @@ def validate_external_group_sync_fence(
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
msg = (
|
||||
f"validate_external_group_sync_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
emit_background_error(msg)
|
||||
task_logger.error(msg)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
@@ -509,12 +515,14 @@ def validate_external_group_sync_fence(
|
||||
try:
|
||||
payload = redis_connector.external_group_sync.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
msg = (
|
||||
"validate_external_group_sync_fence - "
|
||||
"Resetting fence because fence schema is out of date: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
task_logger.exception(msg)
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
|
||||
redis_connector.external_group_sync.reset()
|
||||
return
|
||||
@@ -551,12 +559,15 @@ def validate_external_group_sync_fence(
|
||||
# return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
logger.warning(
|
||||
"validate_external_group_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
emit_background_error(
|
||||
message=(
|
||||
"validate_external_group_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
),
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.reset()
|
||||
|
||||
@@ -6,13 +6,18 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import sentry_sdk
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.result import AsyncResult
|
||||
from celery.states import READY_STATES
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
@@ -30,6 +35,7 @@ from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
@@ -37,6 +43,7 @@ from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingMode
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
@@ -47,9 +54,12 @@ from onyx.db.swap_index import check_index_swap
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_utils import is_fence
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
@@ -60,6 +70,150 @@ from shared_configs.configs import SENTRY_DSN
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"Connector indexing: could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
parts = composite_id.split("/")
|
||||
if len(parts) != 2:
|
||||
return
|
||||
|
||||
cc_pair_id = int(parts[0])
|
||||
search_settings_id = int(parts[1])
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
if not redis_connector_index.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
elapsed_started_str = None
|
||||
if payload.started:
|
||||
elapsed_started = datetime.now(timezone.utc) - payload.started
|
||||
elapsed_started_str = f"{elapsed_started.total_seconds():.2f}"
|
||||
|
||||
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
|
||||
|
||||
progress = redis_connector_index.get_progress()
|
||||
if progress is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
# the task is still setting up
|
||||
return
|
||||
|
||||
# never use any blocking methods on the result from inside a task!
|
||||
result: AsyncResult = AsyncResult(payload.celery_task_id)
|
||||
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
|
||||
# Verify: if the generator isn't complete, the task must not be in READY state
|
||||
# inner = get_completion / generator_complete not signaled
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int is None: # inner signal not set ... possible error
|
||||
task_state = result.state
|
||||
if (
|
||||
task_state in READY_STATES
|
||||
): # outer signal in terminal state ... possible error
|
||||
# Now double check!
|
||||
if redis_connector_index.get_completion() is None:
|
||||
# inner signal still not set (and cannot change when outer result_state is READY)
|
||||
# Task is finished but generator complete isn't set.
|
||||
# We have a problem! Worker may have crashed.
|
||||
task_result = str(result.result)
|
||||
task_traceback = str(result.traceback)
|
||||
|
||||
msg = (
|
||||
f"Connector indexing aborted or exceptioned: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"celery_task={payload.celery_task_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"result.state={task_state} "
|
||||
f"result.result={task_result} "
|
||||
f"result.traceback={task_traceback}"
|
||||
)
|
||||
task_logger.warning(msg)
|
||||
|
||||
try:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session, payload.index_attempt_id
|
||||
)
|
||||
if index_attempt:
|
||||
if (
|
||||
index_attempt.status != IndexingStatus.CANCELED
|
||||
and index_attempt.status != IndexingStatus.FAILED
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Connector indexing - Transient exception marking index attempt as failed: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
if redis_connector_index.watchdog_signaled():
|
||||
# if the generator is complete, don't clean up until the watchdog has exited
|
||||
task_logger.info(
|
||||
f"Connector indexing - Delaying finalization until watchdog has exited: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
soft_time_limit=300,
|
||||
@@ -91,6 +245,25 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
try:
|
||||
locked = True
|
||||
|
||||
# SPECIAL 0/3: sync lookup table for active fences
|
||||
# we want to run this less frequently than the overall task
|
||||
if not redis_client.exists(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE):
|
||||
# build a lookup table of existing fences
|
||||
# this is just a migration concern and should be unnecessary once
|
||||
# lookup tables are rolled out
|
||||
for key_bytes in redis_client_replica.scan_iter(
|
||||
count=SCAN_ITER_COUNT_DEFAULT
|
||||
):
|
||||
if is_fence(key_bytes) and not redis_client.sismember(
|
||||
OnyxRedisConstants.ACTIVE_FENCES, key_bytes
|
||||
):
|
||||
logger.warning(f"Adding {key_bytes} to the lookup table.")
|
||||
redis_client.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
|
||||
redis_client.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
|
||||
|
||||
# 1/3: KICKOFF
|
||||
|
||||
# check for search settings swap
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_search_settings = check_index_swap(db_session=db_session)
|
||||
@@ -197,6 +370,8 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
# 2/3: VALIDATE
|
||||
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
# This shouldn't ever happen!
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -236,6 +411,26 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
task_logger.exception("Exception while validating indexing fences")
|
||||
|
||||
redis_client.set(OnyxRedisSignals.BLOCK_VALIDATE_INDEXING_FENCES, 1, ex=60)
|
||||
|
||||
# 3/3: FINALIZE
|
||||
lock_beat.reacquire()
|
||||
keys = cast(
|
||||
set[Any], redis_client_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES)
|
||||
)
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
if not redis_client.exists(key_bytes):
|
||||
redis_client.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(
|
||||
tenant_id, key_bytes, redis_client_replica, db_session
|
||||
)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
|
||||
@@ -17,7 +17,8 @@ from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.vespa.tasks import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -722,7 +723,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
|
||||
name=OnyxCeleryTask.CLOUD_MONITOR_ALEMBIC,
|
||||
)
|
||||
def cloud_check_alembic() -> bool | None:
|
||||
"""A task to verify that all tenants are on the same alembic revision.
|
||||
@@ -852,3 +853,55 @@ def cloud_check_alembic() -> bool | None:
|
||||
f"cloud_check_alembic finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLOUD_MONITOR_CELERY_QUEUES, ignore_result=True, bind=True
|
||||
)
|
||||
def cloud_monitor_celery_queues(
|
||||
self: Task,
|
||||
) -> None:
|
||||
return monitor_celery_queues_helper(self)
|
||||
|
||||
|
||||
@shared_task(name=OnyxCeleryTask.MONITOR_CELERY_QUEUES, ignore_result=True, bind=True)
|
||||
def monitor_celery_queues(self: Task, *, tenant_id: str | None) -> None:
|
||||
return monitor_celery_queues_helper(self)
|
||||
|
||||
|
||||
def monitor_celery_queues_helper(
|
||||
task: Task,
|
||||
) -> None:
|
||||
"""A task to monitor all celery queue lengths."""
|
||||
|
||||
r_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery)
|
||||
n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery)
|
||||
n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
|
||||
n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery)
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
n_external_group_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
n_permissions_upsert = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
|
||||
n_indexing_prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"indexing_prefetched={len(n_indexing_prefetched)} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
f"external_group_sync={n_external_group_sync} "
|
||||
f"permissions_upsert={n_permissions_upsert} "
|
||||
)
|
||||
|
||||
@@ -122,34 +122,39 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
return None
|
||||
|
||||
try:
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
# the entire task needs to run frequently in order to finalize pruning
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
# but pruning only kicks off once per hour
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
continue
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
if not _is_pruning_due(cc_pair):
|
||||
continue
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
continue
|
||||
|
||||
payload_id = try_creating_prune_generator_task(
|
||||
self.app, cc_pair, db_session, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
continue
|
||||
if not _is_pruning_due(cc_pair):
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}"
|
||||
)
|
||||
payload_id = try_creating_prune_generator_task(
|
||||
self.app, cc_pair, db_session, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
continue
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}"
|
||||
)
|
||||
r.set(OnyxRedisSignals.BLOCK_PRUNING, 1, ex=3600)
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
lock_beat.reacquire()
|
||||
@@ -163,6 +168,22 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
task_logger.exception("Exception while validating pruning fences")
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_PRUNING_FENCES, 1, ex=300)
|
||||
|
||||
# use a lookup table to find active fences. We still have to verify the fence
|
||||
# exists since it is an optimization and not the source of truth.
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
if not r.exists(key_bytes):
|
||||
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -481,7 +502,7 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
|
||||
|
||||
"""Monitoring pruning utils, called in monitor_vespa_sync"""
|
||||
"""Monitoring pruning utils"""
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
|
||||
@@ -8,6 +8,7 @@ from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis.lock import Lock as RedisLock
|
||||
from tenacity import RetryError
|
||||
|
||||
from ee.onyx.server.tenants.product_gating import get_gated_tenants
|
||||
from onyx.access.access import get_access_for_document
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
@@ -252,7 +253,11 @@ def cloud_beat_task_generator(
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
gated_tenants = get_gated_tenants()
|
||||
for tenant_id in tenant_ids:
|
||||
if tenant_id in gated_tenants:
|
||||
continue
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4):
|
||||
lock_beat.reacquire()
|
||||
@@ -270,6 +275,7 @@ def cloud_beat_task_generator(
|
||||
queue=queue,
|
||||
priority=priority,
|
||||
expires=expires,
|
||||
ignore_result=True,
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -13,8 +9,6 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from celery.result import AsyncResult
|
||||
from celery.states import READY_STATES
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -22,47 +16,27 @@ from tenacity import RetryError
|
||||
|
||||
from onyx.access.access import get_access_for_document
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.doc_permission_syncing.tasks import (
|
||||
monitor_ccpair_permissions_taskset,
|
||||
)
|
||||
from onyx.background.celery.tasks.pruning.tasks import monitor_ccpair_pruning_taskset
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.connector_credential_pair import add_deletion_failure_message
|
||||
from onyx.db.connector_credential_pair import (
|
||||
delete_connector_credential_pair__no_commit,
|
||||
)
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import count_documents_by_needs_sync
|
||||
from onyx.db.document import get_document
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document import mark_document_as_synced
|
||||
from onyx.db.document_set import delete_document_set
|
||||
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from onyx.db.document_set import fetch_document_sets
|
||||
from onyx.db.document_set import fetch_document_sets_for_document
|
||||
from onyx.db.document_set import get_document_set_by_id
|
||||
from onyx.db.document_set import mark_document_set_as_synced
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.index_attempt import delete_index_attempts
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
@@ -72,20 +46,14 @@ from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_credential_pair import RedisConnectorCredentialPair
|
||||
from onyx.redis.redis_connector_credential_pair import (
|
||||
RedisGlobalConnectorCredentialPair,
|
||||
)
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
@@ -94,7 +62,6 @@ from onyx.utils.variable_functionality import (
|
||||
)
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from onyx.utils.variable_functionality import noop_fallback
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -114,6 +81,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
time_start = time.monotonic()
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||
@@ -125,6 +93,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
return None
|
||||
|
||||
try:
|
||||
# 1/3: KICKOFF
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_stale_document_sync_tasks(
|
||||
self.app, VESPA_SYNC_MAX_TASKS, db_session, r, lock_beat, tenant_id
|
||||
@@ -151,9 +120,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
# endregion
|
||||
|
||||
# check if any user groups are not synced
|
||||
lock_beat.reacquire()
|
||||
if global_version.is_ee_version():
|
||||
lock_beat.reacquire()
|
||||
|
||||
try:
|
||||
fetch_user_groups = fetch_versioned_implementation(
|
||||
"onyx.db.user_group", "fetch_user_groups"
|
||||
@@ -179,6 +147,35 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
|
||||
# 2/3: VALIDATE: TODO
|
||||
|
||||
# 3/3: FINALIZE
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
if not r.exists(key_bytes):
|
||||
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
|
||||
monitor_connector_taskset(r)
|
||||
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
elif key_str.startswith(RedisUserGroup.FENCE_PREFIX):
|
||||
monitor_usergroup_taskset = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"onyx.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -506,475 +503,6 @@ def monitor_document_set_taskset(
|
||||
rds.reset()
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(f"could not parse cc_pair_id from {fence_key}")
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
fence_data = redis_connector.delete.payload
|
||||
if not fence_data:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
if fence_data.num_tasks is None:
|
||||
# the fence is setting up but isn't ready yet
|
||||
return
|
||||
|
||||
remaining = redis_connector.delete.get_remaining()
|
||||
task_logger.info(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
|
||||
)
|
||||
if remaining > 0:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.IN_PROGRESS,
|
||||
num_docs_synced=remaining,
|
||||
)
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
doc_ids = get_document_ids_for_connector_credential_pair(
|
||||
db_session, cc_pair.connector_id, cc_pair.credential_id
|
||||
)
|
||||
if len(doc_ids) > 0:
|
||||
# NOTE(rkuo): if this happens, documents somehow got added while
|
||||
# deletion was in progress. Likely a bug gating off pruning and indexing
|
||||
# work before deletion starts.
|
||||
task_logger.warning(
|
||||
"Connector deletion - documents still found after taskset completion. "
|
||||
"Clearing the current deletion attempt and allowing deletion to restart: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"docs_deleted={fence_data.num_tasks} "
|
||||
f"docs_remaining={len(doc_ids)}"
|
||||
)
|
||||
|
||||
# We don't want to waive off why we get into this state, but resetting
|
||||
# our attempt and letting the deletion restart is a good way to recover
|
||||
redis_connector.delete.reset()
|
||||
raise RuntimeError(
|
||||
"Connector deletion - documents still found after taskset completion"
|
||||
)
|
||||
|
||||
# clean up the rest of the related Postgres entities
|
||||
# index attempts
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
# document sets
|
||||
delete_document_set_cc_pair_relationship__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# user groups
|
||||
cleanup_user_groups = fetch_versioned_implementation_with_fallback(
|
||||
"onyx.db.user_group",
|
||||
"delete_user_group_cc_pair_relationship__no_commit",
|
||||
noop_fallback,
|
||||
)
|
||||
cleanup_user_groups(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
"Connector deletion - Found no credentials left for connector, deleting connector"
|
||||
)
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
stack_trace = traceback.format_exc()
|
||||
error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}"
|
||||
add_deletion_failure_message(db_session, cc_pair_id, error_message)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
sync_status=SyncStatus.FAILED,
|
||||
num_docs_synced=fence_data.num_tasks,
|
||||
)
|
||||
|
||||
task_logger.exception(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Connector deletion succeeded: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"docs_deleted={fence_data.num_tasks}"
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
composite_id = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if composite_id is None:
|
||||
task_logger.warning(
|
||||
f"Connector indexing: could not parse composite_id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
parts = composite_id.split("/")
|
||||
if len(parts) != 2:
|
||||
return
|
||||
|
||||
cc_pair_id = int(parts[0])
|
||||
search_settings_id = int(parts[1])
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
if not redis_connector_index.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector_index.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
elapsed_started_str = None
|
||||
if payload.started:
|
||||
elapsed_started = datetime.now(timezone.utc) - payload.started
|
||||
elapsed_started_str = f"{elapsed_started.total_seconds():.2f}"
|
||||
|
||||
elapsed_submitted = datetime.now(timezone.utc) - payload.submitted
|
||||
|
||||
progress = redis_connector_index.get_progress()
|
||||
if progress is not None:
|
||||
task_logger.info(
|
||||
f"Connector indexing progress: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
if payload.index_attempt_id is None or payload.celery_task_id is None:
|
||||
# the task is still setting up
|
||||
return
|
||||
|
||||
# never use any blocking methods on the result from inside a task!
|
||||
result: AsyncResult = AsyncResult(payload.celery_task_id)
|
||||
|
||||
# inner/outer/inner double check pattern to avoid race conditions when checking for
|
||||
# bad state
|
||||
|
||||
# Verify: if the generator isn't complete, the task must not be in READY state
|
||||
# inner = get_completion / generator_complete not signaled
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int is None: # inner signal not set ... possible error
|
||||
task_state = result.state
|
||||
if (
|
||||
task_state in READY_STATES
|
||||
): # outer signal in terminal state ... possible error
|
||||
# Now double check!
|
||||
if redis_connector_index.get_completion() is None:
|
||||
# inner signal still not set (and cannot change when outer result_state is READY)
|
||||
# Task is finished but generator complete isn't set.
|
||||
# We have a problem! Worker may have crashed.
|
||||
task_result = str(result.result)
|
||||
task_traceback = str(result.traceback)
|
||||
|
||||
msg = (
|
||||
f"Connector indexing aborted or exceptioned: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"celery_task={payload.celery_task_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"result.state={task_state} "
|
||||
f"result.result={task_result} "
|
||||
f"result.traceback={task_traceback}"
|
||||
)
|
||||
task_logger.warning(msg)
|
||||
|
||||
try:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session, payload.index_attempt_id
|
||||
)
|
||||
if index_attempt:
|
||||
if (
|
||||
index_attempt.status != IndexingStatus.CANCELED
|
||||
and index_attempt.status != IndexingStatus.FAILED
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Connector indexing - Transient exception marking index attempt as failed: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
if redis_connector_index.watchdog_signaled():
|
||||
# if the generator is complete, don't clean up until the watchdog has exited
|
||||
task_logger.info(
|
||||
f"Connector indexing - Delaying finalization until watchdog has exited: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
status_enum = HTTPStatus(status_int)
|
||||
|
||||
task_logger.info(
|
||||
f"Connector indexing finished: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"progress={progress} "
|
||||
f"status={status_enum.name} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"elapsed_started={elapsed_started_str}"
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
ignore_result=True,
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
"""This is a celery beat task that monitors and finalizes various long running tasks.
|
||||
|
||||
The name monitor_vespa_sync is a bit of a misnomer since it checks many different tasks
|
||||
now. Should change that at some point.
|
||||
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
For many tasks, the count is 0, that means all tasks finished and we should clean up.
|
||||
|
||||
This task lock timeout is CELERY_METADATA_SYNC_BEAT_LOCK_TIMEOUT seconds, so don't
|
||||
do anything too expensive in this function!
|
||||
|
||||
Returns True if the task actually did work, False if it exited early to prevent overlap
|
||||
"""
|
||||
task_logger.info(f"monitor_vespa_sync starting: tenant={tenant_id}")
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Replica usage notes
|
||||
#
|
||||
# False negatives are OK. (aka fail to to see a key that exists on the master).
|
||||
# We simply skip the monitoring work and it will be caught on the next pass.
|
||||
#
|
||||
# False positives are not OK, and are possible if we clear a fence on the master and
|
||||
# then read from the replica. In this case, monitoring work could be done on a fence
|
||||
# that no longer exists. To avoid this, we scan from the replica, but double check
|
||||
# the result on the master.
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.MONITOR_VESPA_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# prevent overlapping tasks
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
# print current queue lengths
|
||||
time.monotonic()
|
||||
# we don't need every tenant polling redis for this info.
|
||||
if not MULTI_TENANT or random.randint(1, 10) == 10:
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
n_celery = celery_get_queue_length("celery", r_celery)
|
||||
n_indexing = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
n_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery
|
||||
)
|
||||
n_deletion = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
n_pruning = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
n_permissions_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
n_external_group_sync = celery_get_queue_length(
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
n_permissions_upsert = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
|
||||
prefetched = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_INDEXING, r_celery
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Queue lengths: celery={n_celery} "
|
||||
f"indexing={n_indexing} "
|
||||
f"indexing_prefetched={len(prefetched)} "
|
||||
f"sync={n_sync} "
|
||||
f"deletion={n_deletion} "
|
||||
f"pruning={n_pruning} "
|
||||
f"permissions_sync={n_permissions_sync} "
|
||||
f"external_group_sync={n_external_group_sync} "
|
||||
f"permissions_upsert={n_permissions_upsert} "
|
||||
)
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE):
|
||||
# build a lookup table of existing fences
|
||||
# this is just a migration concern and should be unnecessary once
|
||||
# lookup tables are rolled out
|
||||
for key_bytes in r_replica.scan_iter(count=SCAN_ITER_COUNT_DEFAULT):
|
||||
if is_fence(key_bytes) and not r.sismember(
|
||||
OnyxRedisConstants.ACTIVE_FENCES, key_bytes
|
||||
):
|
||||
logger.warning(f"Adding {key_bytes} to the lookup table.")
|
||||
r.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
|
||||
|
||||
# use a lookup table to find active fences. We still have to verify the fence
|
||||
# exists since it is an optimization and not the source of truth.
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
if not r.exists(key_bytes):
|
||||
r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
|
||||
continue
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
|
||||
monitor_connector_taskset(r)
|
||||
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
elif key_str.startswith(RedisUserGroup.FENCE_PREFIX):
|
||||
monitor_usergroup_taskset = (
|
||||
fetch_versioned_implementation_with_fallback(
|
||||
"onyx.background.celery.tasks.vespa.tasks",
|
||||
"monitor_usergroup_taskset",
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
elif key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
|
||||
monitor_connector_deletion_taskset(tenant_id, key_bytes, r)
|
||||
elif key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
elif key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(tenant_id, key_bytes, r, db_session)
|
||||
elif key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(
|
||||
tenant_id, key_bytes, r, db_session
|
||||
)
|
||||
else:
|
||||
pass
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
task_logger.exception("monitor_vespa_sync exceptioned.")
|
||||
return False
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
else:
|
||||
task_logger.error(
|
||||
"monitor_vespa_sync - Lock not owned on completion: "
|
||||
f"tenant={tenant_id}"
|
||||
# f"timings={timings}"
|
||||
)
|
||||
redis_lock_dump(lock_beat, r)
|
||||
|
||||
time_elapsed = time.monotonic() - time_start
|
||||
task_logger.info(f"monitor_vespa_sync finished: elapsed={time_elapsed:.2f}")
|
||||
return True
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
bind=True,
|
||||
@@ -1072,23 +600,3 @@ def vespa_metadata_sync_task(
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_fence(key_bytes: bytes) -> bool:
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
|
||||
return True
|
||||
if key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisUserGroup.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
13
backend/onyx/background/error_logging.py
Normal file
13
backend/onyx/background/error_logging.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from onyx.db.background_error import create_background_error
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
|
||||
|
||||
def emit_background_error(
|
||||
message: str,
|
||||
cc_pair_id: int | None = None,
|
||||
) -> None:
|
||||
"""Currently just saves a row in the background_errors table.
|
||||
|
||||
In the future, could create notifications based on the severity."""
|
||||
with get_session_with_tenant() as db_session:
|
||||
create_background_error(db_session, message, cc_pair_id)
|
||||
@@ -107,9 +107,9 @@ CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_PRUNING_LOCK_TIMEOUT = 300 # 5 min
|
||||
CELERY_PRUNING_LOCK_TIMEOUT = 3600 # 1 hour (in seconds)
|
||||
|
||||
CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 3600 # 1 hour (in seconds)
|
||||
|
||||
CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
@@ -298,7 +298,6 @@ class OnyxRedisLocks:
|
||||
CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = (
|
||||
"da_lock:check_connector_external_group_sync_beat"
|
||||
)
|
||||
MONITOR_VESPA_SYNC_BEAT_LOCK = "da_lock:monitor_vespa_sync_beat"
|
||||
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
|
||||
|
||||
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
|
||||
@@ -324,6 +323,7 @@ class OnyxRedisSignals:
|
||||
BLOCK_VALIDATE_PERMISSION_SYNC_FENCES = (
|
||||
"signal:block_validate_permission_sync_fences"
|
||||
)
|
||||
BLOCK_PRUNING = "signal:block_pruning"
|
||||
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
|
||||
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
|
||||
|
||||
@@ -354,7 +354,10 @@ class OnyxCeleryTask:
|
||||
DEFAULT = "celery"
|
||||
|
||||
CLOUD_BEAT_TASK_GENERATOR = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_generate_beat_tasks"
|
||||
CLOUD_CHECK_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_alembic"
|
||||
CLOUD_MONITOR_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_alembic"
|
||||
CLOUD_MONITOR_CELERY_QUEUES = (
|
||||
f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_queues"
|
||||
)
|
||||
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
@@ -364,8 +367,8 @@ class OnyxCeleryTask:
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
|
||||
|
||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
|
||||
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
|
||||
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
|
||||
@@ -145,7 +145,8 @@ def fetch_jira_issues_batch(
|
||||
id=page_url,
|
||||
sections=[Section(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=issue.fields.summary,
|
||||
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
|
||||
title=f"{issue.key} {issue.fields.summary}",
|
||||
doc_updated_at=time_str_to_utc(issue.fields.updated),
|
||||
primary_owners=list(people) or None,
|
||||
# TODO add secondary_owners (commenters) if needed
|
||||
|
||||
@@ -51,6 +51,7 @@ class SearchPipeline:
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
fast_llm: LLM,
|
||||
skip_query_analysis: bool,
|
||||
db_session: Session,
|
||||
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
|
||||
retrieval_metrics_callback: (
|
||||
@@ -61,10 +62,13 @@ class SearchPipeline:
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
prompt_config: PromptConfig | None = None,
|
||||
):
|
||||
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
|
||||
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
|
||||
self.search_request = search_request
|
||||
self.user = user
|
||||
self.llm = llm
|
||||
self.fast_llm = fast_llm
|
||||
self.skip_query_analysis = skip_query_analysis
|
||||
self.db_session = db_session
|
||||
self.bypass_acl = bypass_acl
|
||||
self.retrieval_metrics_callback = retrieval_metrics_callback
|
||||
@@ -106,6 +110,7 @@ class SearchPipeline:
|
||||
search_request=self.search_request,
|
||||
user=self.user,
|
||||
llm=self.llm,
|
||||
skip_query_analysis=self.skip_query_analysis,
|
||||
db_session=self.db_session,
|
||||
bypass_acl=self.bypass_acl,
|
||||
)
|
||||
@@ -160,6 +165,12 @@ class SearchPipeline:
|
||||
that have a corresponding chunk.
|
||||
|
||||
This step should be fast for any document index implementation.
|
||||
|
||||
Current implementation timing is approximately broken down in timing as:
|
||||
- 200 ms to get the embedding of the query
|
||||
- 15 ms to get chunks from the document index
|
||||
- possibly more to get additional surrounding chunks
|
||||
- possibly more for query expansion (multilingual)
|
||||
"""
|
||||
if self._retrieved_sections is not None:
|
||||
return self._retrieved_sections
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceChunkUncleaned
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import MAX_METRICS_CONTENT
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RerankMetricsContainer
|
||||
from onyx.context.search.models import SearchQuery
|
||||
from onyx.document_index.document_index_utils import (
|
||||
@@ -77,7 +78,8 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def semantic_reranking(
|
||||
query: SearchQuery,
|
||||
query_str: str,
|
||||
rerank_settings: RerankingDetails,
|
||||
chunks: list[InferenceChunk],
|
||||
model_min: int = CROSS_ENCODER_RANGE_MIN,
|
||||
model_max: int = CROSS_ENCODER_RANGE_MAX,
|
||||
@@ -88,11 +90,9 @@ def semantic_reranking(
|
||||
|
||||
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
|
||||
"""
|
||||
rerank_settings = query.rerank_settings
|
||||
|
||||
if not rerank_settings or not rerank_settings.rerank_model_name:
|
||||
# Should never reach this part of the flow without reranking settings
|
||||
raise RuntimeError("Reranking flow should not be running")
|
||||
assert (
|
||||
rerank_settings.rerank_model_name
|
||||
), "Reranking flow cannot run without a specific model"
|
||||
|
||||
chunks_to_rerank = chunks[: rerank_settings.num_rerank]
|
||||
|
||||
@@ -107,7 +107,7 @@ def semantic_reranking(
|
||||
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
|
||||
for chunk in chunks_to_rerank
|
||||
]
|
||||
sim_scores_floats = cross_encoder.predict(query=query.query, passages=passages)
|
||||
sim_scores_floats = cross_encoder.predict(query=query_str, passages=passages)
|
||||
|
||||
# Old logic to handle multiple cross-encoders preserved but not used
|
||||
sim_scores = [numpy.array(sim_scores_floats)]
|
||||
@@ -165,8 +165,20 @@ def semantic_reranking(
|
||||
return list(ranked_chunks), list(ranked_indices)
|
||||
|
||||
|
||||
def should_rerank(rerank_settings: RerankingDetails | None) -> bool:
|
||||
"""Based on the RerankingDetails model, only run rerank if the following conditions are met:
|
||||
- rerank_model_name is not None
|
||||
- num_rerank is greater than 0
|
||||
"""
|
||||
if not rerank_settings:
|
||||
return False
|
||||
|
||||
return bool(rerank_settings.rerank_model_name and rerank_settings.num_rerank > 0)
|
||||
|
||||
|
||||
def rerank_sections(
|
||||
query: SearchQuery,
|
||||
query_str: str,
|
||||
rerank_settings: RerankingDetails,
|
||||
sections_to_rerank: list[InferenceSection],
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> list[InferenceSection]:
|
||||
@@ -181,16 +193,13 @@ def rerank_sections(
|
||||
"""
|
||||
chunks_to_rerank = [section.center_chunk for section in sections_to_rerank]
|
||||
|
||||
if not query.rerank_settings:
|
||||
# Should never reach this part of the flow without reranking settings
|
||||
raise RuntimeError("Reranking settings not found")
|
||||
|
||||
ranked_chunks, _ = semantic_reranking(
|
||||
query=query,
|
||||
query_str=query_str,
|
||||
rerank_settings=rerank_settings,
|
||||
chunks=chunks_to_rerank,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
lower_chunks = chunks_to_rerank[query.rerank_settings.num_rerank :]
|
||||
lower_chunks = chunks_to_rerank[rerank_settings.num_rerank :]
|
||||
|
||||
# Scores from rerank cannot be meaningfully combined with scores without rerank
|
||||
# However the ordering is still important
|
||||
@@ -260,16 +269,13 @@ def search_postprocessing(
|
||||
|
||||
rerank_task_id = None
|
||||
sections_yielded = False
|
||||
if (
|
||||
search_query.rerank_settings
|
||||
and search_query.rerank_settings.rerank_model_name
|
||||
and search_query.rerank_settings.num_rerank > 0
|
||||
):
|
||||
if should_rerank(search_query.rerank_settings):
|
||||
post_processing_tasks.append(
|
||||
FunctionCall(
|
||||
rerank_sections,
|
||||
(
|
||||
search_query,
|
||||
search_query.query,
|
||||
search_query.rerank_settings, # Cannot be None here
|
||||
retrieved_sections,
|
||||
rerank_metrics_callback,
|
||||
),
|
||||
|
||||
@@ -50,11 +50,11 @@ def retrieval_preprocessing(
|
||||
search_request: SearchRequest,
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
skip_query_analysis: bool,
|
||||
db_session: Session,
|
||||
bypass_acl: bool = False,
|
||||
skip_query_analysis: bool = False,
|
||||
base_recency_decay: float = BASE_RECENCY_DECAY,
|
||||
favor_recent_decay_multiplier: float = FAVOR_RECENT_DECAY_MULTIPLIER,
|
||||
base_recency_decay: float = BASE_RECENCY_DECAY,
|
||||
bypass_acl: bool = False,
|
||||
) -> SearchQuery:
|
||||
"""Logic is as follows:
|
||||
Any global disables apply first
|
||||
@@ -146,7 +146,7 @@ def retrieval_preprocessing(
|
||||
is_keyword, extracted_keywords = (
|
||||
parallel_results[run_query_analysis.result_id]
|
||||
if run_query_analysis
|
||||
else (None, None)
|
||||
else (False, None)
|
||||
)
|
||||
|
||||
all_query_terms = query.split()
|
||||
|
||||
10
backend/onyx/db/background_error.py
Normal file
10
backend/onyx/db/background_error.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import BackgroundError
|
||||
|
||||
|
||||
def create_background_error(
|
||||
db_session: Session, message: str, cc_pair_id: int | None
|
||||
) -> None:
|
||||
db_session.add(BackgroundError(message=message, cc_pair_id=cc_pair_id))
|
||||
db_session.commit()
|
||||
@@ -483,6 +483,10 @@ class ConnectorCredentialPair(Base):
|
||||
primaryjoin="foreign(ConnectorCredentialPair.creator_id) == remote(User.id)",
|
||||
)
|
||||
|
||||
background_errors: Mapped[list["BackgroundError"]] = relationship(
|
||||
"BackgroundError", back_populates="cc_pair", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class Document(Base):
|
||||
__tablename__ = "document"
|
||||
@@ -2115,6 +2119,31 @@ class StandardAnswer(Base):
|
||||
)
|
||||
|
||||
|
||||
class BackgroundError(Base):
|
||||
"""Important background errors. Serves to:
|
||||
1. Ensure that important logs are kept around and not lost on rotation/container restarts
|
||||
2. A trail for high-signal events so that the debugger doesn't need to remember/know every
|
||||
possible relevant log line.
|
||||
"""
|
||||
|
||||
__tablename__ = "background_error"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
message: Mapped[str] = mapped_column(String)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
# option to link the error to a specific CC Pair
|
||||
cc_pair_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("connector_credential_pair.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
|
||||
cc_pair: Mapped["ConnectorCredentialPair | None"] = relationship(
|
||||
"ConnectorCredentialPair", back_populates="background_errors"
|
||||
)
|
||||
|
||||
|
||||
"""Tables related to Permission Sync"""
|
||||
|
||||
|
||||
|
||||
@@ -396,9 +396,14 @@ class DefaultMultiLLM(LLM):
|
||||
self._record_call(processed_prompt)
|
||||
|
||||
try:
|
||||
print(
|
||||
"model is",
|
||||
f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}",
|
||||
)
|
||||
return litellm.completion(
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
# model choice
|
||||
# model="openai/gpt-4",
|
||||
model=f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}",
|
||||
# NOTE: have to pass in None instead of empty string for these
|
||||
# otherwise litellm can have some issues with bedrock
|
||||
|
||||
@@ -99,7 +99,7 @@ def _check_tokenizer_cache(
|
||||
|
||||
if not tokenizer:
|
||||
logger.info(
|
||||
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
|
||||
f"Falling back to default embedding model tokenizer: {DOCUMENT_ENCODER_MODEL}"
|
||||
)
|
||||
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
||||
|
||||
|
||||
@@ -120,6 +120,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
queue=OnyxCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
ignore_result=True,
|
||||
)
|
||||
|
||||
num_tasks_sent += 1
|
||||
|
||||
@@ -132,6 +132,7 @@ class RedisConnectorDelete:
|
||||
queue=OnyxCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
ignore_result=True,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
@@ -11,6 +11,7 @@ from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -49,7 +50,7 @@ class RedisConnectorPermissionSync:
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
ACTIVE_TTL = CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT * 2
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
@@ -195,6 +196,7 @@ class RedisConnectorPermissionSync:
|
||||
queue=OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
ignore_result=True,
|
||||
)
|
||||
async_results.append(result)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -49,7 +50,7 @@ class RedisConnectorPrune:
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
ACTIVE_TTL = CELERY_PRUNING_LOCK_TIMEOUT * 2
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
@@ -201,6 +202,7 @@ class RedisConnectorPrune:
|
||||
queue=OnyxCeleryQueues.CONNECTOR_DELETION,
|
||||
task_id=custom_task_id,
|
||||
priority=OnyxCeleryPriority.MEDIUM,
|
||||
ignore_result=True,
|
||||
)
|
||||
|
||||
async_results.append(result)
|
||||
|
||||
29
backend/onyx/redis/redis_utils.py
Normal file
29
backend/onyx/redis/redis_utils.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from onyx.redis.redis_connector_credential_pair import (
|
||||
RedisGlobalConnectorCredentialPair,
|
||||
)
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
|
||||
|
||||
def is_fence(key_bytes: bytes) -> bool:
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
|
||||
return True
|
||||
if key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisUserGroup.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -162,6 +162,11 @@ def load_personas_from_yaml(
|
||||
else persona.get("is_visible")
|
||||
),
|
||||
db_session=db_session,
|
||||
is_default_persona=(
|
||||
existing_persona.is_default_persona
|
||||
if existing_persona is not None
|
||||
else persona.get("is_default_persona", False)
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ personas:
|
||||
icon_color: "#6FB1FF"
|
||||
display_priority: 0
|
||||
is_visible: true
|
||||
is_default_persona: true
|
||||
starter_messages:
|
||||
- name: "Give me an overview of what's here"
|
||||
message: "Sample some documents and tell me what you find."
|
||||
@@ -66,6 +67,7 @@ personas:
|
||||
icon_color: "#FF6F6F"
|
||||
display_priority: 1
|
||||
is_visible: true
|
||||
is_default_persona: true
|
||||
starter_messages:
|
||||
- name: "Summarize a document"
|
||||
message: "If I have provided a document please summarize it for me. If not, please ask me to upload a document either by dragging it into the input bar or clicking the +file icon."
|
||||
@@ -91,6 +93,7 @@ personas:
|
||||
icon_color: "#6FFF8D"
|
||||
display_priority: 2
|
||||
is_visible: false
|
||||
is_default_persona: true
|
||||
starter_messages:
|
||||
- name: "Document Search"
|
||||
message: "Hi! Could you help me find information about our team structure and reporting lines from our internal documents?"
|
||||
@@ -117,6 +120,7 @@ personas:
|
||||
image_generation: true
|
||||
display_priority: 3
|
||||
is_visible: true
|
||||
is_default_persona: true
|
||||
starter_messages:
|
||||
- name: "Create visuals for a presentation"
|
||||
message: "Generate someone presenting a graph which clearly demonstrates an upwards trajectory."
|
||||
|
||||
@@ -76,6 +76,7 @@ def gpt_search(
|
||||
user=None,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
skip_query_analysis=True,
|
||||
db_session=db_session,
|
||||
).reranked_sections
|
||||
|
||||
|
||||
@@ -12,10 +12,10 @@ class PageType(str, Enum):
|
||||
SEARCH = "search"
|
||||
|
||||
|
||||
class GatingType(str, Enum):
|
||||
FULL = "full" # Complete restriction of access to the product or service
|
||||
PARTIAL = "partial" # Full access but warning (no credit card on file)
|
||||
NONE = "none" # No restrictions, full access to all features
|
||||
class ApplicationStatus(str, Enum):
|
||||
PAYMENT_REMINDER = "payment_reminder"
|
||||
GATED_ACCESS = "gated_access"
|
||||
ACTIVE = "active"
|
||||
|
||||
|
||||
class Notification(BaseModel):
|
||||
@@ -43,7 +43,7 @@ class Settings(BaseModel):
|
||||
|
||||
maximum_chat_retention_days: int | None = None
|
||||
gpu_enabled: bool | None = None
|
||||
product_gating: GatingType = GatingType.NONE
|
||||
application_status: ApplicationStatus = ApplicationStatus.ACTIVE
|
||||
anonymous_user_enabled: bool | None = None
|
||||
pro_search_disabled: bool | None = None
|
||||
auto_scroll: bool | None = None
|
||||
|
||||
@@ -34,7 +34,7 @@ Now respond to the following:
|
||||
""".strip()
|
||||
|
||||
|
||||
class BaseTool(Tool):
|
||||
class BaseTool(Tool[None]):
|
||||
def build_next_prompt(
|
||||
self,
|
||||
prompt_builder: "AnswerPromptBuilder",
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
@@ -57,5 +60,15 @@ class SearchQueryInfo(BaseModel):
|
||||
recency_bias_multiplier: float
|
||||
|
||||
|
||||
class SearchToolOverrideKwargs(BaseModel):
|
||||
force_no_rerank: bool
|
||||
alternate_db_session: Session | None
|
||||
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None
|
||||
skip_query_analysis: bool
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
CHAT_SESSION_ID_PLACEHOLDER = "CHAT_SESSION_ID"
|
||||
MESSAGE_ID_PLACEHOLDER = "MESSAGE_ID"
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
@@ -14,7 +16,10 @@ if TYPE_CHECKING:
|
||||
from onyx.tools.models import ToolResponse
|
||||
|
||||
|
||||
class Tool(abc.ABC):
|
||||
OVERRIDE_T = TypeVar("OVERRIDE_T")
|
||||
|
||||
|
||||
class Tool(abc.ABC, Generic[OVERRIDE_T]):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
@@ -57,7 +62,9 @@ class Tool(abc.ABC):
|
||||
"""Actual execution of the tool"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self, **kwargs: Any) -> Generator["ToolResponse", None, None]:
|
||||
def run(
|
||||
self, override_kwargs: OVERRIDE_T | None = None, **llm_kwargs: Any
|
||||
) -> Generator["ToolResponse", None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -74,6 +74,7 @@ class CustomToolCallSummary(BaseModel):
|
||||
tool_result: Any # The response data
|
||||
|
||||
|
||||
# override_kwargs is not supported for custom tools
|
||||
class CustomTool(BaseTool):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -235,7 +236,9 @@ class CustomTool(BaseTool):
|
||||
|
||||
"""Actual execution of the tool"""
|
||||
|
||||
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
|
||||
def run(
|
||||
self, override_kwargs: dict[str, Any] | None = None, **kwargs: Any
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
request_body = kwargs.get(REQUEST_BODY)
|
||||
|
||||
path_params = {}
|
||||
|
||||
@@ -79,7 +79,8 @@ class ImageShape(str, Enum):
|
||||
LANDSCAPE = "landscape"
|
||||
|
||||
|
||||
class ImageGenerationTool(Tool):
|
||||
# override_kwargs is not supported for image generation tools
|
||||
class ImageGenerationTool(Tool[None]):
|
||||
_NAME = "run_image_generation"
|
||||
_DESCRIPTION = "Generate an image from a prompt."
|
||||
_DISPLAY_NAME = "Image Generation"
|
||||
@@ -255,7 +256,9 @@ class ImageGenerationTool(Tool):
|
||||
"An error occurred during image generation. Please try again later."
|
||||
)
|
||||
|
||||
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
|
||||
def run(
|
||||
self, override_kwargs: None = None, **kwargs: str
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
prompt = cast(str, kwargs["prompt"])
|
||||
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))
|
||||
format = self.output_format
|
||||
|
||||
@@ -106,7 +106,8 @@ def internet_search_response_to_search_docs(
|
||||
]
|
||||
|
||||
|
||||
class InternetSearchTool(Tool):
|
||||
# override_kwargs is not supported for internet search tools
|
||||
class InternetSearchTool(Tool[None]):
|
||||
_NAME = "run_internet_search"
|
||||
_DISPLAY_NAME = "Internet Search"
|
||||
_DESCRIPTION = "Perform an internet search for up-to-date information."
|
||||
@@ -242,7 +243,9 @@ class InternetSearchTool(Tool):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
|
||||
def run(
|
||||
self, override_kwargs: None = None, **kwargs: str
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
query = cast(str, kwargs["internet_search_query"])
|
||||
|
||||
results = self._perform_search(query)
|
||||
|
||||
@@ -39,6 +39,7 @@ from onyx.secondary_llm_flows.choose_search import check_if_need_search
|
||||
from onyx.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
|
||||
@@ -77,7 +78,7 @@ HINT: if you are unfamiliar with the user input OR think the user input is a typ
|
||||
"""
|
||||
|
||||
|
||||
class SearchTool(Tool):
|
||||
class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
_NAME = "run_search"
|
||||
_DISPLAY_NAME = "Search Tool"
|
||||
_DESCRIPTION = SEARCH_TOOL_DESCRIPTION
|
||||
@@ -275,14 +276,19 @@ class SearchTool(Tool):
|
||||
|
||||
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
|
||||
|
||||
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
|
||||
query = cast(str, kwargs["query"])
|
||||
force_no_rerank = cast(bool, kwargs.get("force_no_rerank", False))
|
||||
alternate_db_session = cast(Session, kwargs.get("alternate_db_session", None))
|
||||
retrieved_sections_callback = cast(
|
||||
Callable[[list[InferenceSection]], None],
|
||||
kwargs.get("retrieved_sections_callback"),
|
||||
)
|
||||
def run(
|
||||
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
query = cast(str, llm_kwargs["query"])
|
||||
force_no_rerank = False
|
||||
alternate_db_session = None
|
||||
retrieved_sections_callback = None
|
||||
skip_query_analysis = False
|
||||
if override_kwargs:
|
||||
force_no_rerank = override_kwargs.force_no_rerank
|
||||
alternate_db_session = override_kwargs.alternate_db_session
|
||||
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
|
||||
skip_query_analysis = override_kwargs.skip_query_analysis
|
||||
|
||||
if self.selected_sections:
|
||||
yield from self._build_response_for_specified_sections(query)
|
||||
@@ -324,6 +330,7 @@ class SearchTool(Tool):
|
||||
user=self.user,
|
||||
llm=self.llm,
|
||||
fast_llm=self.fast_llm,
|
||||
skip_query_analysis=skip_query_analysis,
|
||||
bypass_acl=self.bypass_acl,
|
||||
db_session=alternate_db_session or self.db_session,
|
||||
prompt_config=self.prompt_config,
|
||||
|
||||
@@ -256,16 +256,28 @@ def get_documents_for_tenant_connector(
|
||||
|
||||
|
||||
def search_for_document(
|
||||
index_name: str, document_id: str, max_hits: int | None = 10
|
||||
index_name: str,
|
||||
document_id: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
max_hits: int | None = 10,
|
||||
) -> List[Dict[str, Any]]:
|
||||
yql_query = (
|
||||
f'select * from sources {index_name} where document_id contains "{document_id}"'
|
||||
)
|
||||
yql_query = f"select * from sources {index_name}"
|
||||
|
||||
conditions = []
|
||||
if document_id is not None:
|
||||
conditions.append(f'document_id contains "{document_id}"')
|
||||
|
||||
if tenant_id is not None:
|
||||
conditions.append(f'tenant_id contains "{tenant_id}"')
|
||||
|
||||
if conditions:
|
||||
yql_query += " where " + " and ".join(conditions)
|
||||
|
||||
params: dict[str, Any] = {"yql": yql_query}
|
||||
if max_hits is not None:
|
||||
params["hits"] = max_hits
|
||||
with get_vespa_http_client() as client:
|
||||
response = client.get(f"{SEARCH_ENDPOINT}/search/", params=params)
|
||||
response = client.get(f"{SEARCH_ENDPOINT}search/", params=params)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
documents = result.get("root", {}).get("children", [])
|
||||
@@ -582,8 +594,15 @@ class VespaDebugging:
|
||||
) -> None:
|
||||
update_document(self.tenant_id, connector_id, doc_id, fields)
|
||||
|
||||
def search_for_document(self, document_id: str) -> List[Dict[str, Any]]:
|
||||
return search_for_document(self.index_name, document_id)
|
||||
def delete_documents_for_tenant(self, count: int | None = None) -> None:
|
||||
if not self.tenant_id:
|
||||
raise Exception("Tenant ID is not set")
|
||||
delete_documents_for_tenant(self.index_name, self.tenant_id, count=count)
|
||||
|
||||
def search_for_document(
|
||||
self, document_id: str | None = None, tenant_id: str | None = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
return search_for_document(self.index_name, document_id, tenant_id)
|
||||
|
||||
def delete_document(self, connector_id: int, doc_id: str) -> None:
|
||||
# Delete a document.
|
||||
@@ -600,6 +619,147 @@ class VespaDebugging:
|
||||
get_document_acls(self.tenant_id, cc_pair_id, n)
|
||||
|
||||
|
||||
def delete_where(
|
||||
index_name: str,
|
||||
selection: str,
|
||||
cluster: str = "default",
|
||||
bucket_space: str | None = None,
|
||||
continuation: str | None = None,
|
||||
time_chunk: str | None = None,
|
||||
timeout: str | None = None,
|
||||
tracelevel: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Removes visited documents in `cluster` where the given selection
|
||||
is true, using Vespa's 'delete where' endpoint.
|
||||
|
||||
:param index_name: Typically <namespace>/<document-type> from your schema
|
||||
:param selection: The selection string, e.g., "true" or "foo contains 'bar'"
|
||||
:param cluster: The name of the cluster where documents reside
|
||||
:param bucket_space: e.g. 'global' or 'default'
|
||||
:param continuation: For chunked visits
|
||||
:param time_chunk: If you want to chunk the visit by time
|
||||
:param timeout: e.g. '10s'
|
||||
:param tracelevel: Increase for verbose logs
|
||||
"""
|
||||
# Using index_name of form <namespace>/<document-type>, e.g. "nomic_ai_nomic_embed_text_v1"
|
||||
# This route ends with "/docid/" since the actual ID is not specified — we rely on "selection".
|
||||
path = f"/document/v1/{index_name}/docid/"
|
||||
|
||||
params = {
|
||||
"cluster": cluster,
|
||||
"selection": selection,
|
||||
}
|
||||
|
||||
# Optional parameters
|
||||
if bucket_space is not None:
|
||||
params["bucketSpace"] = bucket_space
|
||||
if continuation is not None:
|
||||
params["continuation"] = continuation
|
||||
if time_chunk is not None:
|
||||
params["timeChunk"] = time_chunk
|
||||
if timeout is not None:
|
||||
params["timeout"] = timeout
|
||||
if tracelevel is not None:
|
||||
params["tracelevel"] = tracelevel # type: ignore
|
||||
|
||||
with get_vespa_http_client() as client:
|
||||
url = f"{VESPA_APPLICATION_ENDPOINT}{path}"
|
||||
logger.info(f"Performing 'delete where' on {url} with selection={selection}...")
|
||||
response = client.delete(url, params=params)
|
||||
# (Optionally, you can keep fetching `continuation` from the JSON response
|
||||
# if you have more documents to delete in chunks.)
|
||||
response.raise_for_status() # will raise HTTPError if not 2xx
|
||||
logger.info(f"Delete where completed with status: {response.status_code}")
|
||||
print(f"Delete where completed with status: {response.status_code}")
|
||||
|
||||
|
||||
def delete_documents_for_tenant(
|
||||
index_name: str,
|
||||
tenant_id: str,
|
||||
route: str | None = None,
|
||||
condition: str | None = None,
|
||||
timeout: str | None = None,
|
||||
tracelevel: int | None = None,
|
||||
count: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
For the given tenant_id and index_name (often in the form <namespace>/<document-type>),
|
||||
find documents via search_for_document, then delete them one at a time using Vespa's
|
||||
/document/v1/<namespace>/<document-type>/docid/<document-id> endpoint.
|
||||
|
||||
:param index_name: Typically <namespace>/<document-type> from your schema
|
||||
:param tenant_id: The tenant to match in your Vespa search
|
||||
:param route: Optional route parameter for delete
|
||||
:param condition: Optional conditional remove
|
||||
:param timeout: e.g. '10s'
|
||||
:param tracelevel: Increase for verbose logs
|
||||
"""
|
||||
deleted_count = 0
|
||||
while True:
|
||||
# Search for documents with the given tenant_id
|
||||
docs = search_for_document(
|
||||
index_name=index_name,
|
||||
document_id=None,
|
||||
tenant_id=tenant_id,
|
||||
max_hits=100, # Fetch in batches of 100
|
||||
)
|
||||
|
||||
if not docs:
|
||||
logger.info("No more documents found to delete.")
|
||||
break
|
||||
|
||||
with get_vespa_http_client() as client:
|
||||
for doc in docs:
|
||||
if count is not None and deleted_count >= count:
|
||||
logger.info(f"Reached maximum delete limit of {count} documents.")
|
||||
return
|
||||
|
||||
fields = doc.get("fields", {})
|
||||
doc_id_value = fields.get("document_id") or fields.get("documentid")
|
||||
tenant_id = fields.get("tenant_id")
|
||||
if tenant_id != tenant_id:
|
||||
raise Exception("Tenant ID mismatch")
|
||||
|
||||
if not doc_id_value:
|
||||
logger.warning(
|
||||
"Skipping a document that has no document_id in 'fields'."
|
||||
)
|
||||
continue
|
||||
|
||||
url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_id_value}"
|
||||
|
||||
params = {}
|
||||
if condition:
|
||||
params["condition"] = condition
|
||||
if route:
|
||||
params["route"] = route
|
||||
if timeout:
|
||||
params["timeout"] = timeout
|
||||
if tracelevel is not None:
|
||||
params["tracelevel"] = str(tracelevel)
|
||||
|
||||
response = client.delete(url, params=params)
|
||||
if response.status_code == 200:
|
||||
logger.info(f"Successfully deleted doc_id={doc_id_value}")
|
||||
deleted_count += 1
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to delete doc_id={doc_id_value}, "
|
||||
f"status={response.status_code}, response={response.text}"
|
||||
)
|
||||
print(
|
||||
f"Could not delete doc_id={doc_id_value}. "
|
||||
f"Status={response.status_code}, response={response.text}"
|
||||
)
|
||||
raise Exception(
|
||||
f"Could not delete doc_id={doc_id_value}. "
|
||||
f"Status={response.status_code}, response={response.text}"
|
||||
)
|
||||
|
||||
logger.info(f"Deleted {deleted_count} documents in total.")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Vespa debugging tool")
|
||||
parser.add_argument(
|
||||
@@ -612,6 +772,7 @@ def main() -> None:
|
||||
"update",
|
||||
"delete",
|
||||
"get_acls",
|
||||
"delete-all-documents",
|
||||
],
|
||||
required=True,
|
||||
help="Action to perform",
|
||||
@@ -626,11 +787,20 @@ def main() -> None:
|
||||
parser.add_argument(
|
||||
"--fields", help="Fields to update, in JSON format (for update)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--count",
|
||||
type=int,
|
||||
help="Maximum number of documents to delete (for delete-all-documents)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
vespa_debug = VespaDebugging(args.tenant_id)
|
||||
|
||||
if args.action == "config":
|
||||
if args.action == "delete-all-documents":
|
||||
if not args.tenant_id:
|
||||
parser.error("--tenant-id is required for delete-all-documents action")
|
||||
vespa_debug.delete_documents_for_tenant(count=args.count)
|
||||
elif args.action == "config":
|
||||
vespa_debug.print_config()
|
||||
elif args.action == "connect":
|
||||
vespa_debug.check_connectivity()
|
||||
|
||||
@@ -34,11 +34,11 @@ def test_jira_connector_basic(jira_connector: JiraConnector) -> None:
|
||||
doc = doc_batch[0]
|
||||
|
||||
assert doc.id == "https://danswerai.atlassian.net/browse/AS-2"
|
||||
assert doc.semantic_identifier == "test123small"
|
||||
assert doc.semantic_identifier == "AS-2: test123small"
|
||||
assert doc.source == DocumentSource.JIRA
|
||||
assert doc.metadata == {"priority": "Medium", "status": "Backlog"}
|
||||
assert doc.secondary_owners is None
|
||||
assert doc.title is None
|
||||
assert doc.title == "AS-2 test123small"
|
||||
assert doc.from_ingestion_api is False
|
||||
assert doc.additional_info is None
|
||||
|
||||
|
||||
@@ -23,6 +23,9 @@ def bedrock_provider() -> WellKnownLLMProviderDescriptor:
|
||||
return provider
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="Credentials not yet available due to compliance work needed",
|
||||
)
|
||||
def test_bedrock_llm_configuration(
|
||||
client: TestClient, bedrock_provider: WellKnownLLMProviderDescriptor
|
||||
) -> None:
|
||||
|
||||
@@ -84,6 +84,9 @@ ENV NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=${NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED}
|
||||
ARG NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK
|
||||
ENV NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=${NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK}
|
||||
|
||||
ARG NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY
|
||||
ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY}
|
||||
|
||||
# Use NODE_OPTIONS in the build command
|
||||
RUN NODE_OPTIONS="${NODE_OPTIONS}" npx next build
|
||||
|
||||
@@ -145,7 +148,6 @@ ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT}
|
||||
ARG NEXT_PUBLIC_CUSTOM_REFRESH_URL
|
||||
ENV NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL}
|
||||
|
||||
|
||||
ARG NEXT_PUBLIC_POSTHOG_KEY
|
||||
ARG NEXT_PUBLIC_POSTHOG_HOST
|
||||
ENV NEXT_PUBLIC_POSTHOG_KEY=${NEXT_PUBLIC_POSTHOG_KEY}
|
||||
@@ -166,6 +168,9 @@ ENV NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=${NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED}
|
||||
ARG NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK
|
||||
ENV NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=${NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK}
|
||||
|
||||
ARG NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY
|
||||
ENV NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY}
|
||||
|
||||
# Note: Don't expose ports here, Compose will handle that for us if necessary.
|
||||
# If you want to run this without compose, specify the ports to
|
||||
# expose via cli
|
||||
|
||||
@@ -502,9 +502,10 @@ export default function AddConnector({
|
||||
{oauthSupportedSources.includes(connector) &&
|
||||
(NEXT_PUBLIC_CLOUD_ENABLED ||
|
||||
NEXT_PUBLIC_TEST_ENV) && (
|
||||
<button
|
||||
<Button
|
||||
variant="navigate"
|
||||
onClick={handleAuthorize}
|
||||
className="mt-6 text-sm bg-blue-500 px-2 py-1.5 flex text-text-200 flex-none rounded"
|
||||
className="mt-6 "
|
||||
disabled={isAuthorizing}
|
||||
hidden={!isAuthorizeVisible}
|
||||
>
|
||||
@@ -513,7 +514,7 @@ export default function AddConnector({
|
||||
: `Authorize with ${getSourceDisplayName(
|
||||
connector
|
||||
)}`}
|
||||
</button>
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -5,7 +5,7 @@ import useSWR, { mutate } from "swr";
|
||||
import { FetchError, errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { usePublicCredentials } from "@/lib/hooks";
|
||||
import Title from "@/components/ui/title";
|
||||
@@ -32,7 +32,11 @@ const useConnectorsByCredentialId = (credential_id: number | null) => {
|
||||
};
|
||||
};
|
||||
|
||||
const GDriveMain = ({}: {}) => {
|
||||
const GDriveMain = ({
|
||||
setPopup,
|
||||
}: {
|
||||
setPopup: (popup: PopupSpec | null) => void;
|
||||
}) => {
|
||||
const { isAdmin, user } = useUser();
|
||||
|
||||
// tries getting the uploaded credential json
|
||||
@@ -97,8 +101,6 @@ const GDriveMain = ({}: {}) => {
|
||||
refreshConnectorsByCredentialId,
|
||||
} = useConnectorsByCredentialId(credential_id);
|
||||
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
const appCredentialSuccessfullyFetched =
|
||||
appCredentialData ||
|
||||
(isAppCredentialError && isAppCredentialError.status === 404);
|
||||
@@ -173,10 +175,7 @@ const GDriveMain = ({}: {}) => {
|
||||
|
||||
return (
|
||||
<>
|
||||
{popup}
|
||||
<Title className="mb-2 mt-6 ml-auto mr-auto">
|
||||
Step 1: Provide your Credentials
|
||||
</Title>
|
||||
<Title className="mb-2 mt-6">Step 1: Provide your Credentials</Title>
|
||||
<DriveJsonUploadSection
|
||||
setPopup={setPopup}
|
||||
appCredentialData={appCredentialData}
|
||||
@@ -186,9 +185,7 @@ const GDriveMain = ({}: {}) => {
|
||||
|
||||
{isAdmin && (
|
||||
<>
|
||||
<Title className="mb-2 mt-6 ml-auto mr-auto">
|
||||
Step 2: Authenticate with Onyx
|
||||
</Title>
|
||||
<Title className="mb-2 mt-6">Step 2: Authenticate with Onyx</Title>
|
||||
<DriveAuthSection
|
||||
setPopup={setPopup}
|
||||
refreshCredentials={refreshCredentials}
|
||||
|
||||
@@ -188,7 +188,7 @@ export const DocumentSetCreationForm = ({
|
||||
flex
|
||||
cursor-pointer ` +
|
||||
(isSelected
|
||||
? " bg-background-strong"
|
||||
? " bg-background-200"
|
||||
: " hover:bg-accent-background-hovered")
|
||||
}
|
||||
onClick={() => {
|
||||
@@ -304,7 +304,7 @@ export const DocumentSetCreationForm = ({
|
||||
flex
|
||||
cursor-pointer ` +
|
||||
(isSelected
|
||||
? " bg-background-strong"
|
||||
? " bg-background-200"
|
||||
: " hover:bg-accent-background-hovered")
|
||||
}
|
||||
onClick={() => {
|
||||
|
||||
@@ -235,8 +235,8 @@ export function EmbeddingModelSelection({
|
||||
onClick={() => setModelTab(null)}
|
||||
className={`mr-4 p-2 font-bold ${
|
||||
!modelTab
|
||||
? "rounded bg-background-900 dark:bg-neutral-900 text-text-100 dark:text-neutral-100 underline"
|
||||
: " hover:underline bg-background-100 dark:bg-neutral-700"
|
||||
? "rounded bg-neutral-900 dark:bg-neutral-950 text-neutral-100 dark:text-neutral-300 underline"
|
||||
: " hover:underline bg-neutral-100 dark:bg-neutral-900"
|
||||
}`}
|
||||
>
|
||||
Current
|
||||
@@ -246,8 +246,8 @@ export function EmbeddingModelSelection({
|
||||
onClick={() => setModelTab("cloud")}
|
||||
className={`mx-2 p-2 font-bold ${
|
||||
modelTab == "cloud"
|
||||
? "rounded bg-background-900 dark:bg-neutral-900 text-text-100 dark:text-neutral-100 underline"
|
||||
: " hover:underline bg-background-100 dark:bg-neutral-700"
|
||||
? "rounded bg-neutral-900 dark:bg-neutral-950 text-neutral-100 dark:text-neutral-300 underline"
|
||||
: " hover:underline bg-neutral-100 dark:bg-neutral-900"
|
||||
}`}
|
||||
>
|
||||
Cloud-based
|
||||
@@ -258,8 +258,8 @@ export function EmbeddingModelSelection({
|
||||
onClick={() => setModelTab("open")}
|
||||
className={` mx-2 p-2 font-bold ${
|
||||
modelTab == "open"
|
||||
? "rounded bg-background-900 dark:bg-neutral-900 text-text-100 dark:text-neutral-100 underline"
|
||||
: "hover:underline bg-background-100 dark:bg-neutral-700"
|
||||
? "rounded bg-neutral-900 dark:bg-neutral-950 text-neutral-100 dark:text-neutral-300 underline"
|
||||
: "hover:underline bg-neutral-100 dark:bg-neutral-900"
|
||||
}`}
|
||||
>
|
||||
Self-hosted
|
||||
|
||||
@@ -116,8 +116,8 @@ const RerankingDetailsForm = forwardRef<
|
||||
onClick={() => setModelTab("cloud")}
|
||||
className={`mr-2 p-2 font-bold ${
|
||||
modelTab == "cloud"
|
||||
? "rounded bg-background-900 text-text-100 underline"
|
||||
: " hover:underline bg-background-100"
|
||||
? "rounded bg-neutral-900 dark:bg-neutral-950 text-neutral-100 dark:text-neutral-300 underline"
|
||||
: " hover:underline bg-neutral-100 dark:bg-neutral-900"
|
||||
}`}
|
||||
>
|
||||
Cloud-based
|
||||
@@ -129,8 +129,8 @@ const RerankingDetailsForm = forwardRef<
|
||||
onClick={() => setModelTab("open")}
|
||||
className={` mx-2 p-2 font-bold ${
|
||||
modelTab == "open"
|
||||
? "rounded bg-background-900 text-text-100 underline"
|
||||
: "hover:underline bg-background-100"
|
||||
? "rounded bg-neutral-900 dark:bg-neutral-950 text-neutral-100 dark:text-neutral-300 underline"
|
||||
: "hover:underline bg-neutral-100 dark:bg-neutral-900"
|
||||
}`}
|
||||
>
|
||||
Self-hosted
|
||||
@@ -140,7 +140,7 @@ const RerankingDetailsForm = forwardRef<
|
||||
<div className="px-2">
|
||||
<button
|
||||
onClick={() => resetRerankingValues()}
|
||||
className="mx-2 p-2 font-bold rounded bg-background-100 text-text-900 hover:underline"
|
||||
className={`mx-2 p-2 font-bold rounded bg-neutral-100 dark:bg-neutral-900 text-neutral-900 dark:text-neutral-100 hover:underline`}
|
||||
>
|
||||
Remove Reranking
|
||||
</button>
|
||||
@@ -177,7 +177,7 @@ const RerankingDetailsForm = forwardRef<
|
||||
key={`${card.rerank_provider_type}-${card.modelName}`}
|
||||
className={`p-4 border rounded-lg cursor-pointer transition-all duration-200 ${
|
||||
isSelected
|
||||
? "border-blue-500 bg-blue-50 dark:bg-blue-900 dark:border-blue-700 shadow-md"
|
||||
? "border-blue-800 bg-blue-50 dark:bg-blue-950 dark:border-blue-700 shadow-md"
|
||||
: "border-background-200 hover:border-blue-300 hover:shadow-sm dark:border-neutral-700 dark:hover:border-blue-300"
|
||||
}`}
|
||||
onClick={() => {
|
||||
|
||||
@@ -322,7 +322,7 @@ export default function CloudEmbeddingPage({
|
||||
OpenAI for embeddings.
|
||||
</Text>
|
||||
<div className="flex items-center text-sm text-text-700">
|
||||
<FiInfo className="text-text-400 mr-2" size={16} />
|
||||
<FiInfo className="text-neutral-400 mr-2" size={16} />
|
||||
<Text>
|
||||
You'll need: API version, base URL, API key, model
|
||||
name, and deployment name.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
export enum GatingType {
|
||||
FULL = "full",
|
||||
PARTIAL = "partial",
|
||||
NONE = "none",
|
||||
export enum ApplicationStatus {
|
||||
PAYMENT_REMINDER = "payment_reminder",
|
||||
GATED_ACCESS = "gated_access",
|
||||
ACTIVE = "active",
|
||||
}
|
||||
|
||||
export interface Settings {
|
||||
@@ -11,7 +11,7 @@ export interface Settings {
|
||||
needs_reindexing: boolean;
|
||||
gpu_enabled: boolean;
|
||||
pro_search_disabled: boolean | null;
|
||||
product_gating: GatingType;
|
||||
application_status: ApplicationStatus;
|
||||
auto_scroll: boolean;
|
||||
}
|
||||
|
||||
|
||||
@@ -1627,7 +1627,7 @@ export function ChatPage({
|
||||
second_level_message: second_level_answer,
|
||||
type: error ? "error" : "assistant",
|
||||
retrievalType,
|
||||
query: finalMessage?.rephrased_query,
|
||||
query: finalMessage?.rephrased_query || query,
|
||||
documents: documents,
|
||||
citations: finalMessage?.citations || {},
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
@@ -2291,8 +2291,6 @@ export function ChatPage({
|
||||
bg-opacity-80
|
||||
duration-300
|
||||
ease-in-out
|
||||
|
||||
|
||||
${
|
||||
!untoggled && (showHistorySidebar || sidebarVisible)
|
||||
? "opacity-100 w-[250px] translate-x-0"
|
||||
@@ -3099,6 +3097,7 @@ export function ChatPage({
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="pointer-events-auto w-[95%] mx-auto relative mb-8">
|
||||
<ChatInputBar
|
||||
proSearchEnabled={proSearchEnabled}
|
||||
|
||||
@@ -86,19 +86,20 @@ export function AgenticToggle({
|
||||
</TooltipTrigger>
|
||||
<TooltipContent
|
||||
side="top"
|
||||
className="w-72 p-4 bg-white rounded-lg shadow-lg border border-background-200 dark:border-neutral-900"
|
||||
width="w-72"
|
||||
className="p-4 bg-white rounded-lg shadow-lg border border-background-200 dark:border-neutral-900"
|
||||
>
|
||||
<div className="flex items-center space-x-2 mb-3">
|
||||
<h3 className="text-sm font-semibold text-text-900">
|
||||
<h3 className="text-sm font-semibold text-neutral-900">
|
||||
Agent Search (BETA)
|
||||
</h3>
|
||||
</div>
|
||||
<p className="text-xs text-text-600 mb-2">
|
||||
<p className="text-xs text-neutarl-600 dark:text-neutral-700 mb-2">
|
||||
Use AI agents to break down questions and run deep iterative
|
||||
research through promising pathways. Gives more thorough and
|
||||
accurate responses but takes slightly longer.
|
||||
</p>
|
||||
<ul className="text-xs text-text-600 list-disc list-inside">
|
||||
<ul className="text-xs text-text-600 dark:text-neutral-700 list-disc list-inside">
|
||||
<li>Improved accuracy of search results</li>
|
||||
<li>Less hallucinations</li>
|
||||
<li>More comprehensive answers</li>
|
||||
|
||||
@@ -974,7 +974,7 @@ export const HumanMessage = ({
|
||||
py-2
|
||||
px-3
|
||||
w-fit
|
||||
bg-background-strong
|
||||
bg-background-200
|
||||
text-sm
|
||||
rounded-lg
|
||||
hover:bg-accent-background-hovered-emphasis
|
||||
|
||||
73
web/src/app/ee/admin/billing/BillingAlerts.tsx
Normal file
73
web/src/app/ee/admin/billing/BillingAlerts.tsx
Normal file
@@ -0,0 +1,73 @@
|
||||
import React from "react";
|
||||
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
||||
import { CircleAlert, Info } from "lucide-react";
|
||||
import { BillingInformation, BillingStatus } from "./interfaces";
|
||||
|
||||
export function BillingAlerts({
|
||||
billingInformation,
|
||||
}: {
|
||||
billingInformation: BillingInformation;
|
||||
}) {
|
||||
const isTrialing = billingInformation.status === BillingStatus.TRIALING;
|
||||
const isCancelled = billingInformation.cancel_at_period_end;
|
||||
const isExpired =
|
||||
new Date(billingInformation.current_period_end) < new Date();
|
||||
const noPaymentMethod = !billingInformation.payment_method_enabled;
|
||||
|
||||
const messages: string[] = [];
|
||||
|
||||
if (isExpired) {
|
||||
messages.push(
|
||||
"Your subscription has expired. Please resubscribe to continue using the service."
|
||||
);
|
||||
}
|
||||
if (isCancelled && !isExpired) {
|
||||
messages.push(
|
||||
`Your subscription will cancel on ${new Date(
|
||||
billingInformation.current_period_end
|
||||
).toLocaleDateString()}. You can resubscribe before this date to remain uninterrupted.`
|
||||
);
|
||||
}
|
||||
if (isTrialing) {
|
||||
messages.push(
|
||||
`You're currently on a trial. Your trial ends on ${
|
||||
billingInformation.trial_end
|
||||
? new Date(billingInformation.trial_end).toLocaleDateString()
|
||||
: "N/A"
|
||||
}.`
|
||||
);
|
||||
}
|
||||
if (noPaymentMethod) {
|
||||
messages.push(
|
||||
"You currently have no payment method on file. Please add one to avoid service interruption."
|
||||
);
|
||||
}
|
||||
|
||||
const variant = isExpired || noPaymentMethod ? "destructive" : "default";
|
||||
|
||||
if (messages.length === 0) return null;
|
||||
|
||||
return (
|
||||
<Alert variant={variant}>
|
||||
<AlertTitle className="flex items-center space-x-2">
|
||||
{variant === "destructive" ? (
|
||||
<CircleAlert className="h-4 w-4" />
|
||||
) : (
|
||||
<Info className="h-4 w-4" />
|
||||
)}
|
||||
<span>
|
||||
{variant === "destructive"
|
||||
? "Important Subscription Notice"
|
||||
: "Subscription Notice"}
|
||||
</span>
|
||||
</AlertTitle>
|
||||
<AlertDescription>
|
||||
<ul className="list-disc list-inside space-y-1 mt-2">
|
||||
{messages.map((msg, idx) => (
|
||||
<li key={idx}>{msg}</li>
|
||||
))}
|
||||
</ul>
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
);
|
||||
}
|
||||
@@ -1,18 +1,21 @@
|
||||
"use client";
|
||||
|
||||
import { CreditCard, ArrowFatUp } from "@phosphor-icons/react";
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { loadStripe } from "@stripe/stripe-js";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { SettingsIcon } from "@/components/icons/icons";
|
||||
import {
|
||||
updateSubscriptionQuantity,
|
||||
fetchCustomerPortal,
|
||||
statusToDisplay,
|
||||
useBillingInformation,
|
||||
} from "./utils";
|
||||
import { useEffect } from "react";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { fetchCustomerPortal, useBillingInformation } from "./utils";
|
||||
|
||||
import {
|
||||
Card,
|
||||
CardContent,
|
||||
CardDescription,
|
||||
CardHeader,
|
||||
CardTitle,
|
||||
} from "@/components/ui/card";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { CreditCard, ArrowFatUp } from "@phosphor-icons/react";
|
||||
import { SubscriptionSummary } from "./SubscriptionSummary";
|
||||
import { BillingAlerts } from "./BillingAlerts";
|
||||
|
||||
export default function BillingInformationPage() {
|
||||
const router = useRouter();
|
||||
@@ -24,9 +27,6 @@ export default function BillingInformationPage() {
|
||||
isLoading,
|
||||
} = useBillingInformation();
|
||||
|
||||
if (error) {
|
||||
console.error("Failed to fetch billing information:", error);
|
||||
}
|
||||
useEffect(() => {
|
||||
const url = new URL(window.location.href);
|
||||
if (url.searchParams.has("session_id")) {
|
||||
@@ -35,22 +35,33 @@ export default function BillingInformationPage() {
|
||||
"Congratulations! Your subscription has been updated successfully.",
|
||||
type: "success",
|
||||
});
|
||||
// Remove the session_id from the URL
|
||||
url.searchParams.delete("session_id");
|
||||
window.history.replaceState({}, "", url.toString());
|
||||
// You might want to refresh the billing information here
|
||||
// by calling an API endpoint to get the latest data
|
||||
}
|
||||
}, [setPopup]);
|
||||
|
||||
if (isLoading) {
|
||||
return <div>Loading...</div>;
|
||||
return <div className="text-center py-8">Loading...</div>;
|
||||
}
|
||||
|
||||
if (error) {
|
||||
console.error("Failed to fetch billing information:", error);
|
||||
return (
|
||||
<div className="text-center py-8 text-red-500">
|
||||
Error loading billing information. Please try again later.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!billingInformation) {
|
||||
return (
|
||||
<div className="text-center py-8">No billing information available.</div>
|
||||
);
|
||||
}
|
||||
|
||||
const handleManageSubscription = async () => {
|
||||
try {
|
||||
const response = await fetchCustomerPortal();
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json();
|
||||
throw new Error(
|
||||
@@ -61,11 +72,9 @@ export default function BillingInformationPage() {
|
||||
}
|
||||
|
||||
const { url } = await response.json();
|
||||
|
||||
if (!url) {
|
||||
throw new Error("No portal URL returned from the server");
|
||||
}
|
||||
|
||||
router.push(url);
|
||||
} catch (error) {
|
||||
console.error("Error creating customer portal session:", error);
|
||||
@@ -75,138 +84,39 @@ export default function BillingInformationPage() {
|
||||
});
|
||||
}
|
||||
};
|
||||
if (!billingInformation) {
|
||||
return <div>Loading...</div>;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="space-y-8">
|
||||
<div className="bg-background-50 rounded-lg p-8 border border-background-200">
|
||||
{popup}
|
||||
{popup}
|
||||
<Card className="shadow-md">
|
||||
<CardHeader>
|
||||
<CardTitle className="text-2xl font-bold flex items-center">
|
||||
<CreditCard className="mr-4 text-muted-foreground" size={24} />
|
||||
Subscription Details
|
||||
</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent className="space-y-6">
|
||||
<SubscriptionSummary billingInformation={billingInformation} />
|
||||
<BillingAlerts billingInformation={billingInformation} />
|
||||
</CardContent>
|
||||
</Card>
|
||||
|
||||
<h2 className="text-2xl font-bold mb-6 text-text-800 flex items-center">
|
||||
{/* <CreditCard className="mr-4 text-text-600" size={24} /> */}
|
||||
Subscription Details
|
||||
</h2>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div className="bg-white p-5 rounded-lg shadow-sm transition-all duration-300 hover:shadow-md">
|
||||
<div className="flex justify-between items-center">
|
||||
<div>
|
||||
<p className="text-lg font-medium text-text-700">Seats</p>
|
||||
<p className="text-sm text-text-500">
|
||||
Number of licensed users
|
||||
</p>
|
||||
</div>
|
||||
<p className="text-xl font-semibold text-text-900">
|
||||
{billingInformation.seats}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="bg-white p-5 rounded-lg shadow-sm transition-all duration-300 hover:shadow-md">
|
||||
<div className="flex justify-between items-center">
|
||||
<div>
|
||||
<p className="text-lg font-medium text-text-700">
|
||||
Subscription Status
|
||||
</p>
|
||||
<p className="text-sm text-text-500">
|
||||
Current state of your subscription
|
||||
</p>
|
||||
</div>
|
||||
<p className="text-xl font-semibold text-text-900">
|
||||
{statusToDisplay(billingInformation.subscription_status)}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="bg-white p-5 rounded-lg shadow-sm transition-all duration-300 hover:shadow-md">
|
||||
<div className="flex justify-between items-center">
|
||||
<div>
|
||||
<p className="text-lg font-medium text-text-700">
|
||||
Billing Start
|
||||
</p>
|
||||
<p className="text-sm text-text-500">
|
||||
Start date of current billing cycle
|
||||
</p>
|
||||
</div>
|
||||
<p className="text-xl font-semibold text-text-900">
|
||||
{new Date(
|
||||
billingInformation.billing_start
|
||||
).toLocaleDateString()}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="bg-white p-5 rounded-lg shadow-sm transition-all duration-300 hover:shadow-md">
|
||||
<div className="flex justify-between items-center">
|
||||
<div>
|
||||
<p className="text-lg font-medium text-text-700">Billing End</p>
|
||||
<p className="text-sm text-text-500">
|
||||
End date of current billing cycle
|
||||
</p>
|
||||
</div>
|
||||
<p className="text-xl font-semibold text-text-900">
|
||||
{new Date(billingInformation.billing_end).toLocaleDateString()}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{!billingInformation.payment_method_enabled && (
|
||||
<div className="mt-4 p-4 bg-yellow-100 border-l-4 border-yellow-500 text-yellow-700">
|
||||
<p className="font-bold">Notice:</p>
|
||||
<p>
|
||||
You'll need to add a payment method before your trial ends to
|
||||
continue using the service.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{billingInformation.subscription_status === "trialing" ? (
|
||||
<div className="bg-white p-5 rounded-lg shadow-sm transition-all duration-300 hover:shadow-md mt-8">
|
||||
<p className="text-lg font-medium text-text-700">
|
||||
No cap on users during trial
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex items-center space-x-4 mt-8">
|
||||
<div className="flex items-center space-x-4">
|
||||
<p className="text-lg font-medium text-text-700">
|
||||
Current Seats:
|
||||
</p>
|
||||
<p className="text-xl font-semibold text-text-900">
|
||||
{billingInformation.seats}
|
||||
</p>
|
||||
</div>
|
||||
<p className="text-sm text-text-500">
|
||||
Seats automatically update based on adding, removing, or inviting
|
||||
users.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="bg-white p-5 rounded-lg shadow-sm transition-all duration-300 hover:shadow-md">
|
||||
<div className="flex justify-between items-center mb-4">
|
||||
<div>
|
||||
<p className="text-lg font-medium text-text-700">
|
||||
Manage Subscription
|
||||
</p>
|
||||
<p className="text-sm text-text-500">
|
||||
View your plan, update payment, or change subscription
|
||||
</p>
|
||||
</div>
|
||||
<SettingsIcon className="text-text-600" size={20} />
|
||||
</div>
|
||||
<button
|
||||
onClick={handleManageSubscription}
|
||||
className="bg-background-600 text-white px-4 py-2 rounded-md hover:bg-background-700 transition duration-300 ease-in-out focus:outline-none focus:ring-2 focus:ring-text-500 focus:ring-opacity-50 font-medium shadow-sm text-sm flex items-center justify-center"
|
||||
>
|
||||
<ArrowFatUp className="mr-2" size={16} />
|
||||
Manage Subscription
|
||||
</button>
|
||||
</div>
|
||||
<Card className="shadow-md">
|
||||
<CardHeader>
|
||||
<CardTitle className="text-xl font-semibold">
|
||||
Manage Subscription
|
||||
</CardTitle>
|
||||
<CardDescription>
|
||||
View your plan, update payment, or change subscription
|
||||
</CardDescription>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
<Button onClick={handleManageSubscription} className="w-full">
|
||||
<ArrowFatUp className="mr-2" size={16} />
|
||||
Manage Subscription
|
||||
</Button>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
17
web/src/app/ee/admin/billing/InfoItem.tsx
Normal file
17
web/src/app/ee/admin/billing/InfoItem.tsx
Normal file
@@ -0,0 +1,17 @@
|
||||
import React from "react";
|
||||
|
||||
interface InfoItemProps {
|
||||
title: string;
|
||||
value: string;
|
||||
}
|
||||
|
||||
export function InfoItem({ title, value }: InfoItemProps) {
|
||||
return (
|
||||
<div className="bg-muted p-4 rounded-lg">
|
||||
<p className="text-sm font-medium text-muted-foreground mb-1">{title}</p>
|
||||
<p className="text-lg font-semibold text-foreground dark:text-white">
|
||||
{value}
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
33
web/src/app/ee/admin/billing/SubscriptionSummary.tsx
Normal file
33
web/src/app/ee/admin/billing/SubscriptionSummary.tsx
Normal file
@@ -0,0 +1,33 @@
|
||||
import React from "react";
|
||||
import { InfoItem } from "./InfoItem";
|
||||
import { statusToDisplay } from "./utils";
|
||||
|
||||
interface SubscriptionSummaryProps {
|
||||
billingInformation: any;
|
||||
}
|
||||
|
||||
export function SubscriptionSummary({
|
||||
billingInformation,
|
||||
}: SubscriptionSummaryProps) {
|
||||
return (
|
||||
<div className="grid grid-cols-2 gap-4">
|
||||
<InfoItem
|
||||
title="Subscription Status"
|
||||
value={statusToDisplay(billingInformation.status)}
|
||||
/>
|
||||
<InfoItem title="Seats" value={billingInformation.seats.toString()} />
|
||||
<InfoItem
|
||||
title="Billing Start"
|
||||
value={new Date(
|
||||
billingInformation.current_period_start
|
||||
).toLocaleDateString()}
|
||||
/>
|
||||
<InfoItem
|
||||
title="Billing End"
|
||||
value={new Date(
|
||||
billingInformation.current_period_end
|
||||
).toLocaleDateString()}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
19
web/src/app/ee/admin/billing/interfaces.ts
Normal file
19
web/src/app/ee/admin/billing/interfaces.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
export interface BillingInformation {
|
||||
status: string;
|
||||
trial_end: Date | null;
|
||||
current_period_end: Date;
|
||||
payment_method_enabled: boolean;
|
||||
cancel_at_period_end: boolean;
|
||||
current_period_start: Date;
|
||||
number_of_seats: number;
|
||||
canceled_at: Date | null;
|
||||
trial_start: Date | null;
|
||||
seats: number;
|
||||
}
|
||||
|
||||
export enum BillingStatus {
|
||||
TRIALING = "trialing",
|
||||
ACTIVE = "active",
|
||||
CANCELLED = "cancelled",
|
||||
EXPIRED = "expired",
|
||||
}
|
||||
@@ -3,10 +3,16 @@ import BillingInformationPage from "./BillingInformationPage";
|
||||
import { MdOutlineCreditCard } from "react-icons/md";
|
||||
|
||||
export interface BillingInformation {
|
||||
stripe_subscription_id: string;
|
||||
status: string;
|
||||
current_period_start: Date;
|
||||
current_period_end: Date;
|
||||
number_of_seats: number;
|
||||
cancel_at_period_end: boolean;
|
||||
canceled_at: Date | null;
|
||||
trial_start: Date | null;
|
||||
trial_end: Date | null;
|
||||
seats: number;
|
||||
subscription_status: string;
|
||||
billing_start: Date;
|
||||
billing_end: Date;
|
||||
payment_method_enabled: boolean;
|
||||
}
|
||||
|
||||
|
||||
@@ -35,9 +35,16 @@ export const statusToDisplay = (status: string) => {
|
||||
|
||||
export const useBillingInformation = () => {
|
||||
const url = "/api/tenants/billing-information";
|
||||
const swrResponse = useSWR<BillingInformation>(url, (url: string) =>
|
||||
fetch(url).then((res) => res.json())
|
||||
);
|
||||
const swrResponse = useSWR<BillingInformation>(url, async (url: string) => {
|
||||
const res = await fetch(url);
|
||||
if (!res.ok) {
|
||||
const errorData = await res.json();
|
||||
throw new Error(
|
||||
errorData.message || "Failed to fetch billing information"
|
||||
);
|
||||
}
|
||||
return res.json();
|
||||
});
|
||||
|
||||
return {
|
||||
...swrResponse,
|
||||
|
||||
@@ -660,11 +660,11 @@ ul > li > p {
|
||||
color: white;
|
||||
}
|
||||
|
||||
.dark li,
|
||||
.dark h1,
|
||||
.dark h2,
|
||||
.dark h3,
|
||||
.dark h4,
|
||||
.dark h5 {
|
||||
.prose.dark li,
|
||||
.prose.dark h1,
|
||||
.prose.dark h2,
|
||||
.prose.dark h3,
|
||||
.prose.dark h4,
|
||||
.prose.dark h5 {
|
||||
color: #e5e5e5;
|
||||
}
|
||||
|
||||
@@ -13,7 +13,10 @@ import {
|
||||
import { Metadata } from "next";
|
||||
import { buildClientUrl } from "@/lib/utilsSS";
|
||||
import { Inter } from "next/font/google";
|
||||
import { EnterpriseSettings, GatingType } from "./admin/settings/interfaces";
|
||||
import {
|
||||
EnterpriseSettings,
|
||||
ApplicationStatus,
|
||||
} from "./admin/settings/interfaces";
|
||||
import { fetchAssistantData } from "@/lib/chat/fetchAssistantdata";
|
||||
import { AppProvider } from "@/components/context/AppProvider";
|
||||
import { PHProvider } from "./providers";
|
||||
@@ -28,6 +31,7 @@ import { WebVitals } from "./web-vitals";
|
||||
import { ThemeProvider } from "next-themes";
|
||||
import CloudError from "@/components/errorPages/CloudErrorPage";
|
||||
import Error from "@/components/errorPages/ErrorPage";
|
||||
import AccessRestrictedPage from "@/components/errorPages/AccessRestrictedPage";
|
||||
|
||||
const inter = Inter({
|
||||
subsets: ["latin"],
|
||||
@@ -75,7 +79,7 @@ export default async function RootLayout({
|
||||
]);
|
||||
|
||||
const productGating =
|
||||
combinedSettings?.settings.product_gating ?? GatingType.NONE;
|
||||
combinedSettings?.settings.application_status ?? ApplicationStatus.ACTIVE;
|
||||
|
||||
const getPageContent = async (content: React.ReactNode) => (
|
||||
<html
|
||||
@@ -130,40 +134,16 @@ export default async function RootLayout({
|
||||
</html>
|
||||
);
|
||||
|
||||
if (productGating === ApplicationStatus.GATED_ACCESS) {
|
||||
return getPageContent(<AccessRestrictedPage />);
|
||||
}
|
||||
|
||||
if (!combinedSettings) {
|
||||
return getPageContent(
|
||||
NEXT_PUBLIC_CLOUD_ENABLED ? <CloudError /> : <Error />
|
||||
);
|
||||
}
|
||||
|
||||
if (productGating === GatingType.FULL) {
|
||||
return getPageContent(
|
||||
<div className="flex flex-col items-center justify-center min-h-screen">
|
||||
<div className="mb-2 flex items-center max-w-[175px]">
|
||||
<LogoType />
|
||||
</div>
|
||||
<CardSection className="w-full max-w-md">
|
||||
<h1 className="text-2xl font-bold mb-4 text-error">
|
||||
Access Restricted
|
||||
</h1>
|
||||
<p className="text-text-500 mb-4">
|
||||
We regret to inform you that your access to Onyx has been
|
||||
temporarily suspended due to a lapse in your subscription.
|
||||
</p>
|
||||
<p className="text-text-500 mb-4">
|
||||
To reinstate your access and continue benefiting from Onyx's
|
||||
powerful features, please update your payment information.
|
||||
</p>
|
||||
<p className="text-text-500">
|
||||
If you're an admin, you can resolve this by visiting the
|
||||
billing section. For other users, please reach out to your
|
||||
administrator to address this matter.
|
||||
</p>
|
||||
</CardSection>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
const { assistants, hasAnyConnectors, hasImageCompatibleModel } =
|
||||
assistantsData;
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ export function EditableValue({
|
||||
onSubmit(initialValue);
|
||||
}
|
||||
}}
|
||||
className="border bg-background-strong border-background-300 rounded py-1 px-1 w-12 h-4 my-auto"
|
||||
className="border bg-background-200 border-background-300 rounded py-1 px-1 w-12 h-4 my-auto"
|
||||
/>
|
||||
<div
|
||||
onClick={async () => {
|
||||
|
||||
@@ -151,7 +151,7 @@ export const IsPublicGroupSelector = <T extends IsPublicGroupSelectorFormType>({
|
||||
cursor-pointer
|
||||
${
|
||||
isSelected
|
||||
? "bg-background-strong"
|
||||
? "bg-background-200"
|
||||
: "hover:bg-accent-background-hovered"
|
||||
}
|
||||
`}
|
||||
|
||||
@@ -67,7 +67,7 @@ const PageLink = ({
|
||||
first:ml-0
|
||||
first:rounded-l-md
|
||||
last:rounded-r-md
|
||||
${active ? "bg-background-strong" : ""}
|
||||
${active ? "bg-background-200" : ""}
|
||||
`}
|
||||
onClick={() => {
|
||||
if (pageChangeHandler) {
|
||||
|
||||
@@ -33,6 +33,9 @@ import { MdOutlineCreditCard } from "react-icons/md";
|
||||
import { UserSettingsModal } from "@/app/chat/modal/UserSettingsModal";
|
||||
import { usePopup } from "./connectors/Popup";
|
||||
import { useChatContext } from "../context/ChatContext";
|
||||
import { ApplicationStatus } from "@/app/admin/settings/interfaces";
|
||||
import Link from "next/link";
|
||||
import { Button } from "../ui/button";
|
||||
|
||||
export function ClientLayout({
|
||||
user,
|
||||
@@ -74,6 +77,23 @@ export function ClientLayout({
|
||||
defaultModel={user?.preferences?.default_model!}
|
||||
/>
|
||||
)}
|
||||
{settings?.settings.application_status ===
|
||||
ApplicationStatus.PAYMENT_REMINDER && (
|
||||
<div className="fixed top-2 left-1/2 transform -translate-x-1/2 bg-amber-400 dark:bg-amber-500 text-gray-900 dark:text-gray-100 p-4 rounded-lg shadow-lg z-50 max-w-md text-center">
|
||||
<strong className="font-bold">Warning:</strong> Your trial ends in
|
||||
less than 2 days and no payment method has been added.
|
||||
<div className="mt-2">
|
||||
<Link href="/admin/billing">
|
||||
<Button
|
||||
variant="default"
|
||||
className="bg-amber-600 hover:bg-amber-700 text-white"
|
||||
>
|
||||
Update Billing Information
|
||||
</Button>
|
||||
</Link>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="default-scrollbar flex-none text-text-settings-sidebar bg-background-sidebar dark:bg-[#000] w-[250px] overflow-x-hidden z-20 pt-2 pb-8 h-full border-r border-border dark:border-none miniscroll overflow-auto">
|
||||
<AdminSidebar
|
||||
|
||||
@@ -151,7 +151,7 @@ export function AccessTypeGroupSelector({
|
||||
cursor-pointer
|
||||
${
|
||||
isSelected
|
||||
? "bg-background-strong"
|
||||
? "bg-background-200"
|
||||
: "hover:bg-accent-background-hovered"
|
||||
}
|
||||
`}
|
||||
|
||||
@@ -86,7 +86,7 @@ export const ConnectorTitle = ({
|
||||
);
|
||||
}
|
||||
|
||||
const mainSectionClassName = "text-blue-500 flex w-fit";
|
||||
const mainSectionClassName = "text-blue-500 dark:text-blue-100 flex w-fit";
|
||||
const mainDisplay = (
|
||||
<>
|
||||
{sourceMetadata.icon({ size: 20 })}
|
||||
|
||||
@@ -4,9 +4,7 @@ import { ValidSources } from "@/lib/types";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { FaSwatchbook } from "react-icons/fa";
|
||||
import { NewChatIcon } from "@/components/icons/icons";
|
||||
import { useState } from "react";
|
||||
import { useUserGroups } from "@/lib/hooks";
|
||||
import {
|
||||
deleteCredential,
|
||||
swapCredential,
|
||||
@@ -207,7 +205,7 @@ export default function CredentialSection({
|
||||
{showCreateCredential && (
|
||||
<Modal
|
||||
onOutsideClick={closeCreateCredential}
|
||||
className="max-w-3xl rounded-lg"
|
||||
className="max-w-3xl flex flex-col items-start rounded-lg"
|
||||
title={`Create ${getSourceDisplayName(sourceType)} Credential`}
|
||||
>
|
||||
{oauthDetailsLoading ? (
|
||||
|
||||
@@ -160,7 +160,7 @@ export default function CreateCredential({
|
||||
}
|
||||
|
||||
if (sourceType == "google_drive") {
|
||||
return <GDriveMain />;
|
||||
return <GDriveMain setPopup={setPopup} />;
|
||||
}
|
||||
|
||||
const credentialTemplate: dictionaryType = credentialTemplates[sourceType];
|
||||
@@ -194,7 +194,7 @@ export default function CreateCredential({
|
||||
for information on setting up this connector.
|
||||
</p>
|
||||
)}
|
||||
<CardSection className="w-full !border-0 mt-4 flex flex-col gap-y-6">
|
||||
<CardSection className="w-full items-start bg-blue-200 !border-0 mt-4 flex flex-col gap-y-6">
|
||||
<TextFormField
|
||||
name="name"
|
||||
placeholder="(Optional) credential name.."
|
||||
|
||||
@@ -50,7 +50,7 @@ export function ModelOption({
|
||||
<div
|
||||
className={`p-4 w-96 border rounded-lg transition-all duration-200 ${
|
||||
selected
|
||||
? "border-blue-500 bg-blue-50 dark:bg-blue-900 dark:border-blue-700 shadow-md"
|
||||
? "border-blue-800 bg-blue-50 dark:bg-blue-950 dark:border-blue-700 shadow-md"
|
||||
: "border-background-200 hover:border-blue-300 hover:shadow-sm"
|
||||
}`}
|
||||
>
|
||||
|
||||
@@ -7,8 +7,9 @@ import {
|
||||
MicrosoftIcon,
|
||||
NomicIcon,
|
||||
OpenAIIcon,
|
||||
OpenAIISVG,
|
||||
OpenSourceIcon,
|
||||
VoyageIcon,
|
||||
VoyageIconSVG,
|
||||
} from "@/components/icons/icons";
|
||||
|
||||
export enum EmbeddingProvider {
|
||||
@@ -216,7 +217,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
{
|
||||
provider_type: EmbeddingProvider.OPENAI,
|
||||
website: "https://openai.com",
|
||||
icon: OpenAIIcon,
|
||||
icon: OpenAIISVG,
|
||||
description: "AI industry leader known for ChatGPT and DALL-E",
|
||||
apiLink: "https://platform.openai.com/api-keys",
|
||||
docsLink: "https://docs.onyx.app/guides/embedding_providers#openai-models",
|
||||
@@ -295,7 +296,7 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
{
|
||||
provider_type: EmbeddingProvider.VOYAGE,
|
||||
website: "https://www.voyageai.com",
|
||||
icon: VoyageIcon,
|
||||
icon: VoyageIconSVG,
|
||||
description: "Advanced NLP research startup born from Stanford AI Labs",
|
||||
docsLink: "https://docs.onyx.app/guides/embedding_providers#voyage-models",
|
||||
apiLink: "https://www.voyageai.com/dashboard",
|
||||
|
||||
148
web/src/components/errorPages/AccessRestrictedPage.tsx
Normal file
148
web/src/components/errorPages/AccessRestrictedPage.tsx
Normal file
@@ -0,0 +1,148 @@
|
||||
"use client";
|
||||
import { FiLock } from "react-icons/fi";
|
||||
import ErrorPageLayout from "./ErrorPageLayout";
|
||||
import { fetchCustomerPortal } from "@/app/ee/admin/billing/utils";
|
||||
import { useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { logout } from "@/lib/user";
|
||||
import { loadStripe } from "@stripe/stripe-js";
|
||||
import { NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY } from "@/lib/constants";
|
||||
|
||||
const fetchResubscriptionSession = async () => {
|
||||
const response = await fetch("/api/tenants/create-subscription-session", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to create resubscription session");
|
||||
}
|
||||
return response.json();
|
||||
};
|
||||
|
||||
export default function AccessRestricted() {
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const router = useRouter();
|
||||
|
||||
const handleManageSubscription = async () => {
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
try {
|
||||
const response = await fetchCustomerPortal();
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json();
|
||||
throw new Error(
|
||||
`Failed to create customer portal session: ${
|
||||
errorData.message || response.statusText
|
||||
}`
|
||||
);
|
||||
}
|
||||
|
||||
const { url } = await response.json();
|
||||
|
||||
if (!url) {
|
||||
throw new Error("No portal URL returned from the server");
|
||||
}
|
||||
|
||||
router.push(url);
|
||||
} catch (error) {
|
||||
console.error("Error creating customer portal session:", error);
|
||||
setError("Error opening customer portal. Please try again later.");
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleResubscribe = async () => {
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
if (!NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY) {
|
||||
setError("Stripe public key not found");
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
try {
|
||||
const { sessionId } = await fetchResubscriptionSession();
|
||||
const stripe = await loadStripe(NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY);
|
||||
|
||||
if (stripe) {
|
||||
await stripe.redirectToCheckout({ sessionId });
|
||||
} else {
|
||||
throw new Error("Stripe failed to load");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error creating resubscription session:", error);
|
||||
setError("Error opening resubscription page. Please try again later.");
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<ErrorPageLayout>
|
||||
<h1 className="text-2xl font-semibold flex items-center gap-2 mb-4 text-gray-800 dark:text-gray-200">
|
||||
<p>Access Restricted</p>
|
||||
<FiLock className="text-error inline-block" />
|
||||
</h1>
|
||||
<div className="space-y-4 text-gray-600 dark:text-gray-300">
|
||||
<p>
|
||||
We regret to inform you that your access to Onyx has been temporarily
|
||||
suspended due to a lapse in your subscription.
|
||||
</p>
|
||||
<p>
|
||||
To reinstate your access and continue benefiting from Onyx's
|
||||
powerful features, please update your payment information.
|
||||
</p>
|
||||
<p>
|
||||
If you're an admin, you can manage your subscription by clicking
|
||||
the button below. For other users, please reach out to your
|
||||
administrator to address this matter.
|
||||
</p>
|
||||
<div className="flex flex-col space-y-4 sm:flex-row sm:space-y-0 sm:space-x-4">
|
||||
<Button
|
||||
onClick={handleResubscribe}
|
||||
disabled={isLoading}
|
||||
className="w-full sm:w-auto"
|
||||
>
|
||||
{isLoading ? "Loading..." : "Resubscribe"}
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={handleManageSubscription}
|
||||
disabled={isLoading}
|
||||
className="w-full sm:w-auto"
|
||||
>
|
||||
Manage Existing Subscription
|
||||
</Button>
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={async () => {
|
||||
await logout();
|
||||
window.location.reload();
|
||||
}}
|
||||
className="w-full sm:w-auto"
|
||||
>
|
||||
Log out
|
||||
</Button>
|
||||
</div>
|
||||
{error && <p className="text-error">{error}</p>}
|
||||
<p>
|
||||
Need help? Join our{" "}
|
||||
<a
|
||||
className="text-blue-500 hover:text-blue-700 dark:text-blue-400 dark:hover:text-blue-300"
|
||||
href="https://join.slack.com/t/danswer/shared_invite/zt-1w76msxmd-HJHLe3KNFIAIzk_0dSOKaQ"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
Slack community
|
||||
</a>{" "}
|
||||
for support.
|
||||
</p>
|
||||
</div>
|
||||
</ErrorPageLayout>
|
||||
);
|
||||
}
|
||||
@@ -1162,11 +1162,114 @@ export const MistralIcon = ({
|
||||
<LogoIcon size={size} className={className} src={mistralSVG} />
|
||||
);
|
||||
|
||||
export const VoyageIcon = ({
|
||||
export const VoyageIconSVG = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => (
|
||||
<LogoIcon size={size} className={className} src={voyageIcon} />
|
||||
<svg
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 200 200"
|
||||
width="200"
|
||||
height="200"
|
||||
>
|
||||
<path
|
||||
d="M0 0 C18.56364691 14.8685395 31.52865476 35.60458591 34.68359375 59.39453125 C36.85790415 84.17093249 31.86661083 108.64738046 15.83569336 128.38696289 C-0.18749615 147.32766215 -21.13158775 159.50726579 -46 162 C-70.46026633 163.68595557 -94.53744209 157.16585411 -113.375 141.1875 C-131.5680983 125.12913912 -143.31327081 103.12304227 -145.16845703 78.79052734 C-146.52072106 52.74671426 -138.40787353 29.42123969 -121 10 C-120.39929688 9.30519531 -119.79859375 8.61039063 -119.1796875 7.89453125 C-88.7732111 -25.07872563 -34.66251161 -26.29920259 0 0 Z M-111 6 C-111.96292969 6.76441406 -112.92585938 7.52882813 -113.91796875 8.31640625 C-129.12066 21.0326872 -138.48510826 41.64930525 -141 61 C-142.57102569 86.19086606 -137.40498471 109.10013392 -120.54980469 128.68505859 C-106.05757815 144.84161953 -85.8110604 156.92053779 -63.68798828 158.12597656 C-39.72189393 158.83868932 -17.08757891 154.40601729 1.1875 137.6875 C3.15800523 135.82115685 5.07881363 133.91852176 7 132 C8.22396484 130.7934375 8.22396484 130.7934375 9.47265625 129.5625 C26.2681901 112.046746 31.70691205 89.639394 31.3125 66 C30.4579168 43.32505919 19.07700136 22.58412979 3 7 C-29.27431062 -21.68827611 -78.26536136 -21.67509486 -111 6 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(155,29)"
|
||||
/>
|
||||
<path
|
||||
d="M0 0 C2.62278901 2.33427271 3.96735488 4.64596813 5.4453125 7.81640625 C6.10080078 9.20956055 6.10080078 9.20956055 6.76953125 10.63085938 C7.21683594 11.59830078 7.66414063 12.56574219 8.125 13.5625 C8.58003906 14.53380859 9.03507812 15.50511719 9.50390625 16.50585938 C10.34430119 18.30011504 11.18198346 20.09564546 12.01611328 21.89282227 C12.65935931 23.27045415 13.32005367 24.64010734 14 26 C12.02 26 10.04 26 8 26 C6.515 22.535 6.515 22.535 5 19 C1.7 19 -1.6 19 -5 19 C-5.99 21.31 -6.98 23.62 -8 26 C-9.32 26 -10.64 26 -12 26 C-10.34176227 20.46347949 -7.92776074 15.38439485 -5.4375 10.1875 C-5.02564453 9.31673828 -4.61378906 8.44597656 -4.18945312 7.54882812 C-1.13502139 1.13502139 -1.13502139 1.13502139 0 0 Z M-1 8 C-3.2013866 11.80427492 -3.2013866 11.80427492 -4 16 C-1.69 16 0.62 16 3 16 C2.43260132 11.87026372 2.43260132 11.87026372 1 8 C0.34 8 -0.32 8 -1 8 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(158,86)"
|
||||
/>
|
||||
<path
|
||||
d="M0 0 C2.64453125 1.0234375 2.64453125 1.0234375 4.4453125 4.296875 C4.96971298 5.65633346 5.47294966 7.0241056 5.95703125 8.3984375 C6.22064453 9.08421875 6.48425781 9.77 6.75585938 10.4765625 C7.8687821 13.4482107 8.64453125 15.82826389 8.64453125 19.0234375 C9.30453125 19.0234375 9.96453125 19.0234375 10.64453125 19.0234375 C10.75667969 18.34925781 10.86882813 17.67507812 10.984375 16.98046875 C11.77373626 13.44469078 12.95952974 10.10400184 14.20703125 6.7109375 C14.44099609 6.06576172 14.67496094 5.42058594 14.91601562 4.75585938 C15.48900132 3.17722531 16.06632589 1.60016724 16.64453125 0.0234375 C17.96453125 0.0234375 19.28453125 0.0234375 20.64453125 0.0234375 C20.11164835 5.93359329 17.66052325 10.65458241 15.08203125 15.8984375 C14.65728516 16.77757813 14.23253906 17.65671875 13.79492188 18.5625 C12.75156566 20.71955106 11.70131241 22.87294038 10.64453125 25.0234375 C9.65453125 25.0234375 8.66453125 25.0234375 7.64453125 25.0234375 C6.36851794 22.52596727 5.09866954 20.02565814 3.83203125 17.5234375 C3.29739258 16.47929688 3.29739258 16.47929688 2.75195312 15.4140625 C0.37742917 10.70858383 -1.58321849 5.98797449 -3.35546875 1.0234375 C-2.35546875 0.0234375 -2.35546875 0.0234375 0 0 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(23.35546875,86.9765625)"
|
||||
/>
|
||||
<path
|
||||
d="M0 0 C4.56944444 2.13888889 4.56944444 2.13888889 6 5 C6.58094684 9.76376411 6.98189835 13.6696861 4.0625 17.625 C-0.08290736 19.4862033 -3.52913433 19.80184004 -8 19 C-11.18487773 17.20850628 -12.56721386 16.06753914 -13.9375 12.6875 C-14.04047475 8.25958558 -13.25966827 4.50191217 -10.375 1.0625 C-6.92547207 -0.48070986 -3.67744273 -0.55453501 0 0 Z M-7.66796875 3.21484375 C-9.3387892 5.45403713 -9.40271257 6.72874309 -9.375 9.5 C-9.38273437 10.2734375 -9.39046875 11.046875 -9.3984375 11.84375 C-8.90844456 14.49547648 -8.12507645 15.38331504 -6 17 C-3.17884512 17.42317323 -1.66049093 17.38718434 0.8125 15.9375 C2.65621741 12.92932949 2.30257262 10.44932782 2 7 C1.54910181 4.59436406 1.54910181 4.59436406 0 3 C-4.00690889 1.63330935 -4.00690889 1.63330935 -7.66796875 3.21484375 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(58,93)"
|
||||
/>
|
||||
<path
|
||||
d="M0 0 C0.91007812 0.00902344 1.82015625 0.01804687 2.7578125 0.02734375 C3.45648438 0.03894531 4.15515625 0.05054687 4.875 0.0625 C5.205 1.3825 5.535 2.7025 5.875 4.0625 C4.6375 3.815 3.4 3.5675 2.125 3.3125 C-1.0391959 2.93032359 -1.83705309 2.89394571 -4.6875 4.5625 C-6.71059726 8.08093001 -6.12332701 10.21181009 -5.125 14.0625 C-3.22744856 16.41223818 -3.22744856 16.41223818 0 16.1875 C0.94875 16.14625 1.8975 16.105 2.875 16.0625 C2.875 14.4125 2.875 12.7625 2.875 11.0625 C4.525 11.3925 6.175 11.7225 7.875 12.0625 C8.1875 14.375 8.1875 14.375 7.875 17.0625 C5.25185816 19.29988569 3.33979578 19.9932751 -0.0625 20.5 C-3.96030088 19.9431713 -6.06489651 18.49667323 -9.125 16.0625 C-11.6165904 12.3251144 -11.58293285 10.48918417 -11.125 6.0625 C-7.83836921 1.02299945 -5.86190884 -0.07515268 0 0 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(113.125,92.9375)"
|
||||
/>
|
||||
<path
|
||||
d="M0 0 C4.28705043 1.42901681 5.23208702 4.57025431 7.1875 8.375 C7.55552734 9.06078125 7.92355469 9.7465625 8.30273438 10.453125 C11 15.59744608 11 15.59744608 11 19 C9.35 19 7.7 19 6 19 C5.67 17.68 5.34 16.36 5 15 C2.03 14.67 -0.94 14.34 -4 14 C-4.33 15.65 -4.66 17.3 -5 19 C-5.99 19 -6.98 19 -8 19 C-7.38188466 14.44684052 -5.53234107 10.71540233 -3.4375 6.6875 C-2.9434668 5.71973633 -2.9434668 5.71973633 -2.43945312 4.73242188 C-1.63175745 3.15214772 -0.81662387 1.57567895 0 0 Z M0 6 C-0.33 7.65 -0.66 9.3 -1 11 C0.32 11 1.64 11 3 11 C2.34 9.35 1.68 7.7 1 6 C0.67 6 0.34 6 0 6 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(90,93)"
|
||||
/>
|
||||
<path
|
||||
d="M0 0 C3.63 0 7.26 0 11 0 C11 0.66 11 1.32 11 2 C8.69 2 6.38 2 4 2 C4 3.98 4 5.96 4 8 C5.98 8 7.96 8 10 8 C9.67 8.99 9.34 9.98 9 11 C7.68 11 6.36 11 5 11 C4.67 12.98 4.34 14.96 4 17 C7.465 16.505 7.465 16.505 11 16 C11 16.99 11 17.98 11 19 C7.37 19 3.74 19 0 19 C0 12.73 0 6.46 0 0 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(124,93)"
|
||||
/>
|
||||
<path
|
||||
d="M0 0 C2.25 -0.3125 2.25 -0.3125 5 0 C9 4.10810811 9 4.10810811 9 7 C9.78375 6.21625 10.5675 5.4325 11.375 4.625 C12.91666667 3.08333333 14.45833333 1.54166667 16 0 C16.99 0 17.98 0 19 0 C17.84356383 2.5056117 16.63134741 4.4803655 14.9375 6.6875 C12.52118995 10.81861073 12.20924288 14.29203528 12 19 C10.68 19 9.36 19 8 19 C8.00902344 18.443125 8.01804687 17.88625 8.02734375 17.3125 C7.78294047 11.0217722 5.92390505 8.0388994 1.49609375 3.62890625 C0 2 0 2 0 0 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(64,93)"
|
||||
/>
|
||||
<path
|
||||
d="M0 0 C1.32 0 2.64 0 4 0 C4 8.25 4 16.5 4 25 C2.68 25 1.36 25 0 25 C0 16.75 0 8.5 0 0 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(173,87)"
|
||||
/>
|
||||
<path
|
||||
d="M0 0 C0.66 0.33 1.32 0.66 2 1 C1.125 5.75 1.125 5.75 0 8 C1.093125 7.95875 2.18625 7.9175 3.3125 7.875 C7 8 7 8 10 10 C4.555 10.495 4.555 10.495 -1 11 C-1.99 13.31 -2.98 15.62 -4 18 C-5.32 18 -6.64 18 -8 18 C-6.65150163 13.64029169 -4.95092154 9.68658562 -2.875 5.625 C-2.33617187 4.56539063 -1.79734375 3.50578125 -1.2421875 2.4140625 C-0.83226562 1.61742188 -0.42234375 0.82078125 0 0 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(154,94)"
|
||||
/>
|
||||
<path
|
||||
d="M0 0 C0.66 0.33 1.32 0.66 2 1 C2 1.66 2 2.32 2 3 C1.34 3 0.68 3 0 3 C-0.05429959 4.74965358 -0.09292823 6.49979787 -0.125 8.25 C-0.14820313 9.22453125 -0.17140625 10.1990625 -0.1953125 11.203125 C0.00137219 14.0196498 0.55431084 15.60949036 2 18 C1.34 18.33 0.68 18.66 0 19 C-4.69653179 15.74855491 -4.69653179 15.74855491 -5.9375 12.6875 C-6.02161912 9.07037805 -5.30970069 6.36780178 -4 3 C-1.875 1.0625 -1.875 1.0625 0 0 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(50,93)"
|
||||
/>
|
||||
<path
|
||||
d="M0 0 C2.79192205 -0.05380578 5.5828141 -0.09357669 8.375 -0.125 C9.1690625 -0.14175781 9.963125 -0.15851563 10.78125 -0.17578125 C12.85492015 -0.19335473 14.92883241 -0.10335168 17 0 C17.66 0.66 18.32 1.32 19 2 C17 4 17 4 13.0859375 4.1953125 C11.51550649 4.18200376 9.94513779 4.15813602 8.375 4.125 C7.57320312 4.11597656 6.77140625 4.10695312 5.9453125 4.09765625 C3.96341477 4.07406223 1.98167019 4.03819065 0 4 C0 2.68 0 1.36 0 0 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(92,187)"
|
||||
/>
|
||||
<path
|
||||
d="M0 0 C0.99 0.33 1.98 0.66 3 1 C1.66666667 4.33333333 0.33333333 7.66666667 -1 11 C0.65 11 2.3 11 4 11 C4 11.33 4 11.66 4 12 C1.36 12.33 -1.28 12.66 -4 13 C-4.33 14.98 -4.66 16.96 -5 19 C-5.99 19 -6.98 19 -8 19 C-7.38188466 14.44684052 -5.53234107 10.71540233 -3.4375 6.6875 C-2.9434668 5.71973633 -2.9434668 5.71973633 -2.43945312 4.73242188 C-1.63175745 3.15214772 -0.81662387 1.57567895 0 0 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(90,93)"
|
||||
/>
|
||||
<path
|
||||
d="M0 0 C0.99 0 1.98 0 3 0 C2.43454163 3.95820859 1.19097652 6.6659053 -1 10 C-1.66 9.67 -2.32 9.34 -3 9 C-2.44271087 5.65626525 -1.64826111 2.96687001 0 0 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(37,97)"
|
||||
/>
|
||||
<path
|
||||
d="M0 0 C4.92127034 -0.16682272 8.50343896 -0.24828052 13 2 C9.60268371 4.09065618 6.95730595 4.42098999 3 4 C1.125 2.5625 1.125 2.5625 0 1 C0 0.67 0 0.34 0 0 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(110,12)"
|
||||
/>
|
||||
<path
|
||||
d="M0 0 C0 0.99 0 1.98 0 3 C-3.08888522 5.05925681 -3.70935927 5.2390374 -7.1875 5.125 C-9.0746875 5.063125 -9.0746875 5.063125 -11 5 C-10.67 4.34 -10.34 3.68 -10 3 C-7.96875 2.40234375 -7.96875 2.40234375 -5.5 1.9375 C-2.46226779 1.54135157 -2.46226779 1.54135157 0 0 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(62,107)"
|
||||
/>
|
||||
<path
|
||||
d="M0 0 C0.66 0.33 1.32 0.66 2 1 C1.25 5.75 1.25 5.75 -1 8 C-1.66 8 -2.32 8 -3 8 C-1.125 1.125 -1.125 1.125 0 0 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(154,94)"
|
||||
/>
|
||||
<path
|
||||
d="M0 0 C2.64 0 5.28 0 8 0 C8.33 1.32 8.66 2.64 9 4 C6.03 3.01 3.06 2.02 0 1 C0 0.67 0 0.34 0 0 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(110,93)"
|
||||
/>
|
||||
<path
|
||||
d="M0 0 C1.67542976 0.28604898 3.34385343 0.61781233 5 1 C4.67 2.32 4.34 3.64 4 5 C2.0625 4.6875 2.0625 4.6875 0 4 C-0.33 3.01 -0.66 2.02 -1 1 C-0.67 0.67 -0.34 0.34 0 0 Z "
|
||||
fill="currentColor"
|
||||
transform="translate(21,87)"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
|
||||
export const GoogleIcon = ({
|
||||
@@ -1291,12 +1394,25 @@ export const GoogleSitesIcon = ({
|
||||
}: IconProps) => (
|
||||
<LogoIcon size={size} className={className} src={googleSitesIcon} />
|
||||
);
|
||||
|
||||
export const ZendeskIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => (
|
||||
<LogoIcon size={size} className={className} src={zendeskIcon} />
|
||||
<div
|
||||
className="rounded-full overflow-visible dark:overflow-hidden flex items-center justify-center dark:bg-[#fff]/90"
|
||||
style={{ width: size, height: size }}
|
||||
>
|
||||
<LogoIcon
|
||||
size={
|
||||
typeof window !== "undefined" &&
|
||||
window.matchMedia("(prefers-color-scheme: dark)").matches
|
||||
? size * 0.8
|
||||
: size
|
||||
}
|
||||
className={`${className}`}
|
||||
src={zendeskIcon}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
|
||||
export const DropboxIcon = ({
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import {
|
||||
CombinedSettings,
|
||||
EnterpriseSettings,
|
||||
GatingType,
|
||||
ApplicationStatus,
|
||||
Settings,
|
||||
} from "@/app/admin/settings/interfaces";
|
||||
import {
|
||||
@@ -45,7 +45,7 @@ export async function fetchSettingsSS(): Promise<CombinedSettings | null> {
|
||||
if (results[0].status === 403 || results[0].status === 401) {
|
||||
settings = {
|
||||
auto_scroll: true,
|
||||
product_gating: GatingType.NONE,
|
||||
application_status: ApplicationStatus.ACTIVE,
|
||||
gpu_enabled: false,
|
||||
maximum_chat_retention_days: null,
|
||||
notifications: [],
|
||||
|
||||
@@ -91,3 +91,6 @@ export const NEXT_PUBLIC_ENABLE_CHROME_EXTENSION =
|
||||
export const NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK =
|
||||
process.env.NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK?.toLowerCase() ===
|
||||
"true";
|
||||
|
||||
export const NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY =
|
||||
process.env.NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY;
|
||||
|
||||
@@ -214,7 +214,7 @@ module.exports = {
|
||||
|
||||
input: "var(--white-card-popover)",
|
||||
|
||||
text: "var(--neutral-900)",
|
||||
text: "var(--neutral-950)",
|
||||
"text-darker": "var(--text-darker)",
|
||||
"text-dark": "var(--text-dark)",
|
||||
"sidebar-border": "var(--neutral-200-border)",
|
||||
|
||||
Reference in New Issue
Block a user