mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 16:25:45 +00:00
Compare commits
39 Commits
virtualiza
...
email_pret
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4eb8a23892 | ||
|
|
46343d06e7 | ||
|
|
acd9573fb6 | ||
|
|
ca35fee3ae | ||
|
|
c87261cda7 | ||
|
|
e030b0a6fc | ||
|
|
61136975ad | ||
|
|
0c74bbf9ed | ||
|
|
12b2126e69 | ||
|
|
037943c6ff | ||
|
|
f9485b1325 | ||
|
|
552a0630fe | ||
|
|
5bf520d8b8 | ||
|
|
7dc5a77946 | ||
|
|
03abd4a1bc | ||
|
|
16d6d708f6 | ||
|
|
9740ed32b5 | ||
|
|
b56877cc2e | ||
|
|
da5c83a96d | ||
|
|
818225c60e | ||
|
|
d78a1fe9c6 | ||
|
|
05b3e594b5 | ||
|
|
5a4d007cf9 | ||
|
|
3b25a2dd84 | ||
|
|
baee4c5f22 | ||
|
|
5e32f9d922 | ||
|
|
1454e7e07d | ||
|
|
6848337445 | ||
|
|
519fbd897e | ||
|
|
217569104b | ||
|
|
4c184bb7f0 | ||
|
|
a222fae7c8 | ||
|
|
94788cda53 | ||
|
|
fb931ee4de | ||
|
|
bc2c56dfb6 | ||
|
|
ae37f01f62 | ||
|
|
ef31e14518 | ||
|
|
9b0cba367e | ||
|
|
48ac690a70 |
@@ -4,6 +4,9 @@ on:
|
||||
push:
|
||||
tags:
|
||||
- "*"
|
||||
paths:
|
||||
- 'backend/model_server/**'
|
||||
- 'backend/Dockerfile.model_server'
|
||||
|
||||
env:
|
||||
REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }}
|
||||
|
||||
40
.github/workflows/pr-integration-tests.yml
vendored
40
.github/workflows/pr-integration-tests.yml
vendored
@@ -94,16 +94,19 @@ jobs:
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
MULTI_TENANT=true \
|
||||
AUTH_TYPE=basic \
|
||||
AUTH_TYPE=cloud \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
|
||||
DEV_MODE=true \
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack up -d
|
||||
id: start_docker_multi_tenant
|
||||
|
||||
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
|
||||
- name: Run Multi-Tenant Integration Tests
|
||||
run: |
|
||||
echo "Waiting for 3 minutes to ensure API server is ready..."
|
||||
sleep 180
|
||||
echo "Running integration tests..."
|
||||
docker run --rm --network danswer-stack_default \
|
||||
--name test-runner \
|
||||
@@ -119,6 +122,10 @@ jobs:
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e AUTH_TYPE=cloud \
|
||||
-e MULTI_TENANT=true \
|
||||
-e REQUIRE_EMAIL_VERIFICATION=false \
|
||||
-e DISABLE_TELEMETRY=true \
|
||||
-e IMAGE_TAG=test \
|
||||
-e DEV_MODE=true \
|
||||
onyxdotapp/onyx-integration:test \
|
||||
/app/tests/integration/multitenant_tests
|
||||
continue-on-error: true
|
||||
@@ -126,17 +133,17 @@ jobs:
|
||||
|
||||
- name: Check multi-tenant test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
if [ ${{ steps.run_multitenant_tests.outcome }} == 'failure' ]; then
|
||||
echo "Multi-tenant integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
echo "All multi-tenant integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Stop multi-tenant Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack down -v
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
@@ -216,27 +223,30 @@ jobs:
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
# save before stopping the containers so the logs can be captured
|
||||
- name: Save Docker logs
|
||||
if: success() || failure()
|
||||
# ------------------------------------------------------------
|
||||
# Always gather logs BEFORE "down":
|
||||
- name: Dump API server logs
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
|
||||
|
||||
- name: Stop Docker containers
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-logs
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
# ------------------------------------------------------------
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
@@ -133,3 +133,4 @@ Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md
|
||||
## ⭐Star History
|
||||
|
||||
[](https://star-history.com/#onyx-dot-app/onyx&Date)
|
||||
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
"""set built in to default
|
||||
|
||||
Revision ID: 2cdeff6d8c93
|
||||
Revises: f5437cc136c5
|
||||
Create Date: 2025-02-11 14:57:51.308775
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2cdeff6d8c93"
|
||||
down_revision = "f5437cc136c5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Prior to this migration / point in the codebase history,
|
||||
# built in personas were implicitly treated as default personas (with no option to change this)
|
||||
# This migration makes that explicit
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE persona
|
||||
SET is_default_persona = TRUE
|
||||
WHERE builtin_persona = TRUE
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
pass
|
||||
@@ -5,7 +5,6 @@ Revises: 47e5bef3a1d7
|
||||
Create Date: 2024-11-06 13:15:53.302644
|
||||
|
||||
"""
|
||||
import logging
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
@@ -20,13 +19,8 @@ down_revision = "47e5bef3a1d7"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
logger.info(f"{revision}: create_table: slack_bot")
|
||||
# Create new slack_bot table
|
||||
op.create_table(
|
||||
"slack_bot",
|
||||
@@ -63,7 +57,6 @@ def upgrade() -> None:
|
||||
)
|
||||
|
||||
# Handle existing Slack bot tokens first
|
||||
logger.info(f"{revision}: Checking for existing Slack bot.")
|
||||
bot_token = None
|
||||
app_token = None
|
||||
first_row_id = None
|
||||
@@ -71,15 +64,12 @@ def upgrade() -> None:
|
||||
try:
|
||||
tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key"))
|
||||
except Exception:
|
||||
logger.warning("No existing Slack bot tokens found.")
|
||||
tokens = {}
|
||||
|
||||
bot_token = tokens.get("bot_token")
|
||||
app_token = tokens.get("app_token")
|
||||
|
||||
if bot_token and app_token:
|
||||
logger.info(f"{revision}: Found bot and app tokens.")
|
||||
|
||||
session = Session(bind=op.get_bind())
|
||||
new_slack_bot = SlackBot(
|
||||
name="Slack Bot (Migrated)",
|
||||
@@ -170,10 +160,9 @@ def upgrade() -> None:
|
||||
# Clean up old tokens if they existed
|
||||
try:
|
||||
if bot_token and app_token:
|
||||
logger.info(f"{revision}: Removing old bot and app tokens.")
|
||||
get_kv_store().delete("slack_bot_tokens_config_key")
|
||||
except Exception:
|
||||
logger.warning("tried to delete tokens in dynamic config but failed")
|
||||
pass
|
||||
# Rename the table
|
||||
op.rename_table(
|
||||
"slack_bot_config__standard_answer_category",
|
||||
@@ -190,8 +179,6 @@ def upgrade() -> None:
|
||||
# Drop the table with CASCADE to handle dependent objects
|
||||
op.execute("DROP TABLE slack_bot_config CASCADE")
|
||||
|
||||
logger.info(f"{revision}: Migration complete.")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Recreate the old slack_bot_config table
|
||||
@@ -273,7 +260,7 @@ def downgrade() -> None:
|
||||
}
|
||||
get_kv_store().store("slack_bot_tokens_config_key", tokens)
|
||||
except Exception:
|
||||
logger.warning("Failed to save tokens back to KV store")
|
||||
pass
|
||||
|
||||
# Drop the new tables in reverse order
|
||||
op.drop_table("slack_channel_config")
|
||||
|
||||
@@ -3,42 +3,44 @@ from typing import Any
|
||||
|
||||
from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
cloud_tasks_to_schedule as base_cloud_tasks_to_schedule,
|
||||
beat_system_tasks as base_beat_system_tasks,
|
||||
)
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
tasks_to_schedule as base_tasks_to_schedule,
|
||||
beat_task_templates as base_beat_task_templates,
|
||||
)
|
||||
from onyx.background.celery.tasks.beat_schedule import generate_cloud_tasks
|
||||
from onyx.background.celery.tasks.beat_schedule import (
|
||||
get_tasks_to_schedule as base_get_tasks_to_schedule,
|
||||
)
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
ee_cloud_tasks_to_schedule = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
ee_beat_system_tasks: list[dict] = []
|
||||
|
||||
ee_beat_task_templates: list[dict] = []
|
||||
ee_beat_task_templates.extend(
|
||||
[
|
||||
{
|
||||
"name": "autogenerate-usage-report",
|
||||
"task": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"schedule": timedelta(days=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-ttl-management",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
},
|
||||
},
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
ee_tasks_to_schedule: list[dict] = []
|
||||
|
||||
@@ -65,9 +67,14 @@ if not MULTI_TENANT:
|
||||
]
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return ee_cloud_tasks_to_schedule + base_cloud_tasks_to_schedule
|
||||
def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
|
||||
beat_system_tasks = ee_beat_system_tasks + base_beat_system_tasks
|
||||
beat_task_templates = ee_beat_task_templates + base_beat_task_templates
|
||||
cloud_tasks = generate_cloud_tasks(
|
||||
beat_system_tasks, beat_task_templates, beat_multiplier
|
||||
)
|
||||
return cloud_tasks
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return ee_tasks_to_schedule + base_tasks_to_schedule
|
||||
return ee_tasks_to_schedule + base_get_tasks_to_schedule()
|
||||
|
||||
@@ -15,6 +15,9 @@ def make_persona_private(
|
||||
group_ids: list[int] | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""NOTE(rkuo): This function batches all updates into a single commit. If we don't
|
||||
dedupe the inputs, the commit will exception."""
|
||||
|
||||
db_session.query(Persona__User).filter(
|
||||
Persona__User.persona_id == persona_id
|
||||
).delete(synchronize_session="fetch")
|
||||
@@ -23,19 +26,22 @@ def make_persona_private(
|
||||
).delete(synchronize_session="fetch")
|
||||
|
||||
if user_ids:
|
||||
for user_uuid in user_ids:
|
||||
db_session.add(Persona__User(persona_id=persona_id, user_id=user_uuid))
|
||||
user_ids_set = set(user_ids)
|
||||
for user_id in user_ids_set:
|
||||
db_session.add(Persona__User(persona_id=persona_id, user_id=user_id))
|
||||
|
||||
create_notification(
|
||||
user_id=user_uuid,
|
||||
user_id=user_id,
|
||||
notif_type=NotificationType.PERSONA_SHARED,
|
||||
db_session=db_session,
|
||||
additional_data=PersonaSharedNotificationData(
|
||||
persona_id=persona_id,
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
if group_ids:
|
||||
for group_id in group_ids:
|
||||
group_ids_set = set(group_ids)
|
||||
for group_id in group_ids_set:
|
||||
db_session.add(
|
||||
Persona__UserGroup(persona_id=persona_id, user_group_id=group_id)
|
||||
)
|
||||
|
||||
@@ -365,7 +365,9 @@ def confluence_doc_sync(
|
||||
|
||||
slim_docs = []
|
||||
logger.debug("Fetching all slim documents from confluence")
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents(
|
||||
callback=callback
|
||||
):
|
||||
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
|
||||
@@ -15,6 +15,7 @@ logger = setup_logger()
|
||||
def _get_slim_doc_generator(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
gmail_connector: GmailConnector,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
current_time = datetime.now(timezone.utc)
|
||||
start_time = (
|
||||
@@ -24,7 +25,9 @@ def _get_slim_doc_generator(
|
||||
)
|
||||
|
||||
return gmail_connector.retrieve_all_slim_documents(
|
||||
start=start_time, end=current_time.timestamp()
|
||||
start=start_time,
|
||||
end=current_time.timestamp(),
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
|
||||
@@ -40,7 +43,9 @@ def gmail_doc_sync(
|
||||
gmail_connector = GmailConnector(**cc_pair.connector.connector_specific_config)
|
||||
gmail_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(cc_pair, gmail_connector)
|
||||
slim_doc_generator = _get_slim_doc_generator(
|
||||
cc_pair, gmail_connector, callback=callback
|
||||
)
|
||||
|
||||
document_external_access: list[DocExternalAccess] = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
|
||||
@@ -21,6 +21,7 @@ _PERMISSION_ID_PERMISSION_MAP: dict[str, dict[str, Any]] = {}
|
||||
def _get_slim_doc_generator(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
current_time = datetime.now(timezone.utc)
|
||||
start_time = (
|
||||
@@ -30,7 +31,9 @@ def _get_slim_doc_generator(
|
||||
)
|
||||
|
||||
return google_drive_connector.retrieve_all_slim_documents(
|
||||
start=start_time, end=current_time.timestamp()
|
||||
start=start_time,
|
||||
end=current_time.timestamp(),
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -20,19 +20,11 @@ def _get_slack_document_ids_and_channels(
|
||||
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents()
|
||||
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)
|
||||
|
||||
channel_doc_map: dict[str, list[str]] = {}
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_get_slack_document_ids_and_channels: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_get_slack_document_ids_and_channels", 1)
|
||||
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||
@@ -40,6 +32,14 @@ def _get_slack_document_ids_and_channels(
|
||||
channel_doc_map[channel_id] = []
|
||||
channel_doc_map[channel_id].append(doc_metadata.id)
|
||||
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_get_slack_document_ids_and_channels: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_get_slack_document_ids_and_channels", 1)
|
||||
|
||||
return channel_doc_map
|
||||
|
||||
|
||||
|
||||
@@ -64,6 +64,7 @@ async def _get_tenant_id_from_request(
|
||||
|
||||
try:
|
||||
# Look up token data in Redis
|
||||
|
||||
token_data = await retrieve_auth_token_data_from_redis(request)
|
||||
|
||||
if not token_data:
|
||||
|
||||
@@ -24,6 +24,7 @@ from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
|
||||
from onyx.auth.users import exceptions
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
@@ -85,7 +86,8 @@ async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
# Notify control plane
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
|
||||
@@ -28,3 +28,9 @@ class EmbeddingModelTextType:
|
||||
@staticmethod
|
||||
def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str:
|
||||
return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type]
|
||||
|
||||
|
||||
class GPUStatus:
|
||||
CUDA = "cuda"
|
||||
MAC_MPS = "mps"
|
||||
NONE = "none"
|
||||
|
||||
@@ -12,6 +12,7 @@ import voyageai # type: ignore
|
||||
from cohere import AsyncClient as CohereAsyncClient
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from litellm import aembedding
|
||||
from litellm.exceptions import RateLimitError
|
||||
@@ -320,6 +321,7 @@ async def embed_text(
|
||||
prefix: str | None,
|
||||
api_url: str | None,
|
||||
api_version: str | None,
|
||||
gpu_type: str = "UNKNOWN",
|
||||
) -> list[Embedding]:
|
||||
if not all(texts):
|
||||
logger.error("Empty strings provided for embedding")
|
||||
@@ -373,8 +375,11 @@ async def embed_text(
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
logger.info(
|
||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||
f"with provider {provider_type} in {elapsed:.2f}"
|
||||
f"event=embedding_provider "
|
||||
f"texts={len(texts)} "
|
||||
f"chars={total_chars} "
|
||||
f"provider={provider_type} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
elif model_name is not None:
|
||||
logger.info(
|
||||
@@ -403,6 +408,14 @@ async def embed_text(
|
||||
f"Successfully embedded {len(texts)} texts with {total_chars} total characters "
|
||||
f"with local model {model_name} in {elapsed:.2f}"
|
||||
)
|
||||
logger.info(
|
||||
f"event=embedding_model "
|
||||
f"texts={len(texts)} "
|
||||
f"chars={total_chars} "
|
||||
f"model={model_name} "
|
||||
f"gpu={gpu_type} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
else:
|
||||
logger.error("Neither model name nor provider specified for embedding")
|
||||
raise ValueError(
|
||||
@@ -455,8 +468,15 @@ async def litellm_rerank(
|
||||
|
||||
|
||||
@router.post("/bi-encoder-embed")
|
||||
async def process_embed_request(
|
||||
async def route_bi_encoder_embed(
|
||||
request: Request,
|
||||
embed_request: EmbedRequest,
|
||||
) -> EmbedResponse:
|
||||
return await process_embed_request(embed_request, request.app.state.gpu_type)
|
||||
|
||||
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest, gpu_type: str = "UNKNOWN"
|
||||
) -> EmbedResponse:
|
||||
if not embed_request.texts:
|
||||
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
||||
@@ -484,6 +504,7 @@ async def process_embed_request(
|
||||
api_url=embed_request.api_url,
|
||||
api_version=embed_request.api_version,
|
||||
prefix=prefix,
|
||||
gpu_type=gpu_type,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except RateLimitError as e:
|
||||
|
||||
@@ -16,6 +16,7 @@ from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_intent_model
|
||||
from model_server.encoders import router as encoders_router
|
||||
from model_server.management_endpoints import router as management_router
|
||||
from model_server.utils import get_gpu_type
|
||||
from onyx import __version__
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
@@ -58,12 +59,10 @@ def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
if torch.cuda.is_available():
|
||||
logger.notice("CUDA GPU is available")
|
||||
elif torch.backends.mps.is_available():
|
||||
logger.notice("Mac MPS is available")
|
||||
else:
|
||||
logger.notice("GPU is not available, using CPU")
|
||||
gpu_type = get_gpu_type()
|
||||
logger.notice(f"Torch GPU Detection: gpu_type={gpu_type}")
|
||||
|
||||
app.state.gpu_type = gpu_type
|
||||
|
||||
if TEMP_HF_CACHE_PATH.is_dir():
|
||||
logger.notice("Moving contents of temp_huggingface to huggingface cache.")
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import torch
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Response
|
||||
|
||||
from model_server.constants import GPUStatus
|
||||
from model_server.utils import get_gpu_type
|
||||
|
||||
router = APIRouter(prefix="/api")
|
||||
|
||||
|
||||
@@ -11,10 +13,7 @@ async def healthcheck() -> Response:
|
||||
|
||||
|
||||
@router.get("/gpu-status")
|
||||
async def gpu_status() -> dict[str, bool | str]:
|
||||
if torch.cuda.is_available():
|
||||
return {"gpu_available": True, "type": "cuda"}
|
||||
elif torch.backends.mps.is_available():
|
||||
return {"gpu_available": True, "type": "mps"}
|
||||
else:
|
||||
return {"gpu_available": False, "type": "none"}
|
||||
async def route_gpu_status() -> dict[str, bool | str]:
|
||||
gpu_type = get_gpu_type()
|
||||
gpu_available = gpu_type != GPUStatus.NONE
|
||||
return {"gpu_available": gpu_available, "type": gpu_type}
|
||||
|
||||
@@ -8,6 +8,9 @@ from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
|
||||
from model_server.constants import GPUStatus
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -58,3 +61,12 @@ def simple_log_function_time(
|
||||
return cast(F, wrapped_sync_func)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def get_gpu_type() -> str:
|
||||
if torch.cuda.is_available():
|
||||
return GPUStatus.CUDA
|
||||
if torch.backends.mps.is_available():
|
||||
return GPUStatus.MAC_MPS
|
||||
|
||||
return GPUStatus.NONE
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import smtplib
|
||||
from datetime import datetime
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from textwrap import dedent
|
||||
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
@@ -13,23 +13,150 @@ from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.models import User
|
||||
|
||||
HTML_EMAIL_TEMPLATE = """\
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width" />
|
||||
<title>{title}</title>
|
||||
<style>
|
||||
body, table, td, a {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif;
|
||||
text-size-adjust: 100%;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
-webkit-font-smoothing: antialiased;
|
||||
-webkit-text-size-adjust: none;
|
||||
}}
|
||||
body {{
|
||||
background-color: #f7f7f7;
|
||||
color: #333;
|
||||
}}
|
||||
.body-content {{
|
||||
color: #333;
|
||||
}}
|
||||
.email-container {{
|
||||
width: 100%;
|
||||
max-width: 600px;
|
||||
margin: 0 auto;
|
||||
background-color: #ffffff;
|
||||
border-radius: 6px;
|
||||
overflow: hidden;
|
||||
border: 1px solid #eaeaea;
|
||||
}}
|
||||
.header {{
|
||||
background-color: #000000;
|
||||
padding: 20px;
|
||||
text-align: center;
|
||||
}}
|
||||
.header img {{
|
||||
max-width: 140px;
|
||||
}}
|
||||
.body-content {{
|
||||
padding: 20px 30px;
|
||||
}}
|
||||
.title {{
|
||||
font-size: 20px;
|
||||
font-weight: bold;
|
||||
margin: 0 0 10px;
|
||||
}}
|
||||
.message {{
|
||||
font-size: 16px;
|
||||
line-height: 1.5;
|
||||
margin: 0 0 20px;
|
||||
}}
|
||||
.cta-button {{
|
||||
display: inline-block;
|
||||
padding: 12px 20px;
|
||||
background-color: #000000;
|
||||
color: #ffffff !important;
|
||||
text-decoration: none;
|
||||
border-radius: 4px;
|
||||
font-weight: 500;
|
||||
}}
|
||||
.footer {{
|
||||
font-size: 13px;
|
||||
color: #6A7280;
|
||||
text-align: center;
|
||||
padding: 20px;
|
||||
}}
|
||||
.footer a {{
|
||||
color: #6b7280;
|
||||
text-decoration: underline;
|
||||
}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<table role="presentation" class="email-container" cellpadding="0" cellspacing="0">
|
||||
<tr>
|
||||
<td class="header">
|
||||
<img
|
||||
style="background-color: #ffffff; border-radius: 8px;"
|
||||
src="https://www.onyx.app/logos/customer/onyx.png"
|
||||
alt="Onyx Logo"
|
||||
>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="body-content">
|
||||
<h1 class="title">{heading}</h1>
|
||||
<div class="message">
|
||||
{message}
|
||||
</div>
|
||||
{cta_block}
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td class="footer">
|
||||
© {year} Onyx. All rights reserved.
|
||||
<br>
|
||||
Have questions? Join our Slack community <a href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA">here</a>.
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def build_html_email(
|
||||
heading: str, message: str, cta_text: str | None = None, cta_link: str | None = None
|
||||
) -> str:
|
||||
if cta_text and cta_link:
|
||||
cta_block = f'<a class="cta-button" href="{cta_link}">{cta_text}</a>'
|
||||
else:
|
||||
cta_block = ""
|
||||
return HTML_EMAIL_TEMPLATE.format(
|
||||
title=heading,
|
||||
heading=heading,
|
||||
message=message,
|
||||
cta_block=cta_block,
|
||||
year=datetime.now().year,
|
||||
)
|
||||
|
||||
|
||||
def send_email(
|
||||
user_email: str,
|
||||
subject: str,
|
||||
body: str,
|
||||
html_body: str,
|
||||
text_body: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
if not EMAIL_CONFIGURED:
|
||||
raise ValueError("Email is not configured.")
|
||||
|
||||
msg = MIMEMultipart()
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
|
||||
msg.attach(MIMEText(body))
|
||||
part_text = MIMEText(text_body, "plain")
|
||||
part_html = MIMEText(html_body, "html")
|
||||
|
||||
msg.attach(part_text)
|
||||
msg.attach(part_html)
|
||||
|
||||
try:
|
||||
with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s:
|
||||
@@ -40,26 +167,44 @@ def send_email(
|
||||
raise e
|
||||
|
||||
|
||||
def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
# Example usage of the reusable HTML
|
||||
subject = "Your Onyx Subscription Has Been Canceled"
|
||||
heading = "Subscription Canceled"
|
||||
message = (
|
||||
"<p>We’re sorry to see you go.</p>"
|
||||
"<p>Your subscription has been canceled and will end on your next billing date.</p>"
|
||||
"<p>If you change your mind, you can always come back!</p>"
|
||||
)
|
||||
cta_text = "Renew Subscription"
|
||||
cta_link = "https://www.onyx.app/pricing"
|
||||
html_content = build_html_email(heading, message, cta_text, cta_link)
|
||||
text_content = (
|
||||
"We're sorry to see you go.\n"
|
||||
"Your subscription has been canceled and will end on your next billing date.\n"
|
||||
"If you change your mind, visit https://www.onyx.app/pricing"
|
||||
)
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
|
||||
|
||||
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
||||
subject = "Invitation to Join Onyx Organization"
|
||||
body = dedent(
|
||||
f"""\
|
||||
Hello,
|
||||
|
||||
You have been invited to join an organization on Onyx.
|
||||
|
||||
To join the organization, please visit the following link:
|
||||
|
||||
{WEB_DOMAIN}/auth/signup?email={user_email}
|
||||
|
||||
You'll be asked to set a password or login with Google to complete your registration.
|
||||
|
||||
Best regards,
|
||||
The Onyx Team
|
||||
"""
|
||||
heading = "You've Been Invited!"
|
||||
message = (
|
||||
f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
|
||||
"<p>To join the organization, please click the button below to set a password "
|
||||
"or login with Google and complete your registration.</p>"
|
||||
)
|
||||
|
||||
send_email(user_email, subject, body, current_user.email)
|
||||
cta_text = "Join Organization"
|
||||
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
|
||||
html_content = build_html_email(heading, message, cta_text, cta_link)
|
||||
text_content = (
|
||||
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
|
||||
"To join the organization, please visit the following link:\n"
|
||||
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
|
||||
"You'll be asked to set a password or login with Google to complete your registration."
|
||||
)
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
|
||||
|
||||
def send_forgot_password_email(
|
||||
@@ -68,13 +213,15 @@ def send_forgot_password_email(
|
||||
mail_from: str = EMAIL_FROM,
|
||||
tenant_id: str | None = None,
|
||||
) -> None:
|
||||
# Builds a forgot password email with or without fancy HTML
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
if tenant_id:
|
||||
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
|
||||
# Keep search param same name as cookie for simplicity
|
||||
body = f"Click the following link to reset your password: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
|
||||
html_content = build_html_email("Reset Your Password", message)
|
||||
text_content = f"Click the following link to reset your password: {link}"
|
||||
send_email(user_email, subject, html_content, text_content, mail_from)
|
||||
|
||||
|
||||
def send_user_verification_email(
|
||||
@@ -82,7 +229,12 @@ def send_user_verification_email(
|
||||
token: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
) -> None:
|
||||
# Builds a verification email
|
||||
subject = "Onyx Email Verification"
|
||||
link = f"{WEB_DOMAIN}/auth/verify-email?token={token}"
|
||||
body = f"Click the following link to verify your email address: {link}"
|
||||
send_email(user_email, subject, body, mail_from)
|
||||
message = (
|
||||
f"<p>Click the following link to verify your email address:</p><p>{link}</p>"
|
||||
)
|
||||
html_content = build_html_email("Verify Your Email", message)
|
||||
text_content = f"Click the following link to verify your email address: {link}"
|
||||
send_email(user_email, subject, html_content, text_content, mail_from)
|
||||
|
||||
@@ -1,41 +1,56 @@
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from celery import Celery
|
||||
from celery import signals
|
||||
from celery.beat import PersistentScheduler # type: ignore
|
||||
from celery.signals import beat_init
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
|
||||
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
|
||||
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
task_logger = get_task_logger(__name__)
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.beat")
|
||||
|
||||
|
||||
class DynamicTenantScheduler(PersistentScheduler):
|
||||
"""This scheduler is useful because we can dynamically adjust task generation rates
|
||||
through it."""
|
||||
|
||||
RELOAD_INTERVAL = 60
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
logger.info("Initializing DynamicTenantScheduler")
|
||||
super().__init__(*args, **kwargs)
|
||||
self._reload_interval = timedelta(minutes=2)
|
||||
|
||||
self.last_beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
|
||||
self._reload_interval = timedelta(
|
||||
seconds=DynamicTenantScheduler.RELOAD_INTERVAL
|
||||
)
|
||||
self._last_reload = self.app.now() - self._reload_interval
|
||||
|
||||
# Let the parent class handle store initialization
|
||||
self.setup_schedule()
|
||||
self._try_updating_schedule()
|
||||
logger.info(f"Set reload interval to {self._reload_interval}")
|
||||
task_logger.info(
|
||||
f"DynamicTenantScheduler initialized: reload_interval={self._reload_interval}"
|
||||
)
|
||||
|
||||
def setup_schedule(self) -> None:
|
||||
logger.info("Setting up initial schedule")
|
||||
super().setup_schedule()
|
||||
logger.info("Initial schedule setup complete")
|
||||
|
||||
def tick(self) -> float:
|
||||
retval = super().tick()
|
||||
@@ -44,36 +59,35 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
self._last_reload is None
|
||||
or (now - self._last_reload) > self._reload_interval
|
||||
):
|
||||
logger.info("Reload interval reached, initiating task update")
|
||||
task_logger.debug("Reload interval reached, initiating task update")
|
||||
try:
|
||||
self._try_updating_schedule()
|
||||
except (AttributeError, KeyError) as e:
|
||||
logger.exception(f"Failed to process task configuration: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error updating tasks: {str(e)}")
|
||||
except (AttributeError, KeyError):
|
||||
task_logger.exception("Failed to process task configuration")
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected error updating tasks")
|
||||
|
||||
self._last_reload = now
|
||||
logger.info("Task update completed, reset reload timer")
|
||||
|
||||
return retval
|
||||
|
||||
def _generate_schedule(
|
||||
self, tenant_ids: list[str] | list[None]
|
||||
self, tenant_ids: list[str] | list[None], beat_multiplier: float
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Given a list of tenant id's, generates a new beat schedule for celery."""
|
||||
logger.info("Fetching tasks to schedule")
|
||||
|
||||
new_schedule: dict[str, dict[str, Any]] = {}
|
||||
|
||||
if MULTI_TENANT:
|
||||
# cloud tasks only need the single task beat across all tenants
|
||||
# cloud tasks are system wide and thus only need to be on the beat schedule
|
||||
# once for all tenants
|
||||
get_cloud_tasks_to_schedule = fetch_versioned_implementation(
|
||||
"onyx.background.celery.tasks.beat_schedule",
|
||||
"get_cloud_tasks_to_schedule",
|
||||
)
|
||||
|
||||
cloud_tasks_to_schedule: list[
|
||||
dict[str, Any]
|
||||
] = get_cloud_tasks_to_schedule()
|
||||
cloud_tasks_to_schedule: list[dict[str, Any]] = get_cloud_tasks_to_schedule(
|
||||
beat_multiplier
|
||||
)
|
||||
for task in cloud_tasks_to_schedule:
|
||||
task_name = task["name"]
|
||||
cloud_task = {
|
||||
@@ -82,11 +96,14 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
"kwargs": task.get("kwargs", {}),
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(f"Adding options to task {task_name}: {options}")
|
||||
task_logger.debug(f"Adding options to task {task_name}: {options}")
|
||||
cloud_task["options"] = options
|
||||
new_schedule[task_name] = cloud_task
|
||||
|
||||
# regular task beats are multiplied across all tenants
|
||||
# note that currently this just schedules for a single tenant in self hosted
|
||||
# and doesn't do anything in the cloud because it's much more scalable
|
||||
# to schedule a single cloud beat task to dispatch per tenant tasks.
|
||||
get_tasks_to_schedule = fetch_versioned_implementation(
|
||||
"onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule"
|
||||
)
|
||||
@@ -95,7 +112,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST:
|
||||
logger.info(
|
||||
task_logger.debug(
|
||||
f"Skipping tenant {tenant_id} as it is in the ignored syncing list"
|
||||
)
|
||||
continue
|
||||
@@ -104,14 +121,14 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
task_name = task["name"]
|
||||
tenant_task_name = f"{task['name']}-{tenant_id}"
|
||||
|
||||
logger.debug(f"Creating task configuration for {tenant_task_name}")
|
||||
task_logger.debug(f"Creating task configuration for {tenant_task_name}")
|
||||
tenant_task = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"kwargs": {"tenant_id": tenant_id},
|
||||
}
|
||||
if options := task.get("options"):
|
||||
logger.debug(
|
||||
task_logger.debug(
|
||||
f"Adding options to task {tenant_task_name}: {options}"
|
||||
)
|
||||
tenant_task["options"] = options
|
||||
@@ -121,44 +138,57 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
|
||||
def _try_updating_schedule(self) -> None:
|
||||
"""Only updates the actual beat schedule on the celery app when it changes"""
|
||||
do_update = False
|
||||
|
||||
logger.info("_try_updating_schedule starting")
|
||||
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
|
||||
task_logger.debug("_try_updating_schedule starting")
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
logger.info(f"Found {len(tenant_ids)} IDs")
|
||||
task_logger.debug(f"Found {len(tenant_ids)} IDs")
|
||||
|
||||
# get current schedule and extract current tenants
|
||||
current_schedule = self.schedule.items()
|
||||
|
||||
# there are no more per tenant beat tasks, so comment this out
|
||||
# NOTE: we may not actualy need this scheduler any more and should
|
||||
# test reverting to a regular beat schedule implementation
|
||||
# get potential new state
|
||||
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
|
||||
if beat_multiplier_raw is not None:
|
||||
try:
|
||||
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
|
||||
beat_multiplier = float(beat_multiplier_bytes.decode())
|
||||
except ValueError:
|
||||
task_logger.error(
|
||||
f"Invalid beat_multiplier value: {beat_multiplier_raw}"
|
||||
)
|
||||
|
||||
# current_tenants = set()
|
||||
# for task_name, _ in current_schedule:
|
||||
# task_name = cast(str, task_name)
|
||||
# if task_name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX):
|
||||
# continue
|
||||
new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)
|
||||
|
||||
# if "_" in task_name:
|
||||
# # example: "check-for-condition-tenant_12345678-abcd-efgh-ijkl-12345678"
|
||||
# # -> "12345678-abcd-efgh-ijkl-12345678"
|
||||
# current_tenants.add(task_name.split("_")[-1])
|
||||
# logger.info(f"Found {len(current_tenants)} existing items in schedule")
|
||||
# if the schedule or beat multiplier has changed, update
|
||||
while True:
|
||||
if beat_multiplier != self.last_beat_multiplier:
|
||||
do_update = True
|
||||
break
|
||||
|
||||
# for tenant_id in tenant_ids:
|
||||
# if tenant_id not in current_tenants:
|
||||
# logger.info(f"Processing new tenant: {tenant_id}")
|
||||
if not DynamicTenantScheduler._compare_schedules(
|
||||
current_schedule, new_schedule
|
||||
):
|
||||
do_update = True
|
||||
break
|
||||
|
||||
new_schedule = self._generate_schedule(tenant_ids)
|
||||
break
|
||||
|
||||
if DynamicTenantScheduler._compare_schedules(current_schedule, new_schedule):
|
||||
logger.info(
|
||||
"_try_updating_schedule: Current schedule is up to date, no changes needed"
|
||||
if not do_update:
|
||||
# exit early if nothing changed
|
||||
task_logger.info(
|
||||
f"_try_updating_schedule - Schedule unchanged: "
|
||||
f"tasks={len(new_schedule)} "
|
||||
f"beat_multiplier={beat_multiplier}"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
# schedule needs updating
|
||||
task_logger.debug(
|
||||
"Schedule update required",
|
||||
extra={
|
||||
"new_tasks": len(new_schedule),
|
||||
@@ -185,11 +215,19 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
# Ensure changes are persisted
|
||||
self.sync()
|
||||
|
||||
logger.info("_try_updating_schedule: Schedule updated successfully")
|
||||
task_logger.info(
|
||||
f"_try_updating_schedule - Schedule updated: "
|
||||
f"prev_num_tasks={len(current_schedule)} "
|
||||
f"prev_beat_multiplier={self.last_beat_multiplier} "
|
||||
f"tasks={len(new_schedule)} "
|
||||
f"beat_multiplier={beat_multiplier}"
|
||||
)
|
||||
|
||||
self.last_beat_multiplier = beat_multiplier
|
||||
|
||||
@staticmethod
|
||||
def _compare_schedules(schedule1: dict, schedule2: dict) -> bool:
|
||||
"""Compare schedules to determine if an update is needed.
|
||||
"""Compare schedules by task name only to determine if an update is needed.
|
||||
True if equivalent, False if not."""
|
||||
current_tasks = set(name for name, _ in schedule1)
|
||||
new_tasks = set(schedule2.keys())
|
||||
@@ -201,7 +239,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
|
||||
@beat_init.connect
|
||||
def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info("beat_init signal received.")
|
||||
task_logger.info("beat_init signal received.")
|
||||
|
||||
# Celery beat shouldn't touch the db at all. But just setting a low minimum here.
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME)
|
||||
|
||||
@@ -84,8 +84,10 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
logger.info("worker_init signal received.")
|
||||
|
||||
EXTRA_CONCURRENCY = 4 # small extra fudge factor for connection limits
|
||||
|
||||
SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME)
|
||||
SqlEngine.init_engine(pool_size=8, max_overflow=0)
|
||||
SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import copy
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
@@ -18,242 +19,184 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
|
||||
|
||||
# hack to slow down task dispatch in the cloud until
|
||||
# we have a better implementation (backpressure, etc)
|
||||
CLOUD_BEAT_SCHEDULE_MULTIPLIER = 4
|
||||
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
|
||||
|
||||
# tasks that run in either self-hosted on cloud
|
||||
beat_task_templates: list[dict] = []
|
||||
|
||||
beat_task_templates.extend(
|
||||
[
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-pruning",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-background-processes",
|
||||
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"schedule": timedelta(minutes=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
beat_task_templates.append(
|
||||
{
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]:
|
||||
cloud_task: dict[str, Any] = {}
|
||||
|
||||
# constant options for cloud beat task generators
|
||||
task_schedule: timedelta = task["schedule"]
|
||||
cloud_task["schedule"] = task_schedule
|
||||
cloud_task["options"] = {}
|
||||
cloud_task["options"]["priority"] = OnyxCeleryPriority.HIGHEST
|
||||
cloud_task["options"]["expires"] = BEAT_EXPIRES_DEFAULT
|
||||
|
||||
# settings dependent on the original task
|
||||
cloud_task["name"] = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_{task['name']}"
|
||||
cloud_task["task"] = OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR
|
||||
cloud_task["kwargs"] = {}
|
||||
cloud_task["kwargs"]["task_name"] = task["task"]
|
||||
|
||||
optional_fields = ["queue", "priority", "expires"]
|
||||
for field in optional_fields:
|
||||
if field in task["options"]:
|
||||
cloud_task["kwargs"][field] = task["options"][field]
|
||||
|
||||
return cloud_task
|
||||
|
||||
|
||||
# tasks that only run in the cloud
|
||||
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be filtered
|
||||
# by the DynamicTenantScheduler
|
||||
cloud_tasks_to_schedule = [
|
||||
# the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be seen
|
||||
# by the DynamicTenantScheduler as system wide task and not a per tenant task
|
||||
beat_system_tasks: list[dict] = [
|
||||
# cloud specific tasks
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-alembic",
|
||||
"task": OnyxCeleryTask.CLOUD_CHECK_ALEMBIC,
|
||||
"schedule": timedelta(hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
# remaining tasks are cloud generators for per tenant tasks
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-indexing",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-prune",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=15 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=30 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(seconds=20 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-background-processes",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(minutes=5 * CLOUD_BEAT_SCHEDULE_MULTIPLIER),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
cloud_tasks_to_schedule.append(
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR,
|
||||
"schedule": timedelta(
|
||||
hours=1 * CLOUD_BEAT_SCHEDULE_MULTIPLIER
|
||||
), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.HIGHEST,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
"kwargs": {
|
||||
"task_name": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# tasks that run in either self-hosted on cloud
|
||||
tasks_to_schedule: list[dict] = []
|
||||
|
||||
if not MULTI_TENANT:
|
||||
tasks_to_schedule.extend(
|
||||
[
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-pruning",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_PRUNING,
|
||||
"schedule": timedelta(hours=1),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": OnyxCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.MEDIUM,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "monitor-background-processes",
|
||||
"task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES,
|
||||
"schedule": timedelta(minutes=15),
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
# Only add the LLM model update task if the API URL is configured
|
||||
if LLM_MODEL_UPDATE_API_URL:
|
||||
tasks_to_schedule.append(
|
||||
{
|
||||
"name": "check-for-llm-model-update",
|
||||
"task": OnyxCeleryTask.CHECK_FOR_LLM_MODEL_UPDATE,
|
||||
"schedule": timedelta(hours=1), # Check every hour
|
||||
"options": {
|
||||
"priority": OnyxCeleryPriority.LOW,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
}
|
||||
)
|
||||
tasks_to_schedule = beat_task_templates
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
return cloud_tasks_to_schedule
|
||||
def generate_cloud_tasks(
|
||||
beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
beat_tasks: system wide tasks that can be sent as is
|
||||
beat_templates: task templates that will be transformed into per tenant tasks via
|
||||
the cloud_beat_task_generator
|
||||
beat_multiplier: a multiplier that can be applied on top of the task schedule
|
||||
to speed up or slow down the task generation rate. useful in production.
|
||||
|
||||
Returns a list of cloud tasks, which consists of incoming tasks + tasks generated
|
||||
from incoming templates.
|
||||
"""
|
||||
|
||||
if beat_multiplier <= 0:
|
||||
raise ValueError("beat_multiplier must be positive!")
|
||||
|
||||
# start with the incoming beat tasks
|
||||
cloud_tasks: list[dict] = copy.deepcopy(beat_tasks)
|
||||
|
||||
# generate our cloud tasks from the templates
|
||||
for beat_template in beat_templates:
|
||||
cloud_task = make_cloud_generator_task(beat_template)
|
||||
cloud_tasks.append(cloud_task)
|
||||
|
||||
# factor in the cloud multiplier
|
||||
for cloud_task in cloud_tasks:
|
||||
cloud_task["schedule"] = cloud_task["schedule"] * beat_multiplier
|
||||
|
||||
return cloud_tasks
|
||||
|
||||
|
||||
def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]:
|
||||
return generate_cloud_tasks(beat_system_tasks, beat_task_templates, beat_multiplier)
|
||||
|
||||
|
||||
def get_tasks_to_schedule() -> list[dict[str, Any]]:
|
||||
|
||||
@@ -186,7 +186,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
sync_type=SyncType.CONNECTOR_DELETION,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
task_logger.exception("insert_sync_record exceptioned.")
|
||||
|
||||
except TaskDependencyError:
|
||||
redis_connector.delete.set_fence(None)
|
||||
|
||||
@@ -228,12 +228,15 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_PERMISSIONS,
|
||||
)
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_PERMISSIONS,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("insert_sync_record exceptioned.")
|
||||
|
||||
# set a basic fence to start
|
||||
redis_connector.permissions.set_active()
|
||||
@@ -257,11 +260,10 @@ def try_creating_permissions_sync_task(
|
||||
)
|
||||
|
||||
# fill in the celery task id
|
||||
redis_connector.permissions.set_active()
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
payload_id = payload.celery_task_id
|
||||
payload_id = payload.id
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
|
||||
return None
|
||||
@@ -290,6 +292,8 @@ def connector_permission_sync_generator_task(
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
payload_id: str | None = None
|
||||
|
||||
LoggerContextVars.reset()
|
||||
|
||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||
@@ -332,9 +336,12 @@ def connector_permission_sync_generator_task(
|
||||
sleep(1)
|
||||
continue
|
||||
|
||||
payload_id = payload.id
|
||||
|
||||
logger.info(
|
||||
f"connector_permission_sync_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.permissions.fence_key}"
|
||||
f"fence={redis_connector.permissions.fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
)
|
||||
break
|
||||
|
||||
@@ -413,7 +420,9 @@ def connector_permission_sync_generator_task(
|
||||
redis_connector.permissions.generator_complete = tasks_generated
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(f"Failed to run permission sync: cc_pair={cc_pair_id}")
|
||||
task_logger.exception(
|
||||
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id}"
|
||||
)
|
||||
|
||||
redis_connector.permissions.generator_clear()
|
||||
redis_connector.permissions.taskset_clear()
|
||||
@@ -423,6 +432,10 @@ def connector_permission_sync_generator_task(
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"Permission sync finished: cc_pair={cc_pair_id} payload_id={payload.id}"
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
|
||||
@@ -659,7 +672,7 @@ def validate_permission_sync_fence(
|
||||
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
)
|
||||
|
||||
# we're only active if tasks_scanned > 0 and tasks_not_in_celery == 0
|
||||
# we're active if there are still tasks to run and those tasks all exist in celery
|
||||
if tasks_scanned > 0 and tasks_not_in_celery == 0:
|
||||
redis_connector.permissions.set_active()
|
||||
return
|
||||
@@ -680,7 +693,8 @@ def validate_permission_sync_fence(
|
||||
"validate_permission_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
f"fence={fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
)
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
|
||||
@@ -2,15 +2,17 @@ import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import ValidationError
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
|
||||
@@ -32,7 +34,9 @@ from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
@@ -49,7 +53,8 @@ from onyx.redis.redis_connector_ext_group_sync import (
|
||||
RedisConnectorExternalGroupSyncPayload,
|
||||
)
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -107,11 +112,11 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
# r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
@@ -149,30 +154,32 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
|
||||
lock_beat.reacquire()
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_external_group_sync_task(
|
||||
payload_id = try_creating_external_group_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
if not payload_id:
|
||||
continue
|
||||
|
||||
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
|
||||
task_logger.info(
|
||||
f"External group sync queued: cc_pair={cc_pair_id} id={payload_id}"
|
||||
)
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
# lock_beat.reacquire()
|
||||
# if not r.exists(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
|
||||
# # clear any indexing fences that don't have associated celery tasks in progress
|
||||
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# # or be currently executing
|
||||
# try:
|
||||
# validate_external_group_sync_fences(
|
||||
# tenant_id, self.app, r, r_celery, lock_beat
|
||||
# )
|
||||
# except Exception:
|
||||
# task_logger.exception(
|
||||
# "Exception while validating external group sync fences"
|
||||
# )
|
||||
lock_beat.reacquire()
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
|
||||
# clear fences that don't have associated celery tasks in progress
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_external_group_sync_fences(
|
||||
tenant_id, self.app, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Exception while validating external group sync fences"
|
||||
)
|
||||
|
||||
# r.set(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=60)
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=300)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -191,9 +198,11 @@ def try_creating_external_group_sync_task(
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
) -> str | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
payload_id: str | None = None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
@@ -215,11 +224,28 @@ def try_creating_external_group_sync_task(
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
redis_connector.external_group_sync.taskset_clear()
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_GROUP,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("insert_sync_record exceptioned.")
|
||||
|
||||
# Signal active before creating fence
|
||||
redis_connector.external_group_sync.set_active()
|
||||
|
||||
payload = RedisConnectorExternalGroupSyncPayload(
|
||||
id=make_short_id(),
|
||||
submitted=datetime.now(timezone.utc),
|
||||
started=None,
|
||||
celery_task_id=None,
|
||||
)
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
@@ -234,17 +260,10 @@ def try_creating_external_group_sync_task(
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
sync_type=SyncType.EXTERNAL_GROUP,
|
||||
)
|
||||
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
payload_id = payload.id
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
|
||||
@@ -254,7 +273,7 @@ def try_creating_external_group_sync_task(
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return 1
|
||||
return payload_id
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -312,7 +331,8 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
logger.info(
|
||||
f"connector_external_group_sync_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
f"fence={redis_connector.external_group_sync.fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
)
|
||||
break
|
||||
|
||||
@@ -381,7 +401,7 @@ def connector_external_group_sync_generator_task(
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run external group sync: cc_pair={cc_pair_id}"
|
||||
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -401,32 +421,41 @@ def connector_external_group_sync_generator_task(
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"External group sync finished: cc_pair={cc_pair_id} payload_id={payload.id}"
|
||||
)
|
||||
|
||||
|
||||
def validate_external_group_sync_fences(
|
||||
tenant_id: str | None,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
reserved_sync_tasks = celery_get_unacked_task_ids(
|
||||
reserved_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
# validate all existing external group sync tasks
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if not key_str.startswith(RedisConnectorExternalGroupSync.FENCE_PREFIX):
|
||||
continue
|
||||
|
||||
validate_external_group_sync_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_tasks,
|
||||
r_celery,
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
validate_external_group_sync_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_sync_tasks,
|
||||
r_celery,
|
||||
db_session,
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
@@ -435,7 +464,6 @@ def validate_external_group_sync_fence(
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
@@ -478,26 +506,26 @@ def validate_external_group_sync_fence(
|
||||
if not redis_connector.external_group_sync.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector.external_group_sync.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
# the fence is just barely set up.
|
||||
# if redis_connector_index.active():
|
||||
# return
|
||||
|
||||
# it would be odd to get here as there isn't that much that can go wrong during
|
||||
# initial fence setup, but it's still worth making sure we can recover
|
||||
logger.info(
|
||||
try:
|
||||
payload = redis_connector.external_group_sync.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"validate_external_group_sync_fence - "
|
||||
f"Resetting fence in basic state without any activity: fence={fence_key}"
|
||||
"Resetting fence because fence schema is out of date: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.reset()
|
||||
return
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
if not payload.celery_task_id:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
@@ -527,7 +555,8 @@ def validate_external_group_sync_fence(
|
||||
"validate_external_group_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
f"fence={fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.reset()
|
||||
|
||||
@@ -423,8 +423,8 @@ def connector_indexing_task(
|
||||
# define a callback class
|
||||
callback = IndexingCallback(
|
||||
os.getppid(),
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector_index.generator_progress_key,
|
||||
redis_connector,
|
||||
redis_connector_index,
|
||||
lock,
|
||||
r,
|
||||
)
|
||||
|
||||
@@ -99,16 +99,16 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
def __init__(
|
||||
self,
|
||||
parent_pid: int,
|
||||
stop_key: str,
|
||||
generator_progress_key: str,
|
||||
redis_connector: RedisConnector,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_pid = parent_pid
|
||||
self.redis_connector: RedisConnector = redis_connector
|
||||
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.stop_key: str = stop_key
|
||||
self.generator_progress_key: str = generator_progress_key
|
||||
self.redis_client = redis_client
|
||||
self.started: datetime = datetime.now(timezone.utc)
|
||||
self.redis_lock.reacquire()
|
||||
@@ -120,7 +120,7 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
self.last_parent_check = time.monotonic()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_client.exists(self.stop_key):
|
||||
if self.redis_connector.stop.fenced:
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -143,6 +143,8 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
# self.last_parent_check = now
|
||||
|
||||
try:
|
||||
self.redis_connector.prune.set_active()
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - self.last_lock_monotonic >= (
|
||||
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
|
||||
@@ -165,7 +167,9 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
redis_lock_dump(self.redis_lock, self.redis_client)
|
||||
raise
|
||||
|
||||
self.redis_client.incrby(self.generator_progress_key, amount)
|
||||
self.redis_client.incrby(
|
||||
self.redis_connector_index.generator_progress_key, amount
|
||||
)
|
||||
|
||||
|
||||
def validate_indexing_fence(
|
||||
|
||||
@@ -1,28 +1,39 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import ValidationError
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.connector import mark_ccpair_as_pruned
|
||||
@@ -35,10 +46,15 @@ from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import pruning_ctx
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -93,6 +109,8 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK,
|
||||
@@ -123,13 +141,28 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
if not _is_pruning_due(cc_pair):
|
||||
continue
|
||||
|
||||
tasks_created = try_creating_prune_generator_task(
|
||||
payload_id = try_creating_prune_generator_task(
|
||||
self.app, cc_pair, db_session, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
if not payload_id:
|
||||
continue
|
||||
|
||||
task_logger.info(f"Pruning queued: cc_pair={cc_pair.id}")
|
||||
task_logger.info(
|
||||
f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}"
|
||||
)
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
lock_beat.reacquire()
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_PRUNING_FENCES):
|
||||
# clear any permission fences that don't have associated celery tasks in progress
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_pruning_fences(tenant_id, r, r_replica, r_celery, lock_beat)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while validating pruning fences")
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_PRUNING_FENCES, 1, ex=300)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -149,7 +182,7 @@ def try_creating_prune_generator_task(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
) -> str | None:
|
||||
"""Checks for any conditions that should block the pruning generator task from being
|
||||
created, then creates the task.
|
||||
|
||||
@@ -168,7 +201,7 @@ def try_creating_prune_generator_task(
|
||||
|
||||
# we need to serialize starting pruning since it can be triggered either via
|
||||
# celery beat or manually (API call)
|
||||
lock = r.lock(
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_prune_generator_task",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
)
|
||||
@@ -200,7 +233,30 @@ def try_creating_prune_generator_task(
|
||||
|
||||
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
|
||||
|
||||
celery_app.send_task(
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
try:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair.id,
|
||||
sync_type=SyncType.PRUNING,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("insert_sync_record exceptioned.")
|
||||
|
||||
# signal active before the fence is set
|
||||
redis_connector.prune.set_active()
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorPrunePayload(
|
||||
id=make_short_id(),
|
||||
submitted=datetime.now(timezone.utc),
|
||||
started=None,
|
||||
celery_task_id=None,
|
||||
)
|
||||
redis_connector.prune.set_fence(payload)
|
||||
|
||||
result = celery_app.send_task(
|
||||
OnyxCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair.id,
|
||||
@@ -213,16 +269,11 @@ def try_creating_prune_generator_task(
|
||||
priority=OnyxCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair.id,
|
||||
sync_type=SyncType.PRUNING,
|
||||
)
|
||||
# fill in the celery task id
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector.prune.set_fence(payload)
|
||||
|
||||
# set this only after all tasks have been added
|
||||
redis_connector.prune.set_fence(True)
|
||||
payload_id = payload.id
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
|
||||
return None
|
||||
@@ -230,7 +281,7 @@ def try_creating_prune_generator_task(
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return 1
|
||||
return payload_id
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -252,6 +303,8 @@ def connector_pruning_generator_task(
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
|
||||
payload_id: str | None = None
|
||||
|
||||
LoggerContextVars.reset()
|
||||
|
||||
pruning_ctx_dict = pruning_ctx.get()
|
||||
@@ -265,6 +318,46 @@ def connector_pruning_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_prune_generator_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.prune.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector.prune.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_prune_generator_task - fence not found: "
|
||||
f"fence={redis_connector.prune.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector.prune.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError(
|
||||
"connector_prune_generator_task: payload invalid or not found"
|
||||
)
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_prune_generator_task - Waiting for fence: "
|
||||
f"fence={redis_connector.prune.fence_key}"
|
||||
)
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
payload_id = payload.id
|
||||
|
||||
logger.info(
|
||||
f"connector_prune_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.prune.fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
)
|
||||
break
|
||||
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
@@ -294,6 +387,18 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
return
|
||||
|
||||
payload = redis_connector.prune.payload
|
||||
if not payload:
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
|
||||
new_payload = RedisConnectorPrunePayload(
|
||||
id=payload.id,
|
||||
submitted=payload.submitted,
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=payload.celery_task_id,
|
||||
)
|
||||
redis_connector.prune.set_fence(new_payload)
|
||||
|
||||
task_logger.info(
|
||||
f"Pruning generator running connector: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -307,10 +412,13 @@ def connector_pruning_generator_task(
|
||||
cc_pair.credential,
|
||||
)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
|
||||
callback = IndexingCallback(
|
||||
0,
|
||||
redis_connector.stop.fence_key,
|
||||
redis_connector.prune.generator_progress_key,
|
||||
redis_connector,
|
||||
redis_connector_index,
|
||||
lock,
|
||||
r,
|
||||
)
|
||||
@@ -357,7 +465,9 @@ def connector_pruning_generator_task(
|
||||
redis_connector.prune.generator_complete = tasks_generated
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"Failed to run pruning: cc_pair={cc_pair_id} connector={connector_id}"
|
||||
f"Pruning exceptioned: cc_pair={cc_pair_id} "
|
||||
f"connector={connector_id} "
|
||||
f"payload_id={payload_id}"
|
||||
)
|
||||
|
||||
redis_connector.prune.reset()
|
||||
@@ -366,7 +476,9 @@ def connector_pruning_generator_task(
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(f"Pruning generator finished: cc_pair={cc_pair_id}")
|
||||
task_logger.info(
|
||||
f"Pruning generator finished: cc_pair={cc_pair_id} payload_id={payload_id}"
|
||||
)
|
||||
|
||||
|
||||
"""Monitoring pruning utils, called in monitor_vespa_sync"""
|
||||
@@ -415,4 +527,184 @@ def monitor_ccpair_pruning_taskset(
|
||||
|
||||
redis_connector.prune.taskset_clear()
|
||||
redis_connector.prune.generator_clear()
|
||||
redis_connector.prune.set_fence(False)
|
||||
redis_connector.prune.set_fence(None)
|
||||
|
||||
|
||||
def validate_pruning_fences(
|
||||
tenant_id: str | None,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
# building lookup table can be expensive, so we won't bother
|
||||
# validating until the queue is small
|
||||
PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN = 1024
|
||||
|
||||
queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
|
||||
if queue_len > PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN:
|
||||
return
|
||||
|
||||
# the queue for a single pruning generator task
|
||||
reserved_generator_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery
|
||||
)
|
||||
|
||||
# the queue for a reasonably large set of lightweight deletion tasks
|
||||
queued_upsert_tasks = celery_get_queued_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
|
||||
# Use replica for this because the worst thing that happens
|
||||
# is that we don't run the validation on this pass
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if not key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
continue
|
||||
|
||||
validate_pruning_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_generator_tasks,
|
||||
queued_upsert_tasks,
|
||||
r,
|
||||
r_celery,
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
return
|
||||
|
||||
|
||||
def validate_pruning_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
queued_tasks: set[str],
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
) -> None:
|
||||
"""See validate_indexing_fence for an overall idea of validation flows.
|
||||
|
||||
queued_tasks: the celery queue of lightweight permission sync tasks
|
||||
reserved_tasks: prefetched tasks for sync task generator
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"validate_pruning_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
# parse out metadata and initialize the helper class with it
|
||||
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector.prune.fenced:
|
||||
return
|
||||
|
||||
# in the cloud, the payload format may have changed ...
|
||||
# it's a little sloppy, but just reset the fence for now if that happens
|
||||
# TODO: add intentional cleanup/abort logic
|
||||
try:
|
||||
payload = redis_connector.prune.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"validate_pruning_fence - "
|
||||
"Resetting fence because fence schema is out of date: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.prune.reset()
|
||||
return
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
if not payload.celery_task_id:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
# either the generator task must be in flight or its subtasks must be
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id,
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING,
|
||||
r_celery,
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
redis_connector.prune.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in reserved_tasks:
|
||||
# the celery task was prefetched and is reserved within a worker
|
||||
redis_connector.prune.set_active()
|
||||
return
|
||||
|
||||
# look up every task in the current taskset in the celery queue
|
||||
# every entry in the taskset should have an associated entry in the celery task queue
|
||||
# because we get the celery tasks first, the entries in our own pruning taskset
|
||||
# should be roughly a subset of the tasks in celery
|
||||
|
||||
# this check isn't very exact, but should be sufficient over a period of time
|
||||
# A single successful check over some number of attempts is sufficient.
|
||||
|
||||
# TODO: if the number of tasks in celery is much lower than than the taskset length
|
||||
# we might be able to shortcut the lookup since by definition some of the tasks
|
||||
# must not exist in celery.
|
||||
|
||||
tasks_scanned = 0
|
||||
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
|
||||
|
||||
for member in r.sscan_iter(redis_connector.prune.taskset_key):
|
||||
tasks_scanned += 1
|
||||
|
||||
member_bytes = cast(bytes, member)
|
||||
member_str = member_bytes.decode("utf-8")
|
||||
if member_str in queued_tasks:
|
||||
continue
|
||||
|
||||
if member_str in reserved_tasks:
|
||||
continue
|
||||
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
task_logger.info(
|
||||
"validate_pruning_fence task check: "
|
||||
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
)
|
||||
|
||||
# we're active if there are still tasks to run and those tasks all exist in celery
|
||||
if tasks_scanned > 0 and tasks_not_in_celery == 0:
|
||||
redis_connector.prune.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector.prune.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
task_logger.warning(
|
||||
"validate_pruning_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key} "
|
||||
f"payload_id={payload.id}"
|
||||
)
|
||||
|
||||
redis_connector.prune.reset()
|
||||
return
|
||||
|
||||
@@ -339,11 +339,15 @@ def try_generate_document_set_sync_tasks(
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
)
|
||||
try:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=document_set_id,
|
||||
sync_type=SyncType.DOCUMENT_SET,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("insert_sync_record exceptioned.")
|
||||
|
||||
# set this only after all tasks have been added
|
||||
rds.set_fence(tasks_generated)
|
||||
return tasks_generated
|
||||
@@ -411,11 +415,15 @@ def try_generate_user_group_sync_tasks(
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
)
|
||||
try:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=usergroup_id,
|
||||
sync_type=SyncType.USER_GROUP,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("insert_sync_record exceptioned.")
|
||||
|
||||
# set this only after all tasks have been added
|
||||
rug.set_fence(tasks_generated)
|
||||
|
||||
@@ -904,7 +912,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
|
||||
# use a lookup table to find active fences. We still have to verify the fence
|
||||
# exists since it is an optimization and not the source of truth.
|
||||
keys = cast(set[Any], r.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
|
||||
|
||||
@@ -140,6 +140,7 @@ def build_citations_user_message(
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
all_doc_useful: bool,
|
||||
history_message: str = "",
|
||||
context_type: str = "context documents",
|
||||
) -> HumanMessage:
|
||||
multilingual_expansion = get_multilingual_expansion()
|
||||
task_prompt_with_reminder = build_task_prompt_reminders(
|
||||
@@ -156,6 +157,7 @@ def build_citations_user_message(
|
||||
optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT
|
||||
|
||||
user_prompt = CITATIONS_PROMPT.format(
|
||||
context_type=context_type,
|
||||
optional_ignore_statement=optional_ignore,
|
||||
context_docs_str=context_docs_str,
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
@@ -165,6 +167,7 @@ def build_citations_user_message(
|
||||
else:
|
||||
# if no context docs provided, assume we're in the tool calling flow
|
||||
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
|
||||
context_type=context_type,
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=query,
|
||||
history_block=history_block,
|
||||
|
||||
@@ -263,6 +263,11 @@ class PostgresAdvisoryLocks(Enum):
|
||||
|
||||
|
||||
class OnyxCeleryQueues:
|
||||
# "celery" is the default queue defined by celery and also the queue
|
||||
# we are running in the primary worker to run system tasks
|
||||
# Tasks running in this queue should be designed specifically to run quickly
|
||||
PRIMARY = "celery"
|
||||
|
||||
# Light queue
|
||||
VESPA_METADATA_SYNC = "vespa_metadata_sync"
|
||||
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
|
||||
@@ -319,6 +324,7 @@ class OnyxRedisSignals:
|
||||
BLOCK_VALIDATE_PERMISSION_SYNC_FENCES = (
|
||||
"signal:block_validate_permission_sync_fences"
|
||||
)
|
||||
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
|
||||
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
|
||||
|
||||
|
||||
@@ -340,6 +346,9 @@ ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud"
|
||||
# the tenant id we use for system level redis operations
|
||||
ONYX_CLOUD_TENANT_ID = "cloud"
|
||||
|
||||
# the redis namespace for runtime variables
|
||||
ONYX_CLOUD_REDIS_RUNTIME = "runtime"
|
||||
|
||||
|
||||
class OnyxCeleryTask:
|
||||
DEFAULT = "celery"
|
||||
|
||||
@@ -65,10 +65,25 @@ class AirtableConnector(LoadConnector):
|
||||
base_id: str,
|
||||
table_name_or_id: str,
|
||||
treat_all_non_attachment_fields_as_metadata: bool = False,
|
||||
view_id: str | None = None,
|
||||
share_id: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
"""Initialize an AirtableConnector.
|
||||
|
||||
Args:
|
||||
base_id: The ID of the Airtable base to connect to
|
||||
table_name_or_id: The name or ID of the table to index
|
||||
treat_all_non_attachment_fields_as_metadata: If True, all fields except attachments will be treated as metadata.
|
||||
If False, only fields with types in DEFAULT_METADATA_FIELD_TYPES will be treated as metadata.
|
||||
view_id: Optional ID of a specific view to use
|
||||
share_id: Optional ID of a "share" to use for generating record URLs (https://airtable.com/developers/web/api/list-shares)
|
||||
batch_size: Number of records to process in each batch
|
||||
"""
|
||||
self.base_id = base_id
|
||||
self.table_name_or_id = table_name_or_id
|
||||
self.view_id = view_id
|
||||
self.share_id = share_id
|
||||
self.batch_size = batch_size
|
||||
self._airtable_client: AirtableApi | None = None
|
||||
self.treat_all_non_attachment_fields_as_metadata = (
|
||||
@@ -85,6 +100,39 @@ class AirtableConnector(LoadConnector):
|
||||
raise AirtableClientNotSetUpError()
|
||||
return self._airtable_client
|
||||
|
||||
@classmethod
|
||||
def _get_record_url(
|
||||
cls,
|
||||
base_id: str,
|
||||
table_id: str,
|
||||
record_id: str,
|
||||
share_id: str | None,
|
||||
view_id: str | None,
|
||||
field_id: str | None = None,
|
||||
attachment_id: str | None = None,
|
||||
) -> str:
|
||||
"""Constructs the URL for a record, optionally including field and attachment IDs
|
||||
|
||||
Full possible structure is:
|
||||
|
||||
https://airtable.com/BASE_ID/SHARE_ID/TABLE_ID/VIEW_ID/RECORD_ID/FIELD_ID/ATTACHMENT_ID
|
||||
"""
|
||||
# If we have a shared link, use that view for better UX
|
||||
if share_id:
|
||||
base_url = f"https://airtable.com/{base_id}/{share_id}/{table_id}"
|
||||
else:
|
||||
base_url = f"https://airtable.com/{base_id}/{table_id}"
|
||||
|
||||
if view_id:
|
||||
base_url = f"{base_url}/{view_id}"
|
||||
|
||||
base_url = f"{base_url}/{record_id}"
|
||||
|
||||
if field_id and attachment_id:
|
||||
return f"{base_url}/{field_id}/{attachment_id}?blocks=hide"
|
||||
|
||||
return base_url
|
||||
|
||||
def _extract_field_values(
|
||||
self,
|
||||
field_id: str,
|
||||
@@ -110,8 +158,10 @@ class AirtableConnector(LoadConnector):
|
||||
if field_type == "multipleRecordLinks":
|
||||
return []
|
||||
|
||||
# default link to use for non-attachment fields
|
||||
default_link = f"https://airtable.com/{base_id}/{table_id}/{record_id}"
|
||||
# Get the base URL for this record
|
||||
default_link = self._get_record_url(
|
||||
base_id, table_id, record_id, self.share_id, self.view_id or view_id
|
||||
)
|
||||
|
||||
if field_type == "multipleAttachments":
|
||||
attachment_texts: list[tuple[str, str]] = []
|
||||
@@ -165,17 +215,16 @@ class AirtableConnector(LoadConnector):
|
||||
extension=file_ext,
|
||||
)
|
||||
if attachment_text:
|
||||
# slightly nicer loading experience if we can specify the view ID
|
||||
if view_id:
|
||||
attachment_link = (
|
||||
f"https://airtable.com/{base_id}/{table_id}/{view_id}/{record_id}"
|
||||
f"/{field_id}/{attachment_id}?blocks=hide"
|
||||
)
|
||||
else:
|
||||
attachment_link = (
|
||||
f"https://airtable.com/{base_id}/{table_id}/{record_id}"
|
||||
f"/{field_id}/{attachment_id}?blocks=hide"
|
||||
)
|
||||
# Use the helper method to construct attachment URLs
|
||||
attachment_link = self._get_record_url(
|
||||
base_id,
|
||||
table_id,
|
||||
record_id,
|
||||
self.share_id,
|
||||
self.view_id or view_id,
|
||||
field_id,
|
||||
attachment_id,
|
||||
)
|
||||
attachment_texts.append(
|
||||
(f"{filename}:\n{attachment_text}", attachment_link)
|
||||
)
|
||||
|
||||
@@ -27,6 +27,7 @@ from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -319,6 +320,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
|
||||
@@ -386,4 +388,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
|
||||
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
|
||||
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"retrieve_all_slim_documents: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("retrieve_all_slim_documents", 1)
|
||||
|
||||
yield doc_metadata_list
|
||||
|
||||
@@ -30,6 +30,7 @@ from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
@@ -321,6 +322,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
time_range_start: SecondsSinceUnixEpoch | None = None,
|
||||
time_range_end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
query = _build_time_range_query(time_range_start, time_range_end)
|
||||
doc_batch = []
|
||||
@@ -343,6 +345,15 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if len(doc_batch) > SLIM_BATCH_SIZE:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"retrieve_all_slim_documents: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("retrieve_all_slim_documents", 1)
|
||||
|
||||
if doc_batch:
|
||||
yield doc_batch
|
||||
|
||||
@@ -368,9 +379,10 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
try:
|
||||
yield from self._fetch_slim_threads(start, end)
|
||||
yield from self._fetch_slim_threads(start, end, callback=callback)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
|
||||
@@ -42,6 +42,7 @@ from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
@@ -564,6 +565,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_batch = []
|
||||
for file in self._fetch_drive_items(
|
||||
@@ -576,15 +578,26 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if len(slim_batch) >= SLIM_BATCH_SIZE:
|
||||
yield slim_batch
|
||||
slim_batch = []
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_extract_slim_docs_from_google_drive: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_extract_slim_docs_from_google_drive", 1)
|
||||
|
||||
yield slim_batch
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
try:
|
||||
yield from self._extract_slim_docs_from_google_drive(start, end)
|
||||
yield from self._extract_slim_docs_from_google_drive(
|
||||
start, end, callback=callback
|
||||
)
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
|
||||
@@ -7,6 +7,7 @@ from pydantic import BaseModel
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
|
||||
|
||||
SecondsSinceUnixEpoch = float
|
||||
@@ -63,6 +64,7 @@ class SlimConnector(BaseConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -91,6 +91,7 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
f"&response_type=code"
|
||||
f"&scope=read"
|
||||
f"&state={state}"
|
||||
f"&prompt=consent" # prompts user for access; allows choosing workspace
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.connectors.onyx_jira.utils import build_jira_url
|
||||
from onyx.connectors.onyx_jira.utils import extract_jira_project
|
||||
from onyx.connectors.onyx_jira.utils import extract_text_from_adf
|
||||
from onyx.connectors.onyx_jira.utils import get_comment_strs
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -245,6 +246,7 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
jql = f"project = {self.quoted_jira_project}"
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.connectors.salesforce.sqlite_functions import get_affected_parent_ids_
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_record
|
||||
from onyx.connectors.salesforce.sqlite_functions import init_db
|
||||
from onyx.connectors.salesforce.sqlite_functions import update_sf_db_with_csv
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -176,6 +177,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
doc_metadata_list: list[SlimDocument] = []
|
||||
for parent_object_type in self.parent_object_list:
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -242,6 +243,7 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
for post_id in get_all_post_ids(self.slab_bot_token):
|
||||
|
||||
@@ -27,6 +27,7 @@ from onyx.connectors.slack.utils import get_message_link
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import make_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import SlackTextCleaner
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -98,6 +99,7 @@ def get_channel_messages(
|
||||
channel: dict[str, Any],
|
||||
oldest: str | None = None,
|
||||
latest: str | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> Generator[list[MessageType], None, None]:
|
||||
"""Get all messages in a channel"""
|
||||
# join so that the bot can access messages
|
||||
@@ -115,6 +117,11 @@ def get_channel_messages(
|
||||
oldest=oldest,
|
||||
latest=latest,
|
||||
):
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("get_channel_messages: Stop signal detected")
|
||||
|
||||
callback.progress("get_channel_messages", 0)
|
||||
yield cast(list[MessageType], result["messages"])
|
||||
|
||||
|
||||
@@ -325,6 +332,7 @@ def _get_all_doc_ids(
|
||||
channels: list[str] | None = None,
|
||||
channel_name_regex_enabled: bool = False,
|
||||
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
"""
|
||||
Get all document ids in the workspace, channel by channel
|
||||
@@ -342,6 +350,7 @@ def _get_all_doc_ids(
|
||||
channel_message_batches = get_channel_messages(
|
||||
client=client,
|
||||
channel=channel,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
message_ts_set: set[str] = set()
|
||||
@@ -390,6 +399,7 @@ class SlackPollConnector(PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
@@ -398,6 +408,7 @@ class SlackPollConnector(PollConnector, SlimConnector):
|
||||
client=self.client,
|
||||
channels=self.channels,
|
||||
channel_name_regex_enabled=self.channel_regex_enabled,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
def poll_source(
|
||||
|
||||
@@ -20,6 +20,7 @@ from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
@@ -405,6 +406,7 @@ class ZendeskConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
if self.content_type == "articles":
|
||||
|
||||
@@ -152,7 +152,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
# if not specified, all assistants are shown
|
||||
temperature_override_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
chosen_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
|
||||
@@ -204,6 +204,14 @@ def create_update_persona(
|
||||
if not all_prompt_ids:
|
||||
raise ValueError("No prompt IDs provided")
|
||||
|
||||
# Default persona validation
|
||||
if create_persona_request.is_default_persona:
|
||||
if not create_persona_request.is_public:
|
||||
raise ValueError("Cannot make a default persona non public")
|
||||
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
raise ValueError("Only admins can make a default persona")
|
||||
|
||||
persona = upsert_persona(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
@@ -228,6 +236,7 @@ def create_update_persona(
|
||||
num_chunks=create_persona_request.num_chunks,
|
||||
llm_relevance_filter=create_persona_request.llm_relevance_filter,
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
is_default_persona=create_persona_request.is_default_persona,
|
||||
)
|
||||
|
||||
versioned_make_persona_private = fetch_versioned_implementation(
|
||||
@@ -509,6 +518,7 @@ def upsert_persona(
|
||||
existing_persona.is_visible = is_visible
|
||||
existing_persona.search_start_date = search_start_date
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.is_default_persona = is_default_persona
|
||||
# Do not delete any associations manually added unless
|
||||
# a new updated list is provided
|
||||
if document_sets is not None:
|
||||
@@ -589,6 +599,23 @@ def delete_old_default_personas(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_persona_is_default(
|
||||
persona_id: int,
|
||||
is_default: bool,
|
||||
db_session: Session,
|
||||
user: User | None = None,
|
||||
) -> None:
|
||||
persona = fetch_persona_by_id_for_user(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
if not persona.is_public:
|
||||
persona.is_public = True
|
||||
|
||||
persona.is_default_persona = is_default
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_persona_visibility(
|
||||
persona_id: int,
|
||||
is_visible: bool,
|
||||
|
||||
@@ -146,6 +146,23 @@ def _index_vespa_chunk(
|
||||
|
||||
title = document.get_title_for_document_index()
|
||||
|
||||
metadata_json = document.metadata
|
||||
cleaned_metadata_json: dict[str, str | list[str]] = {}
|
||||
for key, value in metadata_json.items():
|
||||
cleaned_key = remove_invalid_unicode_chars(key)
|
||||
if isinstance(value, list):
|
||||
cleaned_metadata_json[cleaned_key] = [
|
||||
remove_invalid_unicode_chars(item) for item in value
|
||||
]
|
||||
else:
|
||||
cleaned_metadata_json[cleaned_key] = remove_invalid_unicode_chars(value)
|
||||
|
||||
metadata_list = document.get_metadata_str_attributes()
|
||||
if metadata_list:
|
||||
metadata_list = [
|
||||
remove_invalid_unicode_chars(metadata) for metadata in metadata_list
|
||||
]
|
||||
|
||||
vespa_document_fields = {
|
||||
DOCUMENT_ID: document.id,
|
||||
CHUNK_ID: chunk.chunk_id,
|
||||
@@ -166,10 +183,10 @@ def _index_vespa_chunk(
|
||||
SEMANTIC_IDENTIFIER: remove_invalid_unicode_chars(document.semantic_identifier),
|
||||
SECTION_CONTINUATION: chunk.section_continuation,
|
||||
LARGE_CHUNK_REFERENCE_IDS: chunk.large_chunk_reference_ids,
|
||||
METADATA: json.dumps(document.metadata),
|
||||
METADATA: json.dumps(cleaned_metadata_json),
|
||||
# Save as a list for efficient extraction as an Attribute
|
||||
METADATA_LIST: chunk.source_document.get_metadata_str_attributes(),
|
||||
METADATA_SUFFIX: chunk.metadata_suffix_keyword,
|
||||
METADATA_LIST: metadata_list,
|
||||
METADATA_SUFFIX: remove_invalid_unicode_chars(chunk.metadata_suffix_keyword),
|
||||
EMBEDDINGS: embeddings_name_vector_map,
|
||||
TITLE_EMBEDDING: chunk.title_embedding,
|
||||
DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at),
|
||||
|
||||
@@ -27,6 +27,7 @@ from langchain_core.prompt_values import PromptValue
|
||||
|
||||
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
|
||||
from onyx.configs.chat_configs import QA_TIMEOUT
|
||||
from onyx.configs.model_configs import (
|
||||
DISABLE_LITELLM_STREAMING,
|
||||
)
|
||||
@@ -35,6 +36,7 @@ from onyx.configs.model_configs import LITELLM_EXTRA_BODY
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.utils import model_is_reasoning_model
|
||||
from onyx.server.utils import mask_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
@@ -229,15 +231,15 @@ class DefaultMultiLLM(LLM):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None,
|
||||
timeout: int,
|
||||
model_provider: str,
|
||||
model_name: str,
|
||||
timeout: int | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
deployment_name: str | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
custom_llm_provider: str | None = None,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
temperature: float | None = None,
|
||||
custom_config: dict[str, str] | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
extra_body: dict | None = LITELLM_EXTRA_BODY,
|
||||
@@ -245,9 +247,16 @@ class DefaultMultiLLM(LLM):
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
):
|
||||
self._timeout = timeout
|
||||
if timeout is None:
|
||||
if model_is_reasoning_model(model_name):
|
||||
self._timeout = QA_TIMEOUT * 10 # Reasoning models are slow
|
||||
else:
|
||||
self._timeout = QA_TIMEOUT
|
||||
|
||||
self._temperature = GEN_AI_TEMPERATURE if temperature is None else temperature
|
||||
|
||||
self._model_provider = model_provider
|
||||
self._model_version = model_name
|
||||
self._temperature = temperature
|
||||
self._api_key = api_key
|
||||
self._deployment_name = deployment_name
|
||||
self._api_base = api_base
|
||||
|
||||
@@ -2,7 +2,6 @@ from typing import Any
|
||||
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from onyx.configs.chat_configs import QA_TIMEOUT
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
@@ -88,8 +87,8 @@ def get_llms_for_persona(
|
||||
|
||||
|
||||
def get_default_llms(
|
||||
timeout: int = QA_TIMEOUT,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
timeout: int | None = None,
|
||||
temperature: float | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> tuple[LLM, LLM]:
|
||||
@@ -138,7 +137,7 @@ def get_llm(
|
||||
api_version: str | None = None,
|
||||
custom_config: dict[str, str] | None = None,
|
||||
temperature: float | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
timeout: int | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> LLM:
|
||||
|
||||
@@ -29,11 +29,11 @@ OPENAI_PROVIDER_NAME = "openai"
|
||||
OPEN_AI_MODEL_NAMES = [
|
||||
"o3-mini",
|
||||
"o1-mini",
|
||||
"o1-preview",
|
||||
"o1-2024-12-17",
|
||||
"o1",
|
||||
"gpt-4",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"o1-preview",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-turbo-preview",
|
||||
"gpt-4-1106-preview",
|
||||
|
||||
@@ -543,3 +543,14 @@ def model_supports_image_input(model_name: str, model_provider: str) -> bool:
|
||||
f"Failed to get model object for {model_provider}/{model_name}"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
def model_is_reasoning_model(model_name: str) -> bool:
|
||||
_REASONING_MODEL_NAMES = [
|
||||
"o1",
|
||||
"o1-mini",
|
||||
"o3-mini",
|
||||
"deepseek-reasoner",
|
||||
"deepseek-r1",
|
||||
]
|
||||
return model_name.lower() in _REASONING_MODEL_NAMES
|
||||
|
||||
@@ -91,7 +91,7 @@ SAMPLE RESPONSE:
|
||||
# similar to the chat flow, but with the option of including a
|
||||
# "conversation history" block
|
||||
CITATIONS_PROMPT = f"""
|
||||
Refer to the following context documents when responding to me.{DEFAULT_IGNORE_STATEMENT}
|
||||
Refer to the following {{context_type}} when responding to me.{DEFAULT_IGNORE_STATEMENT}
|
||||
|
||||
CONTEXT:
|
||||
{GENERAL_SEP_PAT}
|
||||
@@ -108,7 +108,7 @@ CONTEXT:
|
||||
# NOTE: need to add the extra line about "getting right to the point" since the
|
||||
# tool calling models from OpenAI tend to be more verbose
|
||||
CITATIONS_PROMPT_FOR_TOOL_CALLING = f"""
|
||||
Refer to the provided context documents when responding to me.{DEFAULT_IGNORE_STATEMENT} \
|
||||
Refer to the provided {{context_type}} when responding to me.{DEFAULT_IGNORE_STATEMENT} \
|
||||
You should always get right to the point, and never use extraneous language.
|
||||
|
||||
{{history_block}}{{task_prompt}}
|
||||
|
||||
@@ -80,7 +80,8 @@ class RedisConnectorPermissionSync:
|
||||
def get_active_task_count(self) -> int:
|
||||
"""Count of active permission sync tasks"""
|
||||
count = 0
|
||||
for _ in self.redis.scan_iter(
|
||||
for _ in self.redis.sscan_iter(
|
||||
OnyxRedisConstants.ACTIVE_FENCES,
|
||||
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import redis
|
||||
@@ -8,10 +7,12 @@ from pydantic import BaseModel
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
|
||||
|
||||
class RedisConnectorExternalGroupSyncPayload(BaseModel):
|
||||
id: str
|
||||
submitted: datetime
|
||||
started: datetime | None
|
||||
celery_task_id: str | None
|
||||
@@ -37,6 +38,12 @@ class RedisConnectorExternalGroupSync:
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorexternalgroupsync_taskset
|
||||
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorexternalgroupsync+sub
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
@@ -50,6 +57,7 @@ class RedisConnectorExternalGroupSync:
|
||||
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
|
||||
|
||||
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
|
||||
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
|
||||
|
||||
def taskset_clear(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
@@ -66,7 +74,8 @@ class RedisConnectorExternalGroupSync:
|
||||
def get_active_task_count(self) -> int:
|
||||
"""Count of active external group syncing tasks"""
|
||||
count = 0
|
||||
for _ in self.redis.scan_iter(
|
||||
for _ in self.redis.sscan_iter(
|
||||
OnyxRedisConstants.ACTIVE_FENCES,
|
||||
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
@@ -83,10 +92,11 @@ class RedisConnectorExternalGroupSync:
|
||||
@property
|
||||
def payload(self) -> RedisConnectorExternalGroupSyncPayload | None:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_bytes = cast(Any, self.redis.get(self.fence_key))
|
||||
if fence_bytes is None:
|
||||
fence_raw = self.redis.get(self.fence_key)
|
||||
if fence_raw is None:
|
||||
return None
|
||||
|
||||
fence_bytes = cast(bytes, fence_raw)
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
payload = RedisConnectorExternalGroupSyncPayload.model_validate_json(
|
||||
cast(str, fence_str)
|
||||
@@ -99,10 +109,26 @@ class RedisConnectorExternalGroupSync:
|
||||
payload: RedisConnectorExternalGroupSyncPayload | None,
|
||||
) -> None:
|
||||
if not payload:
|
||||
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
def set_active(self) -> None:
|
||||
"""This sets a signal to keep the permissioning flow from getting cleaned up within
|
||||
the expiration time.
|
||||
|
||||
The slack in timing is needed to avoid race conditions where simply checking
|
||||
the celery queue and task status could result in race conditions."""
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
@@ -138,6 +164,8 @@ class RedisConnectorExternalGroupSync:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
self.redis.delete(self.active_key)
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
self.redis.delete(self.taskset_key)
|
||||
@@ -152,6 +180,9 @@ class RedisConnectorExternalGroupSync:
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorExternalGroupSync.ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorExternalGroupSync.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
import redis
|
||||
from celery import Celery
|
||||
from pydantic import BaseModel
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -16,6 +18,13 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
|
||||
|
||||
class RedisConnectorPrunePayload(BaseModel):
|
||||
id: str
|
||||
submitted: datetime
|
||||
started: datetime | None
|
||||
celery_task_id: str | None
|
||||
|
||||
|
||||
class RedisConnectorPrune:
|
||||
"""Manages interactions with redis for pruning tasks. Should only be accessed
|
||||
through RedisConnector."""
|
||||
@@ -36,6 +45,12 @@ class RedisConnectorPrune:
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpruning_taskset
|
||||
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpruning+sub
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
@@ -49,6 +64,7 @@ class RedisConnectorPrune:
|
||||
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
|
||||
|
||||
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
|
||||
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
|
||||
|
||||
def taskset_clear(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
@@ -65,8 +81,10 @@ class RedisConnectorPrune:
|
||||
def get_active_task_count(self) -> int:
|
||||
"""Count of active pruning tasks"""
|
||||
count = 0
|
||||
for key in self.redis.scan_iter(
|
||||
RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT
|
||||
for _ in self.redis.sscan_iter(
|
||||
OnyxRedisConstants.ACTIVE_FENCES,
|
||||
RedisConnectorPrune.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
count += 1
|
||||
return count
|
||||
@@ -78,15 +96,44 @@ class RedisConnectorPrune:
|
||||
|
||||
return False
|
||||
|
||||
def set_fence(self, value: bool) -> None:
|
||||
if not value:
|
||||
@property
|
||||
def payload(self) -> RedisConnectorPrunePayload | None:
|
||||
# read related data and evaluate/print task progress
|
||||
fence_bytes = cast(bytes, self.redis.get(self.fence_key))
|
||||
if fence_bytes is None:
|
||||
return None
|
||||
|
||||
fence_str = fence_bytes.decode("utf-8")
|
||||
payload = RedisConnectorPrunePayload.model_validate_json(cast(str, fence_str))
|
||||
|
||||
return payload
|
||||
|
||||
def set_fence(
|
||||
self,
|
||||
payload: RedisConnectorPrunePayload | None,
|
||||
) -> None:
|
||||
if not payload:
|
||||
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
self.redis.delete(self.fence_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.fence_key, 0)
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
|
||||
def set_active(self) -> None:
|
||||
"""This sets a signal to keep the permissioning flow from getting cleaned up within
|
||||
the expiration time.
|
||||
|
||||
The slack in timing is needed to avoid race conditions where simply checking
|
||||
the celery queue and task status could result in race conditions."""
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
"""the fence payload is an int representing the starting number of
|
||||
@@ -162,6 +209,7 @@ class RedisConnectorPrune:
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key)
|
||||
self.redis.delete(self.active_key)
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
self.redis.delete(self.taskset_key)
|
||||
@@ -176,6 +224,9 @@ class RedisConnectorPrune:
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorPrune.ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPrune.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ from onyx.background.celery.tasks.pruning.tasks import (
|
||||
try_creating_prune_generator_task,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.connector_credential_pair import add_credential_to_connector
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pair_from_id_for_user,
|
||||
@@ -228,6 +230,13 @@ def update_cc_pair_status(
|
||||
|
||||
db_session.commit()
|
||||
|
||||
# this speeds up the start of indexing by firing the check immediately
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
kwargs=dict(tenant_id=tenant_id),
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)}
|
||||
)
|
||||
@@ -359,15 +368,17 @@ def prune_cc_pair(
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
tasks_created = try_creating_prune_generator_task(
|
||||
payload_id = try_creating_prune_generator_task(
|
||||
primary_app, cc_pair, db_session, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
)
|
||||
if not tasks_created:
|
||||
if not payload_id:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
detail="Pruning task creation failed.",
|
||||
)
|
||||
|
||||
logger.info(f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}")
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Successfully created the pruning task.",
|
||||
@@ -505,15 +516,17 @@ def sync_cc_pair_groups(
|
||||
f"credential_id={cc_pair.credential_id} "
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
tasks_created = try_creating_external_group_sync_task(
|
||||
payload_id = try_creating_external_group_sync_task(
|
||||
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
)
|
||||
if not tasks_created:
|
||||
if not payload_id:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
detail="External group sync task creation failed.",
|
||||
)
|
||||
|
||||
logger.info(f"External group sync queued: cc_pair={cc_pair_id} id={payload_id}")
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Successfully created the external group sync task.",
|
||||
@@ -540,7 +553,14 @@ def associate_credential_to_connector(
|
||||
metadata: ConnectorCredentialPairMetadata,
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse[int]:
|
||||
"""NOTE(rkuo): internally discussed and the consensus is this endpoint
|
||||
and create_connector_with_mock_credential should be combined.
|
||||
|
||||
The intent of this endpoint is to handle connectors that actually need credentials.
|
||||
"""
|
||||
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
)(
|
||||
@@ -563,6 +583,18 @@ def associate_credential_to_connector(
|
||||
groups=metadata.groups,
|
||||
)
|
||||
|
||||
# trigger indexing immediately
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"associate_credential_to_connector - running check_for_indexing: "
|
||||
f"cc_pair={response.data}"
|
||||
)
|
||||
|
||||
return response
|
||||
except IntegrityError as e:
|
||||
logger.error(f"IntegrityError: {e}")
|
||||
|
||||
@@ -804,6 +804,14 @@ def create_connector_with_mock_credential(
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse:
|
||||
"""NOTE(rkuo): internally discussed and the consensus is this endpoint
|
||||
and associate_credential_to_connector should be combined.
|
||||
|
||||
The intent of this endpoint is to handle connectors that don't need credentials,
|
||||
AKA web, file, etc ... but there isn't any reason a single endpoint couldn't
|
||||
server this purpose.
|
||||
"""
|
||||
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
)(
|
||||
@@ -841,6 +849,18 @@ def create_connector_with_mock_credential(
|
||||
groups=connector_data.groups,
|
||||
)
|
||||
|
||||
# trigger indexing immediately
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"create_connector_with_mock_credential - running check_for_indexing: "
|
||||
f"cc_pair={response.data}"
|
||||
)
|
||||
|
||||
create_milestone_and_report(
|
||||
user=user,
|
||||
distinct_id=user.email if user else tenant_id or "N/A",
|
||||
@@ -1005,6 +1025,8 @@ def connector_run_once(
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
logger.info("connector_run_once - running check_for_indexing")
|
||||
|
||||
msg = f"Marked {num_triggers} index attempts with indexing triggers."
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
|
||||
@@ -179,12 +179,10 @@ def oauth_callback(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# TODO: use a library for url handling
|
||||
sep = "&" if "?" in desired_return_url else "?"
|
||||
return CallbackResponse(
|
||||
redirect_url=(
|
||||
f"{desired_return_url}?credentialId={credential.id}"
|
||||
if "?" not in desired_return_url
|
||||
else f"{desired_return_url}&credentialId={credential.id}"
|
||||
)
|
||||
redirect_url=f"{desired_return_url}{sep}credentialId={credential.id}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -6,11 +6,15 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.document_set import check_document_sets_are_public
|
||||
from onyx.db.document_set import fetch_all_document_sets_for_user
|
||||
from onyx.db.document_set import insert_document_set
|
||||
from onyx.db.document_set import mark_document_set_as_to_be_deleted
|
||||
from onyx.db.document_set import update_document_set
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.server.features.document_set.models import CheckDocSetPublicRequest
|
||||
@@ -29,6 +33,7 @@ def create_document_set(
|
||||
document_set_creation_request: DocumentSetCreationRequest,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> int:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
@@ -46,6 +51,13 @@ def create_document_set(
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
return document_set_db_model.id
|
||||
|
||||
|
||||
@@ -54,6 +66,7 @@ def patch_document_set(
|
||||
document_set_update_request: DocumentSetUpdateRequest,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> None:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.db.user_group", "validate_object_creation_for_user", None
|
||||
@@ -72,12 +85,19 @@ def patch_document_set(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/admin/document-set/{document_set_id}")
|
||||
def delete_document_set(
|
||||
document_set_id: int,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> None:
|
||||
try:
|
||||
mark_document_set_as_to_be_deleted(
|
||||
@@ -88,6 +108,12 @@ def delete_document_set(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
|
||||
"""Endpoints for non-admins"""
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ from onyx.db.persona import get_personas_for_user
|
||||
from onyx.db.persona import mark_persona_as_deleted
|
||||
from onyx.db.persona import mark_persona_as_not_deleted
|
||||
from onyx.db.persona import update_all_personas_display_priority
|
||||
from onyx.db.persona import update_persona_is_default
|
||||
from onyx.db.persona import update_persona_label
|
||||
from onyx.db.persona import update_persona_public_status
|
||||
from onyx.db.persona import update_persona_shared_users
|
||||
@@ -56,7 +57,6 @@ from onyx.tools.utils import is_image_generation_available
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -72,6 +72,10 @@ class IsPublicRequest(BaseModel):
|
||||
is_public: bool
|
||||
|
||||
|
||||
class IsDefaultRequest(BaseModel):
|
||||
is_default_persona: bool
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/visible")
|
||||
def patch_persona_visibility(
|
||||
persona_id: int,
|
||||
@@ -106,6 +110,25 @@ def patch_user_presona_public_status(
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/default")
|
||||
def patch_persona_default_status(
|
||||
persona_id: int,
|
||||
is_default_request: IsDefaultRequest,
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
update_persona_is_default(
|
||||
persona_id=persona_id,
|
||||
is_default=is_default_request.is_default_persona,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to update persona default status")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.put("/display-priority")
|
||||
def patch_persona_display_priority(
|
||||
display_priority_request: DisplayPriorityRequest,
|
||||
|
||||
@@ -197,6 +197,11 @@ def create_deletion_attempt_for_connector_id(
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"create_deletion_attempt_for_connector_id - running check_for_connector_deletion: "
|
||||
f"cc_pair={cc_pair.id}"
|
||||
)
|
||||
|
||||
if cc_pair.connector.source == DocumentSource.FILE:
|
||||
connector = cc_pair.connector
|
||||
file_store = get_default_file_store(db_session)
|
||||
|
||||
@@ -34,6 +34,7 @@ from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
@@ -286,7 +287,7 @@ def bulk_invite_users(
|
||||
detail=f"Invalid email address: {email} - {str(e)}",
|
||||
)
|
||||
|
||||
if MULTI_TENANT:
|
||||
if MULTI_TENANT and not DEV_MODE:
|
||||
try:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning", "add_users_to_tenant", None
|
||||
|
||||
@@ -279,4 +279,5 @@ class InternetSearchTool(Tool):
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
answer_style_config=self.answer_style_config,
|
||||
prompt_config=self.prompt_config,
|
||||
context_type="internet search results",
|
||||
)
|
||||
|
||||
@@ -25,6 +25,7 @@ def build_next_prompt_for_search_like_tool(
|
||||
using_tool_calling_llm: bool,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
context_type: str = "context documents",
|
||||
) -> AnswerPromptBuilder:
|
||||
if not using_tool_calling_llm:
|
||||
final_context_docs_response = next(
|
||||
@@ -58,6 +59,7 @@ def build_next_prompt_for_search_like_tool(
|
||||
else False
|
||||
),
|
||||
history_message=prompt_builder.single_message_history or "",
|
||||
context_type=context_type,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -86,7 +86,10 @@ def run_functions_in_parallel(
|
||||
Executes a list of FunctionCalls in parallel and stores the results in a dictionary where the keys
|
||||
are the result_id of the FunctionCall and the values are the results of the call.
|
||||
"""
|
||||
results = {}
|
||||
results: dict[str, Any] = {}
|
||||
|
||||
if len(function_calls) == 0:
|
||||
return results
|
||||
|
||||
with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
|
||||
future_to_id = {
|
||||
|
||||
@@ -9,6 +9,8 @@ from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
|
||||
BASE_VIEW_ID = "viwVUEJjWPd8XYjh8"
|
||||
|
||||
|
||||
class AirtableConfig(BaseModel):
|
||||
base_id: str
|
||||
@@ -46,6 +48,8 @@ def create_test_document(
|
||||
days_since_status_change: int | None,
|
||||
attachments: list[tuple[str, str]] | None = None,
|
||||
all_fields_as_metadata: bool = False,
|
||||
share_id: str | None = None,
|
||||
view_id: str | None = None,
|
||||
) -> Document:
|
||||
base_id = os.environ.get("AIRTABLE_TEST_BASE_ID")
|
||||
table_id = os.environ.get("AIRTABLE_TEST_TABLE_ID")
|
||||
@@ -60,7 +64,13 @@ def create_test_document(
|
||||
f"Required environment variables not set: {', '.join(missing_vars)}. "
|
||||
"These variables are required to run Airtable connector tests."
|
||||
)
|
||||
link_base = f"https://airtable.com/{base_id}/{table_id}"
|
||||
link_base = f"https://airtable.com/{base_id}"
|
||||
if share_id:
|
||||
link_base = f"{link_base}/{share_id}"
|
||||
link_base = f"{link_base}/{table_id}"
|
||||
if view_id:
|
||||
link_base = f"{link_base}/{view_id}"
|
||||
|
||||
sections = []
|
||||
|
||||
if not all_fields_as_metadata:
|
||||
@@ -214,6 +224,7 @@ def test_airtable_connector_basic(
|
||||
assignee="Chris Weaver (chris@onyx.app)",
|
||||
submitted_by="Chris Weaver (chris@onyx.app)",
|
||||
all_fields_as_metadata=False,
|
||||
view_id=BASE_VIEW_ID,
|
||||
),
|
||||
create_test_document(
|
||||
id="reccSlIA4pZEFxPBg",
|
||||
@@ -234,6 +245,7 @@ def test_airtable_connector_basic(
|
||||
)
|
||||
],
|
||||
all_fields_as_metadata=False,
|
||||
view_id=BASE_VIEW_ID,
|
||||
),
|
||||
]
|
||||
|
||||
@@ -285,6 +297,81 @@ def test_airtable_connector_all_metadata(
|
||||
)
|
||||
],
|
||||
all_fields_as_metadata=True,
|
||||
view_id=BASE_VIEW_ID,
|
||||
),
|
||||
]
|
||||
|
||||
# Compare documents using the utility function
|
||||
compare_documents(doc_batch, expected_docs)
|
||||
|
||||
|
||||
def test_airtable_connector_with_share_and_view(
|
||||
mock_get_unstructured_api_key: MagicMock, airtable_config: AirtableConfig
|
||||
) -> None:
|
||||
"""Test behavior when using share_id and view_id for URL generation."""
|
||||
SHARE_ID = "shrkfjEzDmLaDtK83"
|
||||
|
||||
connector = AirtableConnector(
|
||||
base_id=airtable_config.base_id,
|
||||
table_name_or_id=airtable_config.table_identifier,
|
||||
treat_all_non_attachment_fields_as_metadata=False,
|
||||
share_id=SHARE_ID,
|
||||
view_id=BASE_VIEW_ID,
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"airtable_access_token": airtable_config.access_token,
|
||||
}
|
||||
)
|
||||
doc_batch_generator = connector.load_from_state()
|
||||
doc_batch = next(doc_batch_generator)
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 2
|
||||
|
||||
expected_docs = [
|
||||
create_test_document(
|
||||
id="rec8BnxDLyWeegOuO",
|
||||
title="Slow Internet",
|
||||
description="The internet connection is very slow.",
|
||||
priority="Medium",
|
||||
status="In Progress",
|
||||
ticket_id="2",
|
||||
created_time="2024-12-24T21:02:49.000Z",
|
||||
status_last_changed="2024-12-24T21:02:49.000Z",
|
||||
days_since_status_change=0,
|
||||
assignee="Chris Weaver (chris@onyx.app)",
|
||||
submitted_by="Chris Weaver (chris@onyx.app)",
|
||||
all_fields_as_metadata=False,
|
||||
share_id=SHARE_ID,
|
||||
view_id=BASE_VIEW_ID,
|
||||
),
|
||||
create_test_document(
|
||||
id="reccSlIA4pZEFxPBg",
|
||||
title="Printer Issue",
|
||||
description="The office printer is not working.",
|
||||
priority="High",
|
||||
status="Open",
|
||||
ticket_id="1",
|
||||
created_time="2024-12-24T21:02:49.000Z",
|
||||
status_last_changed="2024-12-24T21:02:49.000Z",
|
||||
days_since_status_change=0,
|
||||
assignee="Chris Weaver (chris@onyx.app)",
|
||||
submitted_by="Chris Weaver (chris@onyx.app)",
|
||||
attachments=[
|
||||
(
|
||||
"Test.pdf:\ntesting!!!",
|
||||
(
|
||||
f"https://airtable.com/{airtable_config.base_id}/{SHARE_ID}/"
|
||||
f"{os.environ['AIRTABLE_TEST_TABLE_ID']}/{BASE_VIEW_ID}/reccSlIA4pZEFxPBg/"
|
||||
"fld1u21zkJACIvAEF/attlj2UBWNEDZngCc?blocks=hide"
|
||||
),
|
||||
)
|
||||
],
|
||||
all_fields_as_metadata=False,
|
||||
share_id=SHARE_ID,
|
||||
view_id=BASE_VIEW_ID,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@@ -70,8 +70,8 @@ COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
# Set up application files
|
||||
COPY ./onyx /app/onyx
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic_tenants /app/alembic_tenants
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY ./pytest.ini /app/pytest.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
|
||||
@@ -66,7 +66,7 @@ class PersonaManager:
|
||||
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/persona",
|
||||
json=persona_creation_request.model_dump(),
|
||||
json=persona_creation_request.model_dump(mode="json"),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
@@ -119,6 +119,7 @@ class PersonaManager:
|
||||
) -> DATestPersona:
|
||||
system_prompt = system_prompt or f"System prompt for {persona.name}"
|
||||
task_prompt = task_prompt or f"Task prompt for {persona.name}"
|
||||
|
||||
persona_update_request = PersonaUpsertRequest(
|
||||
name=name or persona.name,
|
||||
description=description or persona.description,
|
||||
@@ -146,7 +147,7 @@ class PersonaManager:
|
||||
|
||||
response = requests.patch(
|
||||
f"{API_SERVER_URL}/persona/{persona.id}",
|
||||
json=persona_update_request.model_dump(),
|
||||
json=persona_update_request.model_dump(mode="json"),
|
||||
headers=user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS,
|
||||
|
||||
@@ -24,35 +24,6 @@ def generate_auth_token() -> str:
|
||||
|
||||
|
||||
class TenantManager:
|
||||
@staticmethod
|
||||
def create(
|
||||
tenant_id: str | None = None,
|
||||
initial_admin_email: str | None = None,
|
||||
referral_source: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
body = {
|
||||
"tenant_id": tenant_id,
|
||||
"initial_admin_email": initial_admin_email,
|
||||
"referral_source": referral_source,
|
||||
}
|
||||
|
||||
token = generate_auth_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"X-API-KEY": "",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
url=f"{API_SERVER_URL}/tenants/create",
|
||||
json=body,
|
||||
headers=headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def get_all_users(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
|
||||
@@ -92,6 +92,7 @@ class UserManager:
|
||||
|
||||
# Set cookies in the headers
|
||||
test_user.headers["Cookie"] = f"fastapiusersauth={session_cookie}; "
|
||||
test_user.cookies = {"fastapiusersauth": session_cookie}
|
||||
return test_user
|
||||
|
||||
@staticmethod
|
||||
@@ -102,6 +103,7 @@ class UserManager:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/me",
|
||||
headers=user_to_verify.headers,
|
||||
cookies=user_to_verify.cookies,
|
||||
)
|
||||
|
||||
if user_to_verify.is_active is False:
|
||||
|
||||
@@ -242,6 +242,18 @@ def reset_postgres_multitenant() -> None:
|
||||
schema_name = schema[0]
|
||||
cur.execute(f'DROP SCHEMA "{schema_name}" CASCADE')
|
||||
|
||||
# Drop tables in the public schema
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT tablename FROM pg_tables
|
||||
WHERE schemaname = 'public'
|
||||
"""
|
||||
)
|
||||
public_tables = cur.fetchall()
|
||||
for table in public_tables:
|
||||
table_name = table[0]
|
||||
cur.execute(f'DROP TABLE IF EXISTS public."{table_name}" CASCADE')
|
||||
|
||||
cur.close()
|
||||
conn.close()
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ class DATestUser(BaseModel):
|
||||
headers: dict
|
||||
role: UserRole
|
||||
is_active: bool
|
||||
cookies: dict = {}
|
||||
|
||||
|
||||
class DATestPersonaLabel(BaseModel):
|
||||
|
||||
@@ -4,7 +4,6 @@ from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.tenant import TenantManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestCCPair
|
||||
@@ -13,25 +12,28 @@ from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
# Create Tenant 1 and its Admin User
|
||||
TenantManager.create("tenant_dev1", "test1@test.com", "Data Plane Registration")
|
||||
test_user1: DATestUser = UserManager.create(name="test1", email="test1@test.com")
|
||||
assert UserManager.is_role(test_user1, UserRole.ADMIN)
|
||||
# Creating an admin user (first user created is automatically an admin and also proviions the tenant
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email="admin@onyx-test.com",
|
||||
)
|
||||
|
||||
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
|
||||
|
||||
# Create Tenant 2 and its Admin User
|
||||
TenantManager.create("tenant_dev2", "test2@test.com", "Data Plane Registration")
|
||||
test_user2: DATestUser = UserManager.create(name="test2", email="test2@test.com")
|
||||
assert UserManager.is_role(test_user2, UserRole.ADMIN)
|
||||
admin_user2: DATestUser = UserManager.create(
|
||||
email="admin2@onyx-test.com",
|
||||
)
|
||||
assert UserManager.is_role(admin_user2, UserRole.ADMIN)
|
||||
|
||||
# Create connectors for Tenant 1
|
||||
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=test_user1,
|
||||
user_performing_action=admin_user1,
|
||||
)
|
||||
api_key_1: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=test_user1,
|
||||
user_performing_action=admin_user1,
|
||||
)
|
||||
api_key_1.headers.update(test_user1.headers)
|
||||
LLMProviderManager.create(user_performing_action=test_user1)
|
||||
api_key_1.headers.update(admin_user1.headers)
|
||||
LLMProviderManager.create(user_performing_action=admin_user1)
|
||||
|
||||
# Seed documents for Tenant 1
|
||||
cc_pair_1.documents = []
|
||||
@@ -49,13 +51,13 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
|
||||
# Create connectors for Tenant 2
|
||||
cc_pair_2: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=test_user2,
|
||||
user_performing_action=admin_user2,
|
||||
)
|
||||
api_key_2: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=test_user2,
|
||||
user_performing_action=admin_user2,
|
||||
)
|
||||
api_key_2.headers.update(test_user2.headers)
|
||||
LLMProviderManager.create(user_performing_action=test_user2)
|
||||
api_key_2.headers.update(admin_user2.headers)
|
||||
LLMProviderManager.create(user_performing_action=admin_user2)
|
||||
|
||||
# Seed documents for Tenant 2
|
||||
cc_pair_2.documents = []
|
||||
@@ -76,17 +78,17 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
|
||||
# Create chat sessions for each user
|
||||
chat_session1: DATestChatSession = ChatSessionManager.create(
|
||||
user_performing_action=test_user1
|
||||
user_performing_action=admin_user1
|
||||
)
|
||||
chat_session2: DATestChatSession = ChatSessionManager.create(
|
||||
user_performing_action=test_user2
|
||||
user_performing_action=admin_user2
|
||||
)
|
||||
|
||||
# User 1 sends a message and gets a response
|
||||
response1 = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session1.id,
|
||||
message="What is in Tenant 1's documents?",
|
||||
user_performing_action=test_user1,
|
||||
user_performing_action=admin_user1,
|
||||
)
|
||||
# Assert that the search tool was used
|
||||
assert response1.tool_name == "run_search"
|
||||
@@ -100,14 +102,16 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
), "Tenant 2 document IDs should not be in the response"
|
||||
|
||||
# Assert that the contents are correct
|
||||
for doc in response1.tool_result or []:
|
||||
assert doc["content"] == "Tenant 1 Document Content"
|
||||
assert any(
|
||||
doc["content"] == "Tenant 1 Document Content"
|
||||
for doc in response1.tool_result or []
|
||||
), "Tenant 1 Document Content not found in any document"
|
||||
|
||||
# User 2 sends a message and gets a response
|
||||
response2 = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session2.id,
|
||||
message="What is in Tenant 2's documents?",
|
||||
user_performing_action=test_user2,
|
||||
user_performing_action=admin_user2,
|
||||
)
|
||||
# Assert that the search tool was used
|
||||
assert response2.tool_name == "run_search"
|
||||
@@ -119,15 +123,18 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
assert not response_doc_ids.intersection(
|
||||
tenant1_doc_ids
|
||||
), "Tenant 1 document IDs should not be in the response"
|
||||
|
||||
# Assert that the contents are correct
|
||||
for doc in response2.tool_result or []:
|
||||
assert doc["content"] == "Tenant 2 Document Content"
|
||||
assert any(
|
||||
doc["content"] == "Tenant 2 Document Content"
|
||||
for doc in response2.tool_result or []
|
||||
), "Tenant 2 Document Content not found in any document"
|
||||
|
||||
# User 1 tries to access Tenant 2's documents
|
||||
response_cross = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session1.id,
|
||||
message="What is in Tenant 2's documents?",
|
||||
user_performing_action=test_user1,
|
||||
user_performing_action=admin_user1,
|
||||
)
|
||||
# Assert that the search tool was used
|
||||
assert response_cross.tool_name == "run_search"
|
||||
@@ -140,7 +147,7 @@ def test_multi_tenant_access_control(reset_multitenant: None) -> None:
|
||||
response_cross2 = ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session2.id,
|
||||
message="What is in Tenant 1's documents?",
|
||||
user_performing_action=test_user2,
|
||||
user_performing_action=admin_user2,
|
||||
)
|
||||
# Assert that the search tool was used
|
||||
assert response_cross2.tool_name == "run_search"
|
||||
|
||||
@@ -4,14 +4,12 @@ from onyx.db.models import UserRole
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.connector import ConnectorManager
|
||||
from tests.integration.common_utils.managers.credential import CredentialManager
|
||||
from tests.integration.common_utils.managers.tenant import TenantManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
# Test flow from creating tenant to registering as a user
|
||||
def test_tenant_creation(reset_multitenant: None) -> None:
|
||||
TenantManager.create("tenant_dev", "test@test.com", "Data Plane Registration")
|
||||
test_user: DATestUser = UserManager.create(name="test", email="test@test.com")
|
||||
|
||||
assert UserManager.is_role(test_user, UserRole.ADMIN)
|
||||
|
||||
@@ -1,23 +1,23 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
from onyx.db.models import IndexingStatus
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.index_attempt import IndexAttemptManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestIndexAttempt
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _verify_index_attempt_pagination(
|
||||
cc_pair_id: int,
|
||||
index_attempts: list[DATestIndexAttempt],
|
||||
index_attempt_ids: list[int],
|
||||
page_size: int = 5,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
retrieved_attempts: list[int] = []
|
||||
last_time_started = None # Track the last time_started seen
|
||||
|
||||
for i in range(0, len(index_attempts), page_size):
|
||||
for i in range(0, len(index_attempt_ids), page_size):
|
||||
paginated_result = IndexAttemptManager.get_index_attempt_page(
|
||||
cc_pair_id=cc_pair_id,
|
||||
page=(i // page_size),
|
||||
@@ -26,9 +26,9 @@ def _verify_index_attempt_pagination(
|
||||
)
|
||||
|
||||
# Verify that the total items is equal to the length of the index attempts list
|
||||
assert paginated_result.total_items == len(index_attempts)
|
||||
assert paginated_result.total_items == len(index_attempt_ids)
|
||||
# Verify that the number of items in the page is equal to the page size
|
||||
assert len(paginated_result.items) == min(page_size, len(index_attempts) - i)
|
||||
assert len(paginated_result.items) == min(page_size, len(index_attempt_ids) - i)
|
||||
|
||||
# Verify time ordering within the page (descending order)
|
||||
for attempt in paginated_result.items:
|
||||
@@ -42,7 +42,7 @@ def _verify_index_attempt_pagination(
|
||||
retrieved_attempts.extend([attempt.id for attempt in paginated_result.items])
|
||||
|
||||
# Create a set of all the expected index attempt IDs
|
||||
all_expected_attempts = set(attempt.id for attempt in index_attempts)
|
||||
all_expected_attempts = set(index_attempt_ids)
|
||||
# Create a set of all the retrieved index attempt IDs
|
||||
all_retrieved_attempts = set(retrieved_attempts)
|
||||
|
||||
@@ -51,6 +51,9 @@ def _verify_index_attempt_pagination(
|
||||
|
||||
|
||||
def test_index_attempt_pagination(reset: None) -> None:
|
||||
MAX_WAIT = 60
|
||||
all_attempt_ids: list[int] = []
|
||||
|
||||
# Create an admin user to perform actions
|
||||
user_performing_action: DATestUser = UserManager.create(
|
||||
name="admin_performing_action",
|
||||
@@ -62,20 +65,49 @@ def test_index_attempt_pagination(reset: None) -> None:
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
|
||||
# Create 300 successful index attempts
|
||||
# Creating a CC pair will create an index attempt as well. wait for it.
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
paginated_result = IndexAttemptManager.get_index_attempt_page(
|
||||
cc_pair_id=cc_pair.id,
|
||||
page=0,
|
||||
page_size=5,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
|
||||
if paginated_result.total_items == 1:
|
||||
all_attempt_ids.append(paginated_result.items[0].id)
|
||||
print("Initial index attempt from cc_pair creation detected. Continuing...")
|
||||
break
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
if elapsed > MAX_WAIT:
|
||||
raise TimeoutError(
|
||||
f"Initial index attempt: Not detected within {MAX_WAIT} seconds."
|
||||
)
|
||||
|
||||
print(
|
||||
f"Waiting for initial index attempt: elapsed={elapsed:.2f} timeout={MAX_WAIT}"
|
||||
)
|
||||
time.sleep(1)
|
||||
|
||||
# Create 299 successful index attempts (for 300 total)
|
||||
base_time = datetime.now()
|
||||
all_attempts = IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=300,
|
||||
generated_attempts = IndexAttemptManager.create_test_index_attempts(
|
||||
num_attempts=299,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=IndexingStatus.SUCCESS,
|
||||
base_time=base_time,
|
||||
)
|
||||
|
||||
for attempt in generated_attempts:
|
||||
all_attempt_ids.append(attempt.id)
|
||||
|
||||
# Verify basic pagination with different page sizes
|
||||
print("Verifying basic pagination with page size 5")
|
||||
_verify_index_attempt_pagination(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempts=all_attempts,
|
||||
index_attempt_ids=all_attempt_ids,
|
||||
page_size=5,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
@@ -84,7 +116,7 @@ def test_index_attempt_pagination(reset: None) -> None:
|
||||
print("Verifying pagination with page size 100")
|
||||
_verify_index_attempt_pagination(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempts=all_attempts,
|
||||
index_attempt_ids=all_attempt_ids,
|
||||
page_size=100,
|
||||
user_performing_action=user_performing_action,
|
||||
)
|
||||
|
||||
@@ -58,6 +58,7 @@ def test_persona_permissions(reset: None) -> None:
|
||||
description="A persona created by basic user",
|
||||
is_public=False,
|
||||
groups=[],
|
||||
users=[admin_user.id],
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
PersonaManager.verify(basic_user_persona, user_performing_action=basic_user)
|
||||
@@ -139,9 +140,14 @@ def test_persona_permissions(reset: None) -> None:
|
||||
|
||||
"""Test admin permissions"""
|
||||
# Admin can edit any persona
|
||||
|
||||
# the persona was shared with the admin user on creation
|
||||
# this edit call will simulate having the same user in the list twice.
|
||||
# The server side should dedupe and handle this correctly (prior bug)
|
||||
PersonaManager.edit(
|
||||
persona=basic_user_persona,
|
||||
description="Updated by admin",
|
||||
description="Updated by admin 2",
|
||||
users=[admin_user.id, admin_user.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
PersonaManager.verify(basic_user_persona, user_performing_action=admin_user)
|
||||
|
||||
423
deployment/docker_compose/docker-compose.multitenant-dev.yml
Normal file
423
deployment/docker_compose/docker-compose.multitenant-dev.yml
Normal file
@@ -0,0 +1,423 @@
|
||||
services:
|
||||
api_server:
|
||||
image: onyxdotapp/onyx-backend:${IMAGE_TAG:-latest}
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile
|
||||
command: >
|
||||
/bin/sh -c "
|
||||
alembic -n schema_private upgrade head &&
|
||||
echo \"Starting Onyx Api Server\" &&
|
||||
uvicorn onyx.main:app --host 0.0.0.0 --port 8080"
|
||||
depends_on:
|
||||
- relational_db
|
||||
- index
|
||||
- cache
|
||||
- inference_model_server
|
||||
restart: always
|
||||
ports:
|
||||
- "8080:8080"
|
||||
environment:
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
- MULTI_TENANT=true
|
||||
- LOG_LEVEL=DEBUG
|
||||
- AUTH_TYPE=cloud
|
||||
- REQUIRE_EMAIL_VERIFICATION=false
|
||||
- DISABLE_TELEMETRY=true
|
||||
- IMAGE_TAG=test
|
||||
- DEV_MODE=true
|
||||
# Auth Settings
|
||||
- SESSION_EXPIRE_TIME_SECONDS=${SESSION_EXPIRE_TIME_SECONDS:-}
|
||||
- ENCRYPTION_KEY_SECRET=${ENCRYPTION_KEY_SECRET:-}
|
||||
- VALID_EMAIL_DOMAINS=${VALID_EMAIL_DOMAINS:-}
|
||||
- GOOGLE_OAUTH_CLIENT_ID=${GOOGLE_OAUTH_CLIENT_ID:-}
|
||||
- GOOGLE_OAUTH_CLIENT_SECRET=${GOOGLE_OAUTH_CLIENT_SECRET:-}
|
||||
- SMTP_SERVER=${SMTP_SERVER:-}
|
||||
- SMTP_PORT=${SMTP_PORT:-587}
|
||||
- SMTP_USER=${SMTP_USER:-}
|
||||
- SMTP_PASS=${SMTP_PASS:-}
|
||||
- ENABLE_EMAIL_INVITES=${ENABLE_EMAIL_INVITES:-}
|
||||
- EMAIL_FROM=${EMAIL_FROM:-}
|
||||
- OAUTH_CLIENT_ID=${OAUTH_CLIENT_ID:-}
|
||||
- OAUTH_CLIENT_SECRET=${OAUTH_CLIENT_SECRET:-}
|
||||
- OPENID_CONFIG_URL=${OPENID_CONFIG_URL:-}
|
||||
- TRACK_EXTERNAL_IDP_EXPIRY=${TRACK_EXTERNAL_IDP_EXPIRY:-}
|
||||
- CORS_ALLOWED_ORIGIN=${CORS_ALLOWED_ORIGIN:-}
|
||||
# Gen AI Settings
|
||||
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
|
||||
- QA_TIMEOUT=${QA_TIMEOUT:-}
|
||||
- MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-}
|
||||
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
|
||||
- DISABLE_LLM_QUERY_REPHRASE=${DISABLE_LLM_QUERY_REPHRASE:-}
|
||||
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
|
||||
- DISABLE_LITELLM_STREAMING=${DISABLE_LITELLM_STREAMING:-}
|
||||
- LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-}
|
||||
- BING_API_KEY=${BING_API_KEY:-}
|
||||
- DISABLE_LLM_DOC_RELEVANCE=${DISABLE_LLM_DOC_RELEVANCE:-}
|
||||
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
||||
- TOKEN_BUDGET_GLOBALLY_ENABLED=${TOKEN_BUDGET_GLOBALLY_ENABLED:-}
|
||||
# Query Options
|
||||
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-}
|
||||
- HYBRID_ALPHA=${HYBRID_ALPHA:-}
|
||||
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
|
||||
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}
|
||||
- LANGUAGE_HINT=${LANGUAGE_HINT:-}
|
||||
- LANGUAGE_CHAT_NAMING_HINT=${LANGUAGE_CHAT_NAMING_HINT:-}
|
||||
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
|
||||
# Other services
|
||||
- POSTGRES_HOST=relational_db
|
||||
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-}
|
||||
# Don't change the NLP model configs unless you know what you're doing
|
||||
- EMBEDDING_BATCH_SIZE=${EMBEDDING_BATCH_SIZE:-}
|
||||
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
||||
- DOC_EMBEDDING_DIM=${DOC_EMBEDDING_DIM:-}
|
||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
|
||||
- DISABLE_RERANK_FOR_STREAMING=${DISABLE_RERANK_FOR_STREAMING:-}
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
||||
- LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-}
|
||||
- LOG_DANSWER_MODEL_INTERACTIONS=${LOG_DANSWER_MODEL_INTERACTIONS:-}
|
||||
- LOG_INDIVIDUAL_MODEL_TOKENS=${LOG_INDIVIDUAL_MODEL_TOKENS:-}
|
||||
- LOG_VESPA_TIMING_INFORMATION=${LOG_VESPA_TIMING_INFORMATION:-}
|
||||
- LOG_ENDPOINT_LATENCY=${LOG_ENDPOINT_LATENCY:-}
|
||||
- LOG_POSTGRES_LATENCY=${LOG_POSTGRES_LATENCY:-}
|
||||
- LOG_POSTGRES_CONN_COUNTS=${LOG_POSTGRES_CONN_COUNTS:-}
|
||||
- CELERY_BROKER_POOL_LIMIT=${CELERY_BROKER_POOL_LIMIT:-}
|
||||
- LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS=${LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS:-}
|
||||
# Egnyte OAuth Configs
|
||||
- EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-}
|
||||
- EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-}
|
||||
- EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-}
|
||||
# Linear OAuth Configs
|
||||
- LINEAR_CLIENT_ID=${LINEAR_CLIENT_ID:-}
|
||||
- LINEAR_CLIENT_SECRET=${LINEAR_CLIENT_SECRET:-}
|
||||
# Analytics Configs
|
||||
- SENTRY_DSN=${SENTRY_DSN:-}
|
||||
# Chat Configs
|
||||
- HARD_DELETE_CHATS=${HARD_DELETE_CHATS:-}
|
||||
# Enables the use of bedrock models or IAM Auth
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID:-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY:-}
|
||||
- AWS_REGION_NAME=${AWS_REGION_NAME:-}
|
||||
- API_KEY_HASH_ROUNDS=${API_KEY_HASH_ROUNDS:-}
|
||||
# Seeding configuration
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
background:
|
||||
image: onyxdotapp/onyx-backend:${IMAGE_TAG:-latest}
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile
|
||||
command: >
|
||||
/bin/sh -c "
|
||||
if [ -f /etc/ssl/certs/custom-ca.crt ]; then
|
||||
update-ca-certificates;
|
||||
fi &&
|
||||
/usr/bin/supervisord -c /etc/supervisor/conf.d/supervisord.conf"
|
||||
depends_on:
|
||||
- relational_db
|
||||
- index
|
||||
- cache
|
||||
- inference_model_server
|
||||
- indexing_model_server
|
||||
restart: always
|
||||
environment:
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
- MULTI_TENANT=true
|
||||
- LOG_LEVEL=DEBUG
|
||||
- AUTH_TYPE=cloud
|
||||
- REQUIRE_EMAIL_VERIFICATION=false
|
||||
- DISABLE_TELEMETRY=true
|
||||
- IMAGE_TAG=test
|
||||
- ENCRYPTION_KEY_SECRET=${ENCRYPTION_KEY_SECRET:-}
|
||||
- JWT_PUBLIC_KEY_URL=${JWT_PUBLIC_KEY_URL:-}
|
||||
# Gen AI Settings (Needed by OnyxBot)
|
||||
- GEN_AI_MAX_TOKENS=${GEN_AI_MAX_TOKENS:-}
|
||||
- QA_TIMEOUT=${QA_TIMEOUT:-}
|
||||
- MAX_CHUNKS_FED_TO_CHAT=${MAX_CHUNKS_FED_TO_CHAT:-}
|
||||
- DISABLE_LLM_CHOOSE_SEARCH=${DISABLE_LLM_CHOOSE_SEARCH:-}
|
||||
- DISABLE_LLM_QUERY_REPHRASE=${DISABLE_LLM_QUERY_REPHRASE:-}
|
||||
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
|
||||
- GENERATIVE_MODEL_ACCESS_CHECK_FREQ=${GENERATIVE_MODEL_ACCESS_CHECK_FREQ:-}
|
||||
- DISABLE_LITELLM_STREAMING=${DISABLE_LITELLM_STREAMING:-}
|
||||
- LITELLM_EXTRA_HEADERS=${LITELLM_EXTRA_HEADERS:-}
|
||||
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
||||
- BING_API_KEY=${BING_API_KEY:-}
|
||||
# Query Options
|
||||
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-}
|
||||
- HYBRID_ALPHA=${HYBRID_ALPHA:-}
|
||||
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
|
||||
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}
|
||||
- LANGUAGE_HINT=${LANGUAGE_HINT:-}
|
||||
- LANGUAGE_CHAT_NAMING_HINT=${LANGUAGE_CHAT_NAMING_HINT:-}
|
||||
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
|
||||
# Other Services
|
||||
- POSTGRES_HOST=relational_db
|
||||
- POSTGRES_USER=${POSTGRES_USER:-}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-}
|
||||
- POSTGRES_DEFAULT_SCHEMA=${POSTGRES_DEFAULT_SCHEMA:-}
|
||||
- VESPA_HOST=index
|
||||
- REDIS_HOST=cache
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-}
|
||||
# Don't change the NLP model configs unless you know what you're doing
|
||||
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
||||
- DOC_EMBEDDING_DIM=${DOC_EMBEDDING_DIM:-}
|
||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
|
||||
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
||||
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
|
||||
# Indexing Configs
|
||||
- VESPA_SEARCHER_THREADS=${VESPA_SEARCHER_THREADS:-}
|
||||
- NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-}
|
||||
- ENABLED_CONNECTOR_TYPES=${ENABLED_CONNECTOR_TYPES:-}
|
||||
- DISABLE_INDEX_UPDATE_ON_SWAP=${DISABLE_INDEX_UPDATE_ON_SWAP:-}
|
||||
- DASK_JOB_CLIENT_ENABLED=${DASK_JOB_CLIENT_ENABLED:-}
|
||||
- CONTINUE_ON_CONNECTOR_FAILURE=${CONTINUE_ON_CONNECTOR_FAILURE:-}
|
||||
- EXPERIMENTAL_CHECKPOINTING_ENABLED=${EXPERIMENTAL_CHECKPOINTING_ENABLED:-}
|
||||
- CONFLUENCE_CONNECTOR_LABELS_TO_SKIP=${CONFLUENCE_CONNECTOR_LABELS_TO_SKIP:-}
|
||||
- JIRA_CONNECTOR_LABELS_TO_SKIP=${JIRA_CONNECTOR_LABELS_TO_SKIP:-}
|
||||
- WEB_CONNECTOR_VALIDATE_URLS=${WEB_CONNECTOR_VALIDATE_URLS:-}
|
||||
- JIRA_API_VERSION=${JIRA_API_VERSION:-}
|
||||
- GONG_CONNECTOR_START_TIME=${GONG_CONNECTOR_START_TIME:-}
|
||||
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
|
||||
- GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-}
|
||||
- MAX_DOCUMENT_CHARS=${MAX_DOCUMENT_CHARS:-}
|
||||
- MAX_FILE_SIZE_BYTES=${MAX_FILE_SIZE_BYTES:-}
|
||||
# Egnyte OAuth Configs
|
||||
- EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-}
|
||||
- EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-}
|
||||
- EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-}
|
||||
# Lienar OAuth Configs
|
||||
- LINEAR_CLIENT_ID=${LINEAR_CLIENT_ID:-}
|
||||
- LINEAR_CLIENT_SECRET=${LINEAR_CLIENT_SECRET:-}
|
||||
# Celery Configs (defaults are set in the supervisord.conf file.
|
||||
# prefer doing that to have one source of defaults)
|
||||
- CELERY_WORKER_INDEXING_CONCURRENCY=${CELERY_WORKER_INDEXING_CONCURRENCY:-}
|
||||
- CELERY_WORKER_LIGHT_CONCURRENCY=${CELERY_WORKER_LIGHT_CONCURRENCY:-}
|
||||
- CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER=${CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER:-}
|
||||
|
||||
# Onyx SlackBot Configs
|
||||
- DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER=${DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER:-}
|
||||
- DANSWER_BOT_FEEDBACK_VISIBILITY=${DANSWER_BOT_FEEDBACK_VISIBILITY:-}
|
||||
- DANSWER_BOT_DISPLAY_ERROR_MSGS=${DANSWER_BOT_DISPLAY_ERROR_MSGS:-}
|
||||
- DANSWER_BOT_RESPOND_EVERY_CHANNEL=${DANSWER_BOT_RESPOND_EVERY_CHANNEL:-}
|
||||
- DANSWER_BOT_DISABLE_COT=${DANSWER_BOT_DISABLE_COT:-} # Currently unused
|
||||
- NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-}
|
||||
- DANSWER_BOT_MAX_QPM=${DANSWER_BOT_MAX_QPM:-}
|
||||
- DANSWER_BOT_MAX_WAIT_TIME=${DANSWER_BOT_MAX_WAIT_TIME:-}
|
||||
# Logging
|
||||
# Leave this on pretty please? Nothing sensitive is collected!
|
||||
# https://docs.onyx.app/more/telemetry
|
||||
- DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-}
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info} # Set to debug to get more fine-grained logs
|
||||
- LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-} # LiteLLM Verbose Logging
|
||||
# Log all of Onyx prompts and interactions with the LLM
|
||||
- LOG_DANSWER_MODEL_INTERACTIONS=${LOG_DANSWER_MODEL_INTERACTIONS:-}
|
||||
- LOG_INDIVIDUAL_MODEL_TOKENS=${LOG_INDIVIDUAL_MODEL_TOKENS:-}
|
||||
- LOG_VESPA_TIMING_INFORMATION=${LOG_VESPA_TIMING_INFORMATION:-}
|
||||
|
||||
# Analytics Configs
|
||||
- SENTRY_DSN=${SENTRY_DSN:-}
|
||||
|
||||
# Enterprise Edition stuff
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||
- USE_IAM_AUTH=${USE_IAM_AUTH:-}
|
||||
- AWS_REGION_NAME=${AWS_REGION_NAME:-}
|
||||
- AWS_ACCESS_KEY_ID=${AWS_ACCESS_KEY_ID-}
|
||||
- AWS_SECRET_ACCESS_KEY=${AWS_SECRET_ACCESS_KEY-}
|
||||
# Uncomment the line below to use if IAM_AUTH is true and you are using iam auth for postgres
|
||||
# volumes:
|
||||
# - ./bundle.pem:/app/bundle.pem:ro
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# Uncomment the following lines if you need to include a custom CA certificate
|
||||
# This section enables the use of a custom CA certificate
|
||||
# If present, the custom CA certificate is mounted as a volume
|
||||
# The container checks for its existence and updates the system's CA certificates
|
||||
# This allows for secure communication with services using custom SSL certificates
|
||||
# Optional volume mount for CA certificate
|
||||
# volumes:
|
||||
# # Maps to the CA_CERT_PATH environment variable in the Dockerfile
|
||||
# - ${CA_CERT_PATH:-./custom-ca.crt}:/etc/ssl/certs/custom-ca.crt:ro
|
||||
|
||||
web_server:
|
||||
image: onyxdotapp/onyx-web-server:${IMAGE_TAG:-latest}
|
||||
build:
|
||||
context: ../../web
|
||||
dockerfile: Dockerfile
|
||||
args:
|
||||
- NEXT_PUBLIC_DISABLE_STREAMING=${NEXT_PUBLIC_DISABLE_STREAMING:-false}
|
||||
- NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA=${NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA:-false}
|
||||
- NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
|
||||
- NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
|
||||
- NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-}
|
||||
- NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN:-}
|
||||
- NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=${NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED:-}
|
||||
# Enterprise Edition only
|
||||
- NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME:-}
|
||||
# DO NOT TURN ON unless you have EXPLICIT PERMISSION from Onyx.
|
||||
- NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED:-false}
|
||||
depends_on:
|
||||
- api_server
|
||||
restart: always
|
||||
environment:
|
||||
- INTERNAL_URL=http://api_server:8080
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-}
|
||||
- THEME_IS_DARK=${THEME_IS_DARK:-}
|
||||
- DISABLE_LLM_DOC_RELEVANCE=${DISABLE_LLM_DOC_RELEVANCE:-}
|
||||
|
||||
# Enterprise Edition only
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${ENABLE_PAID_ENTERPRISE_EDITION_FEATURES:-false}
|
||||
- NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL:-}
|
||||
|
||||
inference_model_server:
|
||||
image: onyxdotapp/onyx-model-server:${IMAGE_TAG:-latest}
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile.model_server
|
||||
command: >
|
||||
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
|
||||
echo 'Skipping service...';
|
||||
exit 0;
|
||||
else
|
||||
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
|
||||
fi"
|
||||
restart: on-failure
|
||||
environment:
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
|
||||
# Analytics Configs
|
||||
- SENTRY_DSN=${SENTRY_DSN:-}
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
indexing_model_server:
|
||||
image: onyxdotapp/onyx-model-server:${IMAGE_TAG:-latest}
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile.model_server
|
||||
command: >
|
||||
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
|
||||
echo 'Skipping service...';
|
||||
exit 0;
|
||||
else
|
||||
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
|
||||
fi"
|
||||
restart: on-failure
|
||||
environment:
|
||||
- INDEX_BATCH_SIZE=${INDEX_BATCH_SIZE:-}
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
- INDEXING_ONLY=True
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
- CLIENT_EMBEDDING_TIMEOUT=${CLIENT_EMBEDDING_TIMEOUT:-}
|
||||
|
||||
# Analytics Configs
|
||||
- SENTRY_DSN=${SENTRY_DSN:-}
|
||||
volumes:
|
||||
# Not necessary, this is just to reduce download time during startup
|
||||
- indexing_huggingface_model_cache:/root/.cache/huggingface/
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
relational_db:
|
||||
image: postgres:15.2-alpine
|
||||
command: -c 'max_connections=250'
|
||||
restart: always
|
||||
environment:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
# This container name cannot have an underscore in it due to Vespa expectations of the URL
|
||||
index:
|
||||
image: vespaengine/vespa:8.277.17
|
||||
restart: always
|
||||
ports:
|
||||
- "19071:19071"
|
||||
- "8081:8081"
|
||||
volumes:
|
||||
- vespa_volume:/opt/vespa/var
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
nginx:
|
||||
image: nginx:1.23.4-alpine
|
||||
restart: always
|
||||
# nginx will immediately crash with `nginx: [emerg] host not found in upstream`
|
||||
# if api_server / web_server are not up
|
||||
depends_on:
|
||||
- api_server
|
||||
- web_server
|
||||
environment:
|
||||
- DOMAIN=localhost
|
||||
ports:
|
||||
- "80:80"
|
||||
- "3000:80" # allow for localhost:3000 usage, since that is the norm
|
||||
volumes:
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
# The specified script waits for the api_server to start up.
|
||||
# Without this we've seen issues where nginx shows no error logs but
|
||||
# does not recieve any traffic
|
||||
# NOTE: we have to use dos2unix to remove Carriage Return chars from the file
|
||||
# in order to make this work on both Unix-like systems and windows
|
||||
command: >
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev"
|
||||
|
||||
cache:
|
||||
image: redis:7.4-alpine
|
||||
restart: always
|
||||
ports:
|
||||
- "6379:6379"
|
||||
# docker silently mounts /data even without an explicit volume mount, which enables
|
||||
# persistence. explicitly setting save and appendonly forces ephemeral behavior.
|
||||
command: redis-server --save "" --appendonly no
|
||||
|
||||
volumes:
|
||||
db_volume:
|
||||
vespa_volume: # Created by the container itself
|
||||
|
||||
model_cache_huggingface:
|
||||
indexing_huggingface_model_cache:
|
||||
@@ -23,12 +23,12 @@ _Note:_ if you are having problems accessing the ^, try setting the `WEB_DOMAIN`
|
||||
`http://127.0.0.1:3000` and accessing it there.
|
||||
|
||||
## Testing
|
||||
This testing process will reset your application into a clean state.
|
||||
|
||||
This testing process will reset your application into a clean state.
|
||||
Don't run these tests if you don't want to do this!
|
||||
|
||||
Bring up the entire application.
|
||||
|
||||
|
||||
1. Reset the instance
|
||||
|
||||
```cd backend
|
||||
@@ -59,4 +59,4 @@ may use this for local troubleshooting and testing.
|
||||
```
|
||||
cd web
|
||||
npx chromatic --playwright --project-token={your token here}
|
||||
```
|
||||
```
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
"rsc": true,
|
||||
"tsx": true,
|
||||
"tailwind": {
|
||||
"config": "tailwind.config.js",
|
||||
"config": "tailwind-themes/tailwind.config.js",
|
||||
"css": "src/app/globals.css",
|
||||
"baseColor": "neutral",
|
||||
"cssVariables": false,
|
||||
|
||||
1001
web/package-lock.json
generated
1001
web/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -4,7 +4,7 @@
|
||||
"version-comment": "version field must be SemVer or chromatic will barf",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "next dev --turbopack",
|
||||
"dev": "next dev --turbo",
|
||||
"build": "next build",
|
||||
"start": "next start",
|
||||
"lint": "next lint",
|
||||
@@ -21,17 +21,17 @@
|
||||
"@radix-ui/react-accordion": "^1.2.2",
|
||||
"@radix-ui/react-checkbox": "^1.1.2",
|
||||
"@radix-ui/react-collapsible": "^1.1.2",
|
||||
"@radix-ui/react-dialog": "^1.1.2",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.4",
|
||||
"@radix-ui/react-dialog": "^1.1.6",
|
||||
"@radix-ui/react-dropdown-menu": "^2.1.6",
|
||||
"@radix-ui/react-label": "^2.1.1",
|
||||
"@radix-ui/react-popover": "^1.1.2",
|
||||
"@radix-ui/react-popover": "^1.1.6",
|
||||
"@radix-ui/react-radio-group": "^1.2.2",
|
||||
"@radix-ui/react-scroll-area": "^1.2.2",
|
||||
"@radix-ui/react-select": "^2.1.2",
|
||||
"@radix-ui/react-select": "^2.1.6",
|
||||
"@radix-ui/react-separator": "^1.1.0",
|
||||
"@radix-ui/react-slider": "^1.2.2",
|
||||
"@radix-ui/react-slot": "^1.1.0",
|
||||
"@radix-ui/react-switch": "^1.1.1",
|
||||
"@radix-ui/react-slot": "^1.1.2",
|
||||
"@radix-ui/react-switch": "^1.1.3",
|
||||
"@radix-ui/react-tabs": "^1.1.1",
|
||||
"@radix-ui/react-tooltip": "^1.1.3",
|
||||
"@sentry/nextjs": "^8.50.0",
|
||||
@@ -56,6 +56,7 @@
|
||||
"lucide-react": "^0.454.0",
|
||||
"mdast-util-find-and-replace": "^3.0.1",
|
||||
"next": "^15.0.2",
|
||||
"next-themes": "^0.4.4",
|
||||
"npm": "^10.8.0",
|
||||
"postcss": "^8.4.31",
|
||||
"posthog-js": "^1.176.0",
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 12 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 9.9 KiB |
BIN
web/public/discord.webp
Normal file
BIN
web/public/discord.webp
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.0 KiB |
BIN
web/public/litellm.png
Normal file
BIN
web/public/litellm.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 89 KiB |
BIN
web/public/logo-dark.png
Normal file
BIN
web/public/logo-dark.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 4.8 KiB |
BIN
web/public/logotype-dark.png
Normal file
BIN
web/public/logotype-dark.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 14 KiB |
@@ -27,8 +27,12 @@ function SourceTile({
|
||||
w-40
|
||||
cursor-pointer
|
||||
shadow-md
|
||||
hover:bg-hover
|
||||
${preSelect ? "bg-hover subtle-pulse" : "bg-hover-light"}
|
||||
hover:bg-accent-background-hovered
|
||||
${
|
||||
preSelect
|
||||
? "bg-accent-background-hovered subtle-pulse"
|
||||
: "bg-accent-background"
|
||||
}
|
||||
`}
|
||||
href={sourceMetadata.adminUrl}
|
||||
>
|
||||
|
||||
@@ -56,7 +56,7 @@ function NewApiKeyModal({
|
||||
<div className="flex mt-2">
|
||||
<b className="my-auto break-all">{apiKey}</b>
|
||||
<div
|
||||
className="ml-2 my-auto p-2 hover:bg-hover rounded cursor-pointer"
|
||||
className="ml-2 my-auto p-2 hover:bg-accent-background-hovered rounded cursor-pointer"
|
||||
onClick={() => {
|
||||
setCopyClicked(true);
|
||||
navigator.clipboard.writeText(apiKey);
|
||||
@@ -112,7 +112,10 @@ function Main() {
|
||||
}
|
||||
|
||||
const newApiKeyButton = (
|
||||
<CreateButton href="/admin/api-key/new" text="Create API Key" />
|
||||
<CreateButton
|
||||
onClick={() => setShowCreateUpdateForm(true)}
|
||||
text="Create API Key"
|
||||
/>
|
||||
);
|
||||
|
||||
if (apiKeys.length === 0) {
|
||||
@@ -179,7 +182,7 @@ function Main() {
|
||||
flex
|
||||
mb-1
|
||||
w-fit
|
||||
hover:bg-hover cursor-pointer
|
||||
hover:bg-accent-background-hovered cursor-pointer
|
||||
p-2
|
||||
rounded-lg
|
||||
border-border
|
||||
@@ -203,7 +206,7 @@ function Main() {
|
||||
flex
|
||||
mb-1
|
||||
w-fit
|
||||
hover:bg-hover cursor-pointer
|
||||
hover:bg-accent-background-hovered cursor-pointer
|
||||
p-2
|
||||
rounded-lg
|
||||
border-border
|
||||
|
||||
@@ -3,7 +3,13 @@
|
||||
import React from "react";
|
||||
import { Option } from "@/components/Dropdown";
|
||||
import { generateRandomIconShape } from "@/lib/assistantIconUtils";
|
||||
import { CCPairBasicInfo, DocumentSet, User, UserGroup } from "@/lib/types";
|
||||
import {
|
||||
CCPairBasicInfo,
|
||||
DocumentSet,
|
||||
User,
|
||||
UserGroup,
|
||||
UserRole,
|
||||
} from "@/lib/types";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ArrayHelpers, FieldArray, Form, Formik, FormikProps } from "formik";
|
||||
@@ -33,9 +39,8 @@ import {
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import Link from "next/link";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { FiInfo } from "react-icons/fi";
|
||||
import * as Yup from "yup";
|
||||
import CollapsibleSection from "./CollapsibleSection";
|
||||
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
|
||||
@@ -71,11 +76,11 @@ import {
|
||||
Option as DropdownOption,
|
||||
} from "@/components/Dropdown";
|
||||
import { SourceChip } from "@/app/chat/input/ChatInputBar";
|
||||
import { TagIcon, UserIcon, XIcon } from "lucide-react";
|
||||
import { TagIcon, UserIcon, XIcon, InfoIcon } from "lucide-react";
|
||||
import { LLMSelector } from "@/components/llm/LLMSelector";
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal";
|
||||
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
|
||||
import Title from "@/components/ui/title";
|
||||
import { SEARCH_TOOL_ID } from "@/app/chat/tools/constants";
|
||||
|
||||
@@ -127,6 +132,8 @@ export function AssistantEditor({
|
||||
}) {
|
||||
const { refreshAssistants, isImageGenerationAvailable } = useAssistants();
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const isAdminPage = searchParams.get("admin") === "true";
|
||||
|
||||
const { popup, setPopup } = usePopup();
|
||||
const { labels, refreshLabels, createLabel, updateLabel, deleteLabel } =
|
||||
@@ -216,6 +223,8 @@ export function AssistantEditor({
|
||||
enabledToolsMap[tool.id] = personaCurrentToolIds.includes(tool.id);
|
||||
});
|
||||
|
||||
const [showVisibilityWarning, setShowVisibilityWarning] = useState(false);
|
||||
|
||||
const initialValues = {
|
||||
name: existingPersona?.name ?? "",
|
||||
description: existingPersona?.description ?? "",
|
||||
@@ -252,6 +261,7 @@ export function AssistantEditor({
|
||||
(u) => u.id !== existingPersona.owner?.id
|
||||
) ?? [],
|
||||
selectedGroups: existingPersona?.groups ?? [],
|
||||
is_default_persona: existingPersona?.is_default_persona ?? false,
|
||||
};
|
||||
|
||||
interface AssistantPrompt {
|
||||
@@ -308,24 +318,12 @@ export function AssistantEditor({
|
||||
const [isRequestSuccessful, setIsRequestSuccessful] = useState(false);
|
||||
|
||||
const { data: userGroups } = useUserGroups();
|
||||
// const { data: allUsers } = useUsers({ includeApiKeys: false }) as {
|
||||
// data: MinimalUserSnapshot[] | undefined;
|
||||
// };
|
||||
|
||||
const { data: users } = useSWR<MinimalUserSnapshot[]>(
|
||||
"/api/users",
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
const mapUsersToMinimalSnapshot = (users: any): MinimalUserSnapshot[] => {
|
||||
if (!users || !Array.isArray(users.users)) return [];
|
||||
return users.users.map((user: any) => ({
|
||||
id: user.id,
|
||||
name: user.name,
|
||||
email: user.email,
|
||||
}));
|
||||
};
|
||||
|
||||
const [deleteModalOpen, setDeleteModalOpen] = useState(false);
|
||||
|
||||
if (!labels) {
|
||||
@@ -346,9 +344,7 @@ export function AssistantEditor({
|
||||
if (response.ok) {
|
||||
await refreshAssistants();
|
||||
router.push(
|
||||
redirectType === SuccessfulPersonaUpdateRedirectType.ADMIN
|
||||
? `/admin/assistants?u=${Date.now()}`
|
||||
: `/chat`
|
||||
isAdminPage ? `/admin/assistants?u=${Date.now()}` : `/chat`
|
||||
);
|
||||
} else {
|
||||
setPopup({
|
||||
@@ -374,8 +370,9 @@ export function AssistantEditor({
|
||||
<BackButton />
|
||||
</div>
|
||||
)}
|
||||
|
||||
{labelToDelete && (
|
||||
<DeleteEntityModal
|
||||
<ConfirmEntityModal
|
||||
entityType="label"
|
||||
entityName={labelToDelete.name}
|
||||
onClose={() => setLabelToDelete(null)}
|
||||
@@ -398,7 +395,7 @@ export function AssistantEditor({
|
||||
/>
|
||||
)}
|
||||
{deleteModalOpen && existingPersona && (
|
||||
<DeleteEntityModal
|
||||
<ConfirmEntityModal
|
||||
entityType="Persona"
|
||||
entityName={existingPersona.name}
|
||||
onClose={closeDeleteModal}
|
||||
@@ -439,6 +436,7 @@ export function AssistantEditor({
|
||||
label_ids: Yup.array().of(Yup.number()),
|
||||
selectedUsers: Yup.array().of(Yup.object()),
|
||||
selectedGroups: Yup.array().of(Yup.number()),
|
||||
is_default_persona: Yup.boolean().required(),
|
||||
})
|
||||
.test(
|
||||
"system-prompt-or-task-prompt",
|
||||
@@ -459,6 +457,19 @@ export function AssistantEditor({
|
||||
"Must provide either Instructions or Reminders (Advanced)",
|
||||
});
|
||||
}
|
||||
)
|
||||
.test(
|
||||
"default-persona-public",
|
||||
"Default persona must be public",
|
||||
function (values) {
|
||||
if (values.is_default_persona && !values.is_public) {
|
||||
return this.createError({
|
||||
path: "is_public",
|
||||
message: "Default persona must be public",
|
||||
});
|
||||
}
|
||||
return true;
|
||||
}
|
||||
)}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
if (
|
||||
@@ -499,7 +510,6 @@ export function AssistantEditor({
|
||||
const submissionData: PersonaUpsertParameters = {
|
||||
...values,
|
||||
existing_prompt_id: existingPrompt?.id ?? null,
|
||||
is_default_persona: admin!,
|
||||
starter_messages: starterMessages,
|
||||
groups: groups,
|
||||
users: values.is_public
|
||||
@@ -563,8 +573,9 @@ export function AssistantEditor({
|
||||
}
|
||||
|
||||
await refreshAssistants();
|
||||
|
||||
router.push(
|
||||
redirectType === SuccessfulPersonaUpdateRedirectType.ADMIN
|
||||
isAdminPage
|
||||
? `/admin/assistants?u=${Date.now()}`
|
||||
: `/chat?assistantId=${assistantId}`
|
||||
);
|
||||
@@ -825,10 +836,7 @@ export function AssistantEditor({
|
||||
</TooltipProvider>
|
||||
</div>
|
||||
</div>
|
||||
<p
|
||||
className="text-sm text-subtle"
|
||||
style={{ color: "rgb(113, 114, 121)" }}
|
||||
>
|
||||
<p className="text-sm text-neutral-700 dark:text-neutral-400">
|
||||
Attach additional unique knowledge to this assistant
|
||||
</p>
|
||||
</div>
|
||||
@@ -1008,6 +1016,22 @@ export function AssistantEditor({
|
||||
{showAdvancedOptions && (
|
||||
<>
|
||||
<div className="max-w-4xl w-full">
|
||||
{user?.role == UserRole.ADMIN && (
|
||||
<BooleanFormField
|
||||
onChange={(checked) => {
|
||||
if (checked) {
|
||||
setFieldValue("is_public", true);
|
||||
setFieldValue("is_default_persona", true);
|
||||
}
|
||||
}}
|
||||
name="is_default_persona"
|
||||
label="Featured Assistant"
|
||||
subtext="If set, this assistant will be pinned for all new users and appear in the Featured list in the assistant explorer. This also makes the assistant public."
|
||||
/>
|
||||
)}
|
||||
|
||||
<Separator />
|
||||
|
||||
<div className="flex gap-x-2 items-center ">
|
||||
<div className="block font-medium text-sm">Access</div>
|
||||
</div>
|
||||
@@ -1017,22 +1041,60 @@ export function AssistantEditor({
|
||||
|
||||
<div className="min-h-[100px]">
|
||||
<div className="flex items-center mb-2">
|
||||
<SwitchField
|
||||
name="is_public"
|
||||
size="md"
|
||||
onCheckedChange={(checked) => {
|
||||
setFieldValue("is_public", checked);
|
||||
if (checked) {
|
||||
setFieldValue("selectedUsers", []);
|
||||
setFieldValue("selectedGroups", []);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
<TooltipProvider delayDuration={0}>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div>
|
||||
<SwitchField
|
||||
name="is_public"
|
||||
size="md"
|
||||
onCheckedChange={(checked) => {
|
||||
if (values.is_default_persona && !checked) {
|
||||
setShowVisibilityWarning(true);
|
||||
} else {
|
||||
setFieldValue("is_public", checked);
|
||||
if (!checked) {
|
||||
// Even though this code path should not be possible,
|
||||
// we set the default persona to false to be safe
|
||||
setFieldValue(
|
||||
"is_default_persona",
|
||||
false
|
||||
);
|
||||
}
|
||||
if (checked) {
|
||||
setFieldValue("selectedUsers", []);
|
||||
setFieldValue("selectedGroups", []);
|
||||
}
|
||||
}
|
||||
}}
|
||||
disabled={values.is_default_persona}
|
||||
/>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
{values.is_default_persona && (
|
||||
<TooltipContent side="top" align="center">
|
||||
Default persona must be public. Set
|
||||
"Default Persona" to false to change
|
||||
visibility.
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<span className="text-sm ml-2">
|
||||
{values.is_public ? "Public" : "Private"}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{showVisibilityWarning && (
|
||||
<div className="flex items-center text-warning mt-2">
|
||||
<InfoIcon size={16} className="mr-2" />
|
||||
<span className="text-sm">
|
||||
Default persona must be public. Visibility has been
|
||||
automatically set to public.
|
||||
</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{values.is_public ? (
|
||||
<p className="text-sm text-text-dark">
|
||||
Anyone from your organization can view and use this
|
||||
@@ -1217,7 +1279,7 @@ export function AssistantEditor({
|
||||
setFieldValue("label_ids", newLabelIds);
|
||||
}}
|
||||
itemComponent={({ option }) => (
|
||||
<div className="flex items-center justify-between px-4 py-3 text-sm hover:bg-hover cursor-pointer border-b border-border last:border-b-0">
|
||||
<div className="flex items-center justify-between px-4 py-3 text-sm hover:bg-accent-background-hovered cursor-pointer border-b border-border last:border-b-0">
|
||||
<div
|
||||
className="flex-grow"
|
||||
onClick={() => {
|
||||
@@ -1356,7 +1418,7 @@ export function AssistantEditor({
|
||||
</>
|
||||
)}
|
||||
|
||||
<div className="mt-12 gap-x-2 w-full justify-end flex">
|
||||
<div className="mt-12 gap-x-2 w-full justify-end flex">
|
||||
<Button
|
||||
type="submit"
|
||||
disabled={isSubmitting || isRequestSuccessful}
|
||||
|
||||
@@ -31,7 +31,7 @@ export function HidableSection({
|
||||
return (
|
||||
<div>
|
||||
<div
|
||||
className="flex hover:bg-hover-light rounded cursor-pointer p-2"
|
||||
className="flex hover:bg-accent-background rounded cursor-pointer p-2"
|
||||
onClick={() => setIsHidden(!isHidden)}
|
||||
>
|
||||
<SectionHeader includeMargin={false}>{sectionTitle}</SectionHeader>
|
||||
|
||||
@@ -11,13 +11,14 @@ import { DraggableTable } from "@/components/table/DraggableTable";
|
||||
import {
|
||||
deletePersona,
|
||||
personaComparator,
|
||||
togglePersonaDefault,
|
||||
togglePersonaVisibility,
|
||||
} from "./lib";
|
||||
import { FiEdit2 } from "react-icons/fi";
|
||||
import { TrashIcon } from "@/components/icons/icons";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import { useAssistants } from "@/components/context/AssistantsContext";
|
||||
import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal";
|
||||
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
|
||||
|
||||
function PersonaTypeDisplay({ persona }: { persona: Persona }) {
|
||||
if (persona.builtin_persona) {
|
||||
@@ -56,6 +57,9 @@ export function PersonasTable() {
|
||||
const [finalPersonas, setFinalPersonas] = useState<Persona[]>([]);
|
||||
const [deleteModalOpen, setDeleteModalOpen] = useState(false);
|
||||
const [personaToDelete, setPersonaToDelete] = useState<Persona | null>(null);
|
||||
const [defaultModalOpen, setDefaultModalOpen] = useState(false);
|
||||
const [personaToToggleDefault, setPersonaToToggleDefault] =
|
||||
useState<Persona | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
const editable = editablePersonas.sort(personaComparator);
|
||||
@@ -126,11 +130,39 @@ export function PersonasTable() {
|
||||
}
|
||||
};
|
||||
|
||||
const openDefaultModal = (persona: Persona) => {
|
||||
setPersonaToToggleDefault(persona);
|
||||
setDefaultModalOpen(true);
|
||||
};
|
||||
|
||||
const closeDefaultModal = () => {
|
||||
setDefaultModalOpen(false);
|
||||
setPersonaToToggleDefault(null);
|
||||
};
|
||||
|
||||
const handleToggleDefault = async () => {
|
||||
if (personaToToggleDefault) {
|
||||
const response = await togglePersonaDefault(
|
||||
personaToToggleDefault.id,
|
||||
personaToToggleDefault.is_default_persona
|
||||
);
|
||||
if (response.ok) {
|
||||
await refreshAssistants();
|
||||
closeDefaultModal();
|
||||
} else {
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: `Failed to update persona - ${await response.text()}`,
|
||||
});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div>
|
||||
{popup}
|
||||
{deleteModalOpen && personaToDelete && (
|
||||
<DeleteEntityModal
|
||||
<ConfirmEntityModal
|
||||
entityType="Persona"
|
||||
entityName={personaToDelete.name}
|
||||
onClose={closeDeleteModal}
|
||||
@@ -138,8 +170,35 @@ export function PersonasTable() {
|
||||
/>
|
||||
)}
|
||||
|
||||
{defaultModalOpen && personaToToggleDefault && (
|
||||
<ConfirmEntityModal
|
||||
variant="action"
|
||||
entityType="Assistant"
|
||||
entityName={personaToToggleDefault.name}
|
||||
onClose={closeDefaultModal}
|
||||
onSubmit={handleToggleDefault}
|
||||
actionButtonText={
|
||||
personaToToggleDefault.is_default_persona
|
||||
? "Remove Featured"
|
||||
: "Set as Featured"
|
||||
}
|
||||
additionalDetails={
|
||||
personaToToggleDefault.is_default_persona
|
||||
? `Removing "${personaToToggleDefault.name}" as a featured assistant will not affect its visibility or accessibility.`
|
||||
: `Setting "${personaToToggleDefault.name}" as a featured assistant will make it public and visible to all users. This action cannot be undone.`
|
||||
}
|
||||
/>
|
||||
)}
|
||||
|
||||
<DraggableTable
|
||||
headers={["Name", "Description", "Type", "Is Visible", "Delete"]}
|
||||
headers={[
|
||||
"Name",
|
||||
"Description",
|
||||
"Type",
|
||||
"Featured Assistant",
|
||||
"Is Visible",
|
||||
"Delete",
|
||||
]}
|
||||
isAdmin={isAdmin}
|
||||
rows={finalPersonas.map((persona) => {
|
||||
const isEditable = editablePersonas.includes(persona);
|
||||
@@ -152,7 +211,9 @@ export function PersonasTable() {
|
||||
className="mr-1 my-auto cursor-pointer"
|
||||
onClick={() =>
|
||||
router.push(
|
||||
`/admin/assistants/${persona.id}?u=${Date.now()}`
|
||||
`/assistants/edit/${
|
||||
persona.id
|
||||
}?u=${Date.now()}&admin=true`
|
||||
)
|
||||
}
|
||||
/>
|
||||
@@ -168,6 +229,30 @@ export function PersonasTable() {
|
||||
{persona.description}
|
||||
</p>,
|
||||
<PersonaTypeDisplay key={persona.id} persona={persona} />,
|
||||
<div
|
||||
key="is_default_persona"
|
||||
onClick={() => {
|
||||
if (isEditable) {
|
||||
openDefaultModal(persona);
|
||||
}
|
||||
}}
|
||||
className={`px-1 py-0.5 rounded flex ${
|
||||
isEditable
|
||||
? "hover:bg-accent-background-hovered cursor-pointer"
|
||||
: ""
|
||||
} select-none w-fit`}
|
||||
>
|
||||
<div className="my-auto flex-none w-22">
|
||||
{!persona.is_default_persona ? (
|
||||
<div className="text-error">Not Featured</div>
|
||||
) : (
|
||||
"Featured"
|
||||
)}
|
||||
</div>
|
||||
<div className="ml-1 my-auto">
|
||||
<CustomCheckbox checked={persona.is_default_persona} />
|
||||
</div>
|
||||
</div>,
|
||||
<div
|
||||
key="is_visible"
|
||||
onClick={async () => {
|
||||
@@ -187,7 +272,9 @@ export function PersonasTable() {
|
||||
}
|
||||
}}
|
||||
className={`px-1 py-0.5 rounded flex ${
|
||||
isEditable ? "hover:bg-hover cursor-pointer" : ""
|
||||
isEditable
|
||||
? "hover:bg-accent-background-hovered cursor-pointer"
|
||||
: ""
|
||||
} select-none w-fit`}
|
||||
>
|
||||
<div className="my-auto w-12">
|
||||
@@ -205,7 +292,7 @@ export function PersonasTable() {
|
||||
<div className="mr-auto my-auto">
|
||||
{!persona.builtin_persona && isEditable ? (
|
||||
<div
|
||||
className="hover:bg-hover rounded p-1 cursor-pointer"
|
||||
className="hover:bg-accent-background-hovered rounded p-1 cursor-pointer"
|
||||
onClick={() => openDeleteModal(persona)}
|
||||
>
|
||||
<TrashIcon />
|
||||
|
||||
@@ -65,7 +65,7 @@ export default function StarterMessagesList({
|
||||
onClick={() => {
|
||||
arrayHelpers.remove(index);
|
||||
}}
|
||||
className={`text-gray-400 hover:text-red-500 ${
|
||||
className={`text-text-400 hover:text-red-500 ${
|
||||
index === values.length - 1 && !starterMessage.message
|
||||
? "opacity-50 cursor-not-allowed"
|
||||
: ""
|
||||
@@ -105,7 +105,7 @@ export default function StarterMessagesList({
|
||||
4 ||
|
||||
isRefreshing ||
|
||||
!autoStarterMessageEnabled
|
||||
? "bg-neutral-800 text-neutral-300 cursor-not-allowed"
|
||||
? "bg-background-800 text-text-300 cursor-not-allowed"
|
||||
: ""
|
||||
}
|
||||
`}
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { deletePersona } from "../lib";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { SuccessfulPersonaUpdateRedirectType } from "../enums";
|
||||
|
||||
export function DeletePersonaButton({
|
||||
personaId,
|
||||
redirectType,
|
||||
}: {
|
||||
personaId: number;
|
||||
redirectType: SuccessfulPersonaUpdateRedirectType;
|
||||
}) {
|
||||
const router = useRouter();
|
||||
|
||||
return (
|
||||
<Button
|
||||
variant="destructive"
|
||||
onClick={async () => {
|
||||
const response = await deletePersona(personaId);
|
||||
if (response.ok) {
|
||||
router.push(
|
||||
redirectType === SuccessfulPersonaUpdateRedirectType.ADMIN
|
||||
? `/admin/assistants?u=${Date.now()}`
|
||||
: `/chat`
|
||||
);
|
||||
} else {
|
||||
alert(`Failed to delete persona - ${await response.text()}`);
|
||||
}
|
||||
}}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
);
|
||||
}
|
||||
@@ -1,43 +0,0 @@
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { AssistantEditor } from "../AssistantEditor";
|
||||
import { BackButton } from "@/components/BackButton";
|
||||
|
||||
import { DeletePersonaButton } from "./DeletePersonaButton";
|
||||
import { fetchAssistantEditorInfoSS } from "@/lib/assistants/fetchPersonaEditorInfoSS";
|
||||
import { SuccessfulPersonaUpdateRedirectType } from "../enums";
|
||||
import { RobotIcon } from "@/components/icons/icons";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import Title from "@/components/ui/title";
|
||||
|
||||
export default async function Page(props: { params: Promise<{ id: string }> }) {
|
||||
const params = await props.params;
|
||||
const [values, error] = await fetchAssistantEditorInfoSS(params.id);
|
||||
|
||||
let body;
|
||||
if (!values) {
|
||||
body = (
|
||||
<ErrorCallout errorTitle="Something went wrong :(" errorMsg={error} />
|
||||
);
|
||||
} else {
|
||||
body = (
|
||||
<>
|
||||
<CardSection className="!border-none !bg-transparent !ring-none">
|
||||
<AssistantEditor
|
||||
{...values}
|
||||
admin
|
||||
defaultPublic={true}
|
||||
redirectType={SuccessfulPersonaUpdateRedirectType.ADMIN}
|
||||
/>
|
||||
</CardSection>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="w-full">
|
||||
<AdminPageTitle title="Edit Assistant" icon={<RobotIcon size={32} />} />
|
||||
{body}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -261,6 +261,22 @@ export function personaComparator(a: Persona, b: Persona) {
|
||||
return closerToZeroNegativesFirstComparator(a.id, b.id);
|
||||
}
|
||||
|
||||
export const togglePersonaDefault = async (
|
||||
personaId: number,
|
||||
isDefault: boolean
|
||||
) => {
|
||||
const response = await fetch(`/api/admin/persona/${personaId}/default`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
is_default_persona: !isDefault,
|
||||
}),
|
||||
});
|
||||
return response;
|
||||
};
|
||||
|
||||
export const togglePersonaVisibility = async (
|
||||
personaId: number,
|
||||
isVisible: boolean
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
import { AssistantEditor } from "../AssistantEditor";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { fetchAssistantEditorInfoSS } from "@/lib/assistants/fetchPersonaEditorInfoSS";
|
||||
import { SuccessfulPersonaUpdateRedirectType } from "../enums";
|
||||
|
||||
export default async function Page() {
|
||||
const [values, error] = await fetchAssistantEditorInfoSS();
|
||||
|
||||
if (!values) {
|
||||
return (
|
||||
<ErrorCallout errorTitle="Something went wrong :(" errorMsg={error} />
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<div className="w-full">
|
||||
<AssistantEditor
|
||||
{...values}
|
||||
admin
|
||||
defaultPublic={true}
|
||||
redirectType={SuccessfulPersonaUpdateRedirectType.ADMIN}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,3 @@
|
||||
"use client";
|
||||
import { PersonasTable } from "./PersonaTable";
|
||||
import Text from "@/components/ui/text";
|
||||
import Title from "@/components/ui/title";
|
||||
@@ -30,7 +29,7 @@ export default async function Page() {
|
||||
<Separator />
|
||||
|
||||
<Title>Create an Assistant</Title>
|
||||
<CreateButton href="/admin/assistants/new" text="New Assistant" />
|
||||
<CreateButton href="/assistants/new?admin=true" text="New Assistant" />
|
||||
|
||||
<Separator />
|
||||
|
||||
|
||||
@@ -92,9 +92,9 @@ export const ExistingSlackBotForm = ({
|
||||
|
||||
<div className="flex flex-col" ref={dropdownRef}>
|
||||
<div className="flex items-center gap-4">
|
||||
<div className="border rounded-lg border-gray-200">
|
||||
<div className="border rounded-lg border-background-200">
|
||||
<div
|
||||
className="flex items-center gap-2 cursor-pointer hover:bg-gray-100 p-2"
|
||||
className="flex items-center gap-2 cursor-pointer hover:bg-background-100 p-2"
|
||||
onClick={() => setIsExpanded(!isExpanded)}
|
||||
>
|
||||
{isExpanded ? (
|
||||
@@ -117,7 +117,7 @@ export const ExistingSlackBotForm = ({
|
||||
</div>
|
||||
|
||||
{isExpanded && (
|
||||
<div className="bg-white border rounded-lg border-gray-200 shadow-lg absolute mt-12 right-0 z-10 w-full md:w-3/4 lg:w-1/2">
|
||||
<div className="bg-white border rounded-lg border-background-200 shadow-lg absolute mt-12 right-0 z-10 w-full md:w-3/4 lg:w-1/2">
|
||||
<div className="p-4">
|
||||
<SlackTokensForm
|
||||
isUpdate={true}
|
||||
@@ -134,7 +134,7 @@ export const ExistingSlackBotForm = ({
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-2">
|
||||
<div className="inline-block border rounded-lg border-gray-200 p-2">
|
||||
<div className="inline-block border rounded-lg border-background-200 p-2">
|
||||
<Checkbox
|
||||
label="Enabled"
|
||||
checked={formValues.enabled}
|
||||
|
||||
@@ -100,7 +100,7 @@ export function SlackChannelConfigsTable({
|
||||
slackChannelConfig.persona
|
||||
) ? (
|
||||
<Link
|
||||
href={`/admin/assistants/${slackChannelConfig.persona.id}`}
|
||||
href={`/assistants/${slackChannelConfig.persona.id}`}
|
||||
className="text-primary hover:underline"
|
||||
>
|
||||
{slackChannelConfig.persona.name}
|
||||
|
||||
@@ -613,7 +613,7 @@ export function SlackChannelConfigFormFields({
|
||||
<Link
|
||||
key={ccpairinfo.id}
|
||||
href={`/admin/connector/${ccpairinfo.id}`}
|
||||
className="flex items-center p-2 rounded-md hover:bg-gray-100 transition-colors"
|
||||
className="flex items-center p-2 rounded-md hover:bg-background-100 transition-colors"
|
||||
>
|
||||
<div className="mr-2">
|
||||
<SourceIcon
|
||||
|
||||
@@ -84,7 +84,7 @@ function Main() {
|
||||
{isApiKeySet ? (
|
||||
<div className="w-full p-3 border rounded-md bg-background text-text flex items-center">
|
||||
<span className="flex-grow">••••••••••••••••</span>
|
||||
<Lock className="h-5 w-5 text-gray-400" />
|
||||
<Lock className="h-5 w-5 text-text-400" />
|
||||
</div>
|
||||
) : (
|
||||
<input
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user